29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/Support/Casting.h"
32#include "llvm/Support/ErrorHandling.h"
36#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
37#include "mlir/Conversion/Passes.h.inc"
53 IntegerType i32 = rewriter.getI32Type();
55 auto valTy = cast<IntegerType>(val.
getType());
58 return valTy.getWidth() > 32
59 ?
Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
60 :
Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
65 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
71 IntegerType i64 = rewriter.getI64Type();
73 auto valTy = cast<IntegerType>(val.
getType());
76 return valTy.getWidth() > 64
77 ?
Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
78 :
Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
83 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
90 IntegerType i32 = rewriter.getI32Type();
92 for (
auto [i, increment, stride] : llvm::enumerate(
indices, strides)) {
95 ShapedType::isDynamic(stride)
97 memRefDescriptor.
stride(rewriter, loc, i))
98 : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
99 increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
111 MemRefType memrefType,
115 if (chipset >=
kGfx1250 && !boundsCheck) {
116 constexpr int64_t first45bits = (1ll << 45) - 1;
119 if (memrefType.hasStaticShape() &&
120 !llvm::any_of(strides, ShapedType::isDynamic)) {
121 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
123 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
124 size = std::max(
shape[i] * strides[i], size);
125 size = size * elementByteWidth;
129 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
130 Value size = memrefDescriptor.
size(rewriter, loc, i);
131 Value stride = memrefDescriptor.
stride(rewriter, loc, i);
132 Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
134 ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
139 return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
145 Value cacheSwizzleStride =
nullptr,
146 unsigned addressSpace = 8) {
150 Type i16 = rewriter.getI16Type();
153 Value cacheStrideZext =
154 LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
155 Value swizzleBit = LLVM::ConstantOp::create(
156 rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
157 stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
160 stride = LLVM::ConstantOp::create(rewriter, loc, i16,
161 rewriter.getI16IntegerAttr(0));
190 flags |= (7 << 12) | (4 << 15);
193 uint32_t oob = boundsCheck ? 3 : 2;
194 flags |= (oob << 28);
199 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
200 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
201 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
206struct FatRawBufferCastLowering
208 FatRawBufferCastLowering(
const LLVMTypeConverter &converter, Chipset chipset)
209 : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
215 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
216 ConversionPatternRewriter &rewriter)
const override {
217 Location loc = op.getLoc();
218 Value memRef = adaptor.getSource();
219 Value unconvertedMemref = op.getSource();
220 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
221 MemRefDescriptor descriptor(memRef);
223 DataLayout dataLayout = DataLayout::closest(op);
224 int64_t elementByteWidth =
227 int64_t unusedOffset = 0;
228 SmallVector<int64_t, 5> strideVals;
229 if (
failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
230 return op.emitOpError(
"Can't lower non-stride-offset memrefs");
232 Value numRecords = adaptor.getValidBytes();
235 getNumRecords(rewriter, loc, memrefType, descriptor, strideVals,
236 elementByteWidth, chipset, adaptor.getBoundsCheck());
239 adaptor.getResetOffset()
240 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
242 : descriptor.alignedPtr(rewriter, loc);
244 Value offset = adaptor.getResetOffset()
245 ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
246 rewriter.getIndexAttr(0))
247 : descriptor.offset(rewriter, loc);
249 bool hasSizes = memrefType.getRank() > 0;
252 Value sizes = hasSizes
253 ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
257 hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
262 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
263 chipset, adaptor.getCacheSwizzleStride(), 7);
265 Value
result = MemRefDescriptor::poison(
267 getTypeConverter()->convertType(op.getResult().getType()));
269 result = LLVM::InsertValueOp::create(rewriter, loc,
result, fatPtr, pos);
270 result = LLVM::InsertValueOp::create(rewriter, loc,
result, fatPtr,
272 result = LLVM::InsertValueOp::create(rewriter, loc,
result, offset,
275 result = LLVM::InsertValueOp::create(rewriter, loc,
result, sizes,
277 result = LLVM::InsertValueOp::create(rewriter, loc,
result, strides,
280 rewriter.replaceOp(op,
result);
286template <
typename GpuOp,
typename Intrinsic>
288 RawBufferOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
289 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
292 static constexpr uint32_t maxVectorOpWidth = 128;
295 matchAndRewrite(GpuOp gpuOp,
typename GpuOp::Adaptor adaptor,
296 ConversionPatternRewriter &rewriter)
const override {
297 Location loc = gpuOp.getLoc();
298 Value memref = adaptor.getMemref();
299 Value unconvertedMemref = gpuOp.getMemref();
300 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
302 if (chipset.majorVersion < 9)
303 return gpuOp.emitOpError(
"raw buffer ops require GCN or higher");
305 Value storeData = adaptor.getODSOperands(0)[0];
306 if (storeData == memref)
310 wantedDataType = storeData.
getType();
312 wantedDataType = gpuOp.getODSResults(0)[0].getType();
314 Value atomicCmpData = Value();
317 Value maybeCmpData = adaptor.getODSOperands(1)[0];
318 if (maybeCmpData != memref)
319 atomicCmpData = maybeCmpData;
322 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
324 Type i32 = rewriter.getI32Type();
327 DataLayout dataLayout = DataLayout::closest(gpuOp);
328 int64_t elementByteWidth =
337 Type llvmBufferValType = llvmWantedDataType;
339 if (
auto floatType = dyn_cast<FloatType>(wantedDataType))
340 llvmBufferValType = this->getTypeConverter()->convertType(
341 rewriter.getIntegerType(floatType.getWidth()));
343 if (
auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
344 uint32_t vecLen = dataVector.getNumElements();
347 uint32_t totalBits = elemBits * vecLen;
349 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
350 if (totalBits > maxVectorOpWidth)
351 return gpuOp.emitOpError(
352 "Total width of loads or stores must be no more than " +
353 Twine(maxVectorOpWidth) +
" bits, but we call for " +
355 " bits. This should've been caught in validation");
356 if (!usePackedFp16 && elemBits < 32) {
357 if (totalBits > 32) {
358 if (totalBits % 32 != 0)
359 return gpuOp.emitOpError(
"Load or store of more than 32-bits that "
360 "doesn't fit into words. Can't happen\n");
361 llvmBufferValType = this->typeConverter->convertType(
362 VectorType::get(totalBits / 32, i32));
364 llvmBufferValType = this->typeConverter->convertType(
365 rewriter.getIntegerType(totalBits));
369 if (
auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
372 if (vecType.getNumElements() == 1)
373 llvmBufferValType = vecType.getElementType();
376 SmallVector<Value, 6> args;
378 if (llvmBufferValType != llvmWantedDataType) {
379 Value castForStore = LLVM::BitcastOp::create(
380 rewriter, loc, llvmBufferValType, storeData);
381 args.push_back(castForStore);
383 args.push_back(storeData);
388 if (llvmBufferValType != llvmWantedDataType) {
389 Value castForCmp = LLVM::BitcastOp::create(
390 rewriter, loc, llvmBufferValType, atomicCmpData);
391 args.push_back(castForCmp);
393 args.push_back(atomicCmpData);
399 SmallVector<int64_t, 5> strides;
400 if (
failed(memrefType.getStridesAndOffset(strides, offset)))
401 return gpuOp.emitOpError(
"Can't lower non-stride-offset memrefs");
403 MemRefDescriptor memrefDescriptor(memref);
405 Value ptr = memrefDescriptor.bufferPtr(
406 rewriter, loc, *this->getTypeConverter(), memrefType);
408 getNumRecords(rewriter, loc, memrefType, memrefDescriptor, strides,
409 elementByteWidth, chipset, adaptor.getBoundsCheck());
411 adaptor.getBoundsCheck(), chipset);
412 args.push_back(resource);
416 adaptor.getIndices(), strides);
417 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
418 indexOffset && *indexOffset > 0) {
420 voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
424 voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
425 args.push_back(voffset);
428 Value sgprOffset = adaptor.getSgprOffset();
431 sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
432 args.push_back(sgprOffset);
439 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
441 Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
442 ArrayRef<NamedAttribute>());
445 if (llvmBufferValType != llvmWantedDataType) {
446 replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
451 rewriter.eraseOp(gpuOp);
469 unsigned expcnt,
unsigned lgkmcnt) {
471 vmcnt = std::min(15u, vmcnt);
472 expcnt = std::min(7u, expcnt);
473 lgkmcnt = std::min(15u, lgkmcnt);
474 return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
477 vmcnt = std::min(63u, vmcnt);
478 expcnt = std::min(7u, expcnt);
479 lgkmcnt = std::min(15u, lgkmcnt);
480 unsigned lowBits = vmcnt & 0xF;
481 unsigned highBits = (vmcnt >> 4) << 14;
482 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
483 return lowBits | highBits | otherCnts;
486 vmcnt = std::min(63u, vmcnt);
487 expcnt = std::min(7u, expcnt);
488 lgkmcnt = std::min(63u, lgkmcnt);
489 unsigned lowBits = vmcnt & 0xF;
490 unsigned highBits = (vmcnt >> 4) << 14;
491 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
492 return lowBits | highBits | otherCnts;
495 vmcnt = std::min(63u, vmcnt);
496 expcnt = std::min(7u, expcnt);
497 lgkmcnt = std::min(63u, lgkmcnt);
498 return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
503struct MemoryCounterWaitOpLowering
513 matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
514 ConversionPatternRewriter &rewriter)
const override {
515 if (
chipset.majorVersion >= 12) {
517 if (std::optional<int> ds = adaptor.getDs())
518 ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
520 if (std::optional<int>
load = adaptor.getLoad())
521 ROCDL::WaitLoadcntOp::create(rewriter, loc, *
load);
523 if (std::optional<int> store = adaptor.getStore())
524 ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
526 if (std::optional<int> exp = adaptor.getExp())
527 ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
529 if (std::optional<int>
tensor = adaptor.getTensor())
530 ROCDL::WaitTensorcntOp::create(rewriter, loc, *
tensor);
532 rewriter.eraseOp(op);
536 if (adaptor.getTensor())
537 return op.emitOpError(
"unsupported chipset");
539 auto getVal = [](
Attribute attr) ->
unsigned {
541 return cast<IntegerAttr>(attr).getInt();
546 unsigned ds = getVal(adaptor.getDsAttr());
547 unsigned exp = getVal(adaptor.getExpAttr());
549 unsigned vmcnt = 1024;
550 Attribute
load = adaptor.getLoadAttr();
551 Attribute store = adaptor.getStoreAttr();
553 vmcnt = getVal(
load) + getVal(store);
555 vmcnt = getVal(
load);
557 vmcnt = getVal(store);
560 FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
562 return op.emitOpError(
"unsupported chipset");
564 rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
570 LDSBarrierOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
571 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
576 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
577 ConversionPatternRewriter &rewriter)
const override {
578 Location loc = op.getLoc();
581 bool requiresInlineAsm = chipset <
kGfx90a;
584 rewriter.getAttr<LLVM::MMRATagAttr>(
"amdgpu-synchronize-as",
"local");
593 StringRef scope =
"workgroup";
595 auto relFence = LLVM::FenceOp::create(rewriter, loc,
596 LLVM::AtomicOrdering::release, scope);
597 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
598 if (requiresInlineAsm) {
599 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
600 LLVM::AsmDialect::AD_ATT);
601 const char *asmStr =
";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
602 const char *constraints =
"";
603 LLVM::InlineAsmOp::create(
606 asmStr, constraints,
true,
607 false, LLVM::TailCallKind::None,
610 }
else if (chipset.majorVersion < 12) {
611 ROCDL::SBarrierOp::create(rewriter, loc);
613 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
614 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
617 auto acqFence = LLVM::FenceOp::create(rewriter, loc,
618 LLVM::AtomicOrdering::acquire, scope);
619 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
620 rewriter.replaceOp(op, acqFence);
626 SchedBarrierOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
627 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
632 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
633 ConversionPatternRewriter &rewriter)
const override {
634 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
635 (uint32_t)op.getOpts());
659 bool allowBf16 =
true) {
661 if (
auto vectorType = dyn_cast<VectorType>(inputType)) {
662 if (vectorType.getElementType().isBF16() && !allowBf16)
663 return LLVM::BitcastOp::create(
664 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
665 if (vectorType.getElementType().isInteger(8) &&
666 vectorType.getNumElements() <= 8)
667 return LLVM::BitcastOp::create(
669 rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
670 if (isa<IntegerType>(vectorType.getElementType()) &&
671 vectorType.getElementTypeBitWidth() <= 8) {
672 int64_t numWords = llvm::divideCeil(
673 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
675 return LLVM::BitcastOp::create(
676 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
686 bool allowBf16 =
true) {
688 auto vectorType = cast<VectorType>(inputType);
690 if (vectorType.getElementType().isBF16() && !allowBf16)
691 return LLVM::BitcastOp::create(
692 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
694 if (isa<IntegerType>(vectorType.getElementType()) &&
695 vectorType.getElementTypeBitWidth() <= 8) {
696 int64_t numWords = llvm::divideCeil(
697 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32);
698 return LLVM::BitcastOp::create(
699 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()), input);
717 .Case([&](IntegerType) {
719 return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(),
722 .Case([&](VectorType vectorType) {
724 int64_t numElements = vectorType.getNumElements();
725 assert((numElements == 4 || numElements == 8) &&
726 "scale operand must be a vector of length 4 or 8");
727 IntegerType outputType =
728 (numElements == 4) ? rewriter.getI32Type() : rewriter.getI64Type();
729 return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
731 .DefaultUnreachable(
"unexpected input type for scale operand");
737 .Case([](Float8E8M0FNUType) {
return 0; })
738 .Case([](Float8E4M3FNType) {
return 2; })
739 .Default(std::nullopt);
744static std::optional<StringRef>
746 if (m == 16 && n == 16 && k == 128)
748 ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
749 : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
751 if (m == 32 && n == 16 && k == 128)
752 return isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
753 : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
767 ConversionPatternRewriter &rewriter,
Location loc,
772 auto vectorType = dyn_cast<VectorType>(inputType);
774 operands.push_back(llvmInput);
777 Type elemType = vectorType.getElementType();
779 operands.push_back(llvmInput);
786 auto mlirInputType = cast<VectorType>(mlirInput.
getType());
787 bool isInputInteger = mlirInputType.getElementType().isInteger();
788 if (isInputInteger) {
790 bool localIsUnsigned = isUnsigned;
792 localIsUnsigned =
true;
794 localIsUnsigned =
false;
797 NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
802 Type i32 = rewriter.getI32Type();
803 Type intrinsicInType = numBits <= 32
804 ? (
Type)rewriter.getIntegerType(numBits)
805 : (
Type)VectorType::get(numBits / 32, i32);
806 auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
807 Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
808 loc, llvmIntrinsicInType, llvmInput);
813 castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
814 operands.push_back(castInput);
827 Value output, int32_t subwordOffset,
831 auto vectorType = dyn_cast<VectorType>(inputType);
832 Type elemType = vectorType.getElementType();
833 operands.push_back(output);
845 return (chipset ==
kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
846 (
hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
852 return (chipset ==
kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
853 (
hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
861 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
862 b = mfma.getBlocks();
867 if (mfma.getReducePrecision() && chipset >=
kGfx942) {
868 if (m == 32 && n == 32 && k == 4 &&
b == 1)
869 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
870 if (m == 16 && n == 16 && k == 8 &&
b == 1)
871 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
873 if (m == 32 && n == 32 && k == 1 &&
b == 2)
874 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
875 if (m == 16 && n == 16 && k == 1 &&
b == 4)
876 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
877 if (m == 4 && n == 4 && k == 1 &&
b == 16)
878 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
879 if (m == 32 && n == 32 && k == 2 &&
b == 1)
880 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
881 if (m == 16 && n == 16 && k == 4 &&
b == 1)
882 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
887 if (m == 32 && n == 32 && k == 16 &&
b == 1)
888 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
889 if (m == 16 && n == 16 && k == 32 &&
b == 1)
890 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
892 if (m == 32 && n == 32 && k == 4 &&
b == 2)
893 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
894 if (m == 16 && n == 16 && k == 4 &&
b == 4)
895 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
896 if (m == 4 && n == 4 && k == 4 &&
b == 16)
897 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
898 if (m == 32 && n == 32 && k == 8 &&
b == 1)
899 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
900 if (m == 16 && n == 16 && k == 16 &&
b == 1)
901 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
906 if (m == 32 && n == 32 && k == 16 &&
b == 1)
907 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
908 if (m == 16 && n == 16 && k == 32 &&
b == 1)
909 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
912 if (m == 32 && n == 32 && k == 4 &&
b == 2)
913 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
914 if (m == 16 && n == 16 && k == 4 &&
b == 4)
915 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
916 if (m == 4 && n == 4 && k == 4 &&
b == 16)
917 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
918 if (m == 32 && n == 32 && k == 8 &&
b == 1)
919 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
920 if (m == 16 && n == 16 && k == 16 &&
b == 1)
921 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
923 if (m == 32 && n == 32 && k == 2 &&
b == 2)
924 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
925 if (m == 16 && n == 16 && k == 2 &&
b == 4)
926 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
927 if (m == 4 && n == 4 && k == 2 &&
b == 16)
928 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
929 if (m == 32 && n == 32 && k == 4 &&
b == 1)
930 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
931 if (m == 16 && n == 16 && k == 8 &&
b == 1)
932 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
937 if (m == 32 && n == 32 && k == 32 &&
b == 1)
938 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
939 if (m == 16 && n == 16 && k == 64 &&
b == 1)
940 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
942 if (m == 32 && n == 32 && k == 4 &&
b == 2)
943 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
944 if (m == 16 && n == 16 && k == 4 &&
b == 4)
945 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
946 if (m == 4 && n == 4 && k == 4 &&
b == 16)
947 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
948 if (m == 32 && n == 32 && k == 8 &&
b == 1)
949 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
950 if (m == 16 && n == 16 && k == 16 &&
b == 1)
951 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
952 if (m == 32 && n == 32 && k == 16 &&
b == 1 && chipset >=
kGfx942)
953 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
954 if (m == 16 && n == 16 && k == 32 &&
b == 1 && chipset >=
kGfx942)
955 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
959 if (m == 16 && n == 16 && k == 4 &&
b == 1)
960 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
961 if (m == 4 && n == 4 && k == 4 &&
b == 4)
962 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
969 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
970 if (m == 16 && n == 16 && k == 32 &&
b == 1) {
972 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
974 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
976 if (m == 32 && n == 32 && k == 16 &&
b == 1) {
978 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
980 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
986 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
987 if (m == 16 && n == 16 && k == 32 &&
b == 1) {
989 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
991 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
993 if (m == 32 && n == 32 && k == 16 &&
b == 1) {
995 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
997 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
1001 return std::nullopt;
1006 .Case([](Float8E4M3FNType) {
return 0u; })
1007 .Case([](Float8E5M2Type) {
return 1u; })
1008 .Case([](Float6E2M3FNType) {
return 2u; })
1009 .Case([](Float6E3M2FNType) {
return 3u; })
1010 .Case([](Float4E2M1FNType) {
return 4u; })
1011 .Default(std::nullopt);
1021static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1023 uint32_t n, uint32_t k, uint32_t
b,
Chipset chipset) {
1029 return std::nullopt;
1030 if (!isa<Float32Type>(destType))
1031 return std::nullopt;
1035 if (!aTypeCode || !bTypeCode)
1036 return std::nullopt;
1038 if (m == 32 && n == 32 && k == 64 &&
b == 1)
1039 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
1040 *aTypeCode, *bTypeCode};
1041 if (m == 16 && n == 16 && k == 128 &&
b == 1)
1043 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
1046 return std::nullopt;
1049static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1052 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
1053 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
1054 mfma.getBlocks(), chipset);
1057static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1060 smfma.getSourceB().getType(),
1061 smfma.getDestC().getType(), smfma.getM(),
1062 smfma.getN(), smfma.getK(), 1u, chipset);
1067static std::optional<StringRef>
1069 Type elemDestType, uint32_t k,
bool isRDNA3) {
1070 using fp8 = Float8E4M3FNType;
1071 using bf8 = Float8E5M2Type;
1076 if (elemSourceType.
isF16() && elemDestType.
isF32())
1077 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
1078 if (elemSourceType.
isBF16() && elemDestType.
isF32())
1079 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
1080 if (elemSourceType.
isF16() && elemDestType.
isF16())
1081 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1083 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1085 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1090 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1091 return std::nullopt;
1095 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1096 elemDestType.
isF32())
1097 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1098 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1099 elemDestType.
isF32())
1100 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1101 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1102 elemDestType.
isF32())
1103 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1104 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1105 elemDestType.
isF32())
1106 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1108 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1110 return std::nullopt;
1114 if (k == 32 && !isRDNA3) {
1116 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1119 return std::nullopt;
1125 Type elemBSourceType,
1128 using fp8 = Float8E4M3FNType;
1129 using bf8 = Float8E5M2Type;
1132 if (elemSourceType.
isF32() && elemDestType.
isF32())
1133 return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1135 return std::nullopt;
1139 if (elemSourceType.
isF16() && elemDestType.
isF32())
1140 return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
1141 if (elemSourceType.
isBF16() && elemDestType.
isF32())
1142 return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
1143 if (elemSourceType.
isF16() && elemDestType.
isF16())
1144 return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
1146 return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
1148 return std::nullopt;
1152 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1153 if (elemDestType.
isF32())
1154 return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
1155 if (elemDestType.
isF16())
1156 return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
1158 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1159 if (elemDestType.
isF32())
1160 return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
1161 if (elemDestType.
isF16())
1162 return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
1164 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1165 if (elemDestType.
isF32())
1166 return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
1167 if (elemDestType.
isF16())
1168 return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
1170 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1171 if (elemDestType.
isF32())
1172 return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
1173 if (elemDestType.
isF16())
1174 return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
1177 return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
1179 return std::nullopt;
1183 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1184 if (elemDestType.
isF32())
1185 return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
1186 if (elemDestType.
isF16())
1187 return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
1189 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1190 if (elemDestType.
isF32())
1191 return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
1192 if (elemDestType.
isF16())
1193 return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
1195 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1196 if (elemDestType.
isF32())
1197 return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
1198 if (elemDestType.
isF16())
1199 return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
1201 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1202 if (elemDestType.
isF32())
1203 return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
1204 if (elemDestType.
isF16())
1205 return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
1208 return std::nullopt;
1211 return std::nullopt;
1219 bool isGfx950 = chipset >=
kGfx950;
1223 uint32_t m = op.getM(), n = op.getN(), k = op.getK();
1228 if (m == 16 && n == 16 && k == 32) {
1230 return ROCDL::smfmac_f32_16x16x32_f16::getOperationName();
1232 return ROCDL::smfmac_f32_16x16x32_bf16::getOperationName();
1235 if (m == 16 && n == 16 && k == 64) {
1238 return ROCDL::smfmac_f32_16x16x64_f16::getOperationName();
1240 return ROCDL::smfmac_f32_16x16x64_bf16::getOperationName();
1244 return ROCDL::smfmac_i32_16x16x64_i8::getOperationName();
1245 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1246 return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName();
1247 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1248 return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName();
1249 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1250 return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName();
1251 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1252 return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName();
1255 if (m == 16 && n == 16 && k == 128 && isGfx950) {
1258 return ROCDL::smfmac_i32_16x16x128_i8::getOperationName();
1259 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1260 return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName();
1261 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1262 return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName();
1263 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1264 return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName();
1265 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1266 return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName();
1269 if (m == 32 && n == 32 && k == 16) {
1271 return ROCDL::smfmac_f32_32x32x16_f16::getOperationName();
1273 return ROCDL::smfmac_f32_32x32x16_bf16::getOperationName();
1276 if (m == 32 && n == 32 && k == 32) {
1279 return ROCDL::smfmac_f32_32x32x32_f16::getOperationName();
1281 return ROCDL::smfmac_f32_32x32x32_bf16::getOperationName();
1285 return ROCDL::smfmac_i32_32x32x32_i8::getOperationName();
1286 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1287 return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName();
1288 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1289 return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName();
1290 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1291 return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName();
1292 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1293 return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName();
1296 if (m == 32 && n == 32 && k == 64 && isGfx950) {
1299 return ROCDL::smfmac_i32_32x32x64_i8::getOperationName();
1300 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1301 return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName();
1302 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1303 return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName();
1304 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1305 return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName();
1306 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1307 return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName();
1310 return std::nullopt;
1318 auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
1319 auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
1320 auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1321 Type elemSourceType = sourceVectorType.getElementType();
1322 Type elemBSourceType = sourceBVectorType.getElementType();
1323 Type elemDestType = destVectorType.getElementType();
1325 const uint32_t k = wmma.getK();
1330 if (isRDNA3 || isRDNA4)
1339 return std::nullopt;
1344 MFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1345 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
1350 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1351 ConversionPatternRewriter &rewriter)
const override {
1352 Location loc = op.getLoc();
1353 Type outType = typeConverter->convertType(op.getDestD().getType());
1354 Type intrinsicOutType = outType;
1355 if (
auto outVecType = dyn_cast<VectorType>(outType))
1356 if (outVecType.getElementType().isBF16())
1357 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
1359 if (chipset.majorVersion != 9 || chipset <
kGfx908)
1360 return op->emitOpError(
"MFMA only supported on gfx908+");
1361 uint32_t getBlgpField =
static_cast<uint32_t
>(op.getBlgp());
1362 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1364 return op.emitOpError(
"negation unsupported on older than gfx942");
1366 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1369 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1371 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1372 return op.emitOpError(
"no intrinsic matching MFMA size on given chipset");
1375 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1377 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1378 return op.emitOpError(
1379 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1380 "be scaled as those fields are used for type information");
1383 StringRef intrinsicName =
1384 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1387 bool allowBf16 = [&]() {
1392 return intrinsicName.contains(
"16x16x32.bf16") ||
1393 intrinsicName.contains(
"32x32x16.bf16");
1395 OperationState loweredOp(loc, intrinsicName);
1396 loweredOp.addTypes(intrinsicOutType);
1398 rewriter, loc, adaptor.getSourceA(), allowBf16),
1400 rewriter, loc, adaptor.getSourceB(), allowBf16),
1401 adaptor.getDestC()});
1404 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1414 Value lowered = rewriter.create(loweredOp)->getResult(0);
1415 if (outType != intrinsicOutType)
1416 lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1417 rewriter.replaceOp(op, lowered);
1423 ScaledMFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1424 : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1429 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1430 ConversionPatternRewriter &rewriter)
const override {
1431 Location loc = op.getLoc();
1432 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1434 if (chipset.majorVersion != 9 || chipset <
kGfx950)
1435 return op->emitOpError(
"scaled MFMA only supported on gfx908+");
1436 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1438 if (!maybeScaledIntrinsic.has_value())
1439 return op.emitOpError(
1440 "no intrinsic matching scaled MFMA size on given chipset");
1442 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1443 OperationState loweredOp(loc, intrinsicName);
1444 loweredOp.addTypes(intrinsicOutType);
1445 loweredOp.addOperands(
1448 adaptor.getDestC()});
1453 loweredOp.addOperands(
1462 Value lowered = rewriter.create(loweredOp)->getResult(0);
1463 rewriter.replaceOp(op, lowered);
1469 SparseMFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1470 : ConvertOpToLLVMPattern<SparseMFMAOp>(converter), chipset(chipset) {}
1475 matchAndRewrite(SparseMFMAOp op, SparseMFMAOpAdaptor adaptor,
1476 ConversionPatternRewriter &rewriter)
const override {
1477 Location loc = op.getLoc();
1479 typeConverter->convertType<VectorType>(op.getDestC().
getType());
1481 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1484 if (chipset.majorVersion != 9 || chipset <
kGfx942)
1485 return op->emitOpError(
"sparse MFMA (smfmac) only supported on gfx942+");
1486 bool isGfx950 = chipset >=
kGfx950;
1489 adaptor.getSourceA(), isGfx950);
1491 adaptor.getSourceB(), isGfx950);
1492 Value c = adaptor.getDestC();
1495 if (!maybeIntrinsic.has_value())
1496 return op.emitOpError(
1497 "no intrinsic matching sparse MFMA on the given chipset");
1500 Value sparseIdx = LLVM::BitcastOp::create(
1501 rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
1503 OperationState loweredOp(loc, maybeIntrinsic.value());
1504 loweredOp.addTypes(outType);
1505 loweredOp.addOperands({a,
b, c, sparseIdx,
1508 Value lowered = rewriter.create(loweredOp)->getResult(0);
1509 rewriter.replaceOp(op, lowered);
1515 WMMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1516 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1521 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1522 ConversionPatternRewriter &rewriter)
const override {
1523 Location loc = op.getLoc();
1525 typeConverter->convertType<VectorType>(op.getDestD().
getType());
1527 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1529 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1530 return op->emitOpError(
"WMMA only supported on gfx11 and gfx12");
1532 bool isGFX1250 = chipset >=
kGfx1250;
1537 auto aType = cast<VectorType>(adaptor.getSourceA().getType());
1538 auto bType = cast<VectorType>(adaptor.getSourceB().getType());
1539 auto destCType = cast<VectorType>(adaptor.getDestC().getType());
1540 bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
1541 bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
1542 bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
1543 bool castOutToI16 = outType.getElementType().
isBF16() && !isGFX1250;
1544 VectorType rawOutType = outType;
1546 rawOutType = outType.clone(rewriter.getI16Type());
1547 Value a = adaptor.getSourceA();
1549 a = LLVM::BitcastOp::create(rewriter, loc,
1550 aType.clone(rewriter.getI16Type()), a);
1551 Value
b = adaptor.getSourceB();
1553 b = LLVM::BitcastOp::create(rewriter, loc,
1554 bType.clone(rewriter.getI16Type()),
b);
1555 Value destC = adaptor.getDestC();
1557 destC = LLVM::BitcastOp::create(
1558 rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
1562 if (!maybeIntrinsic.has_value())
1563 return op.emitOpError(
"no intrinsic matching WMMA on the given chipset");
1565 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1566 return op.emitOpError(
"subwordOffset not supported on gfx12+");
1568 SmallVector<Value, 4> operands;
1569 SmallVector<NamedAttribute, 4> attrs;
1571 op.getSourceA(), operands, attrs,
"signA");
1573 op.getSourceB(), operands, attrs,
"signB");
1575 op.getSubwordOffset(), op.getClamp(), operands,
1578 OperationState loweredOp(loc, *maybeIntrinsic);
1579 loweredOp.addTypes(rawOutType);
1580 loweredOp.addOperands(operands);
1581 loweredOp.addAttributes(attrs);
1582 Operation *lowered = rewriter.create(loweredOp);
1584 Operation *maybeCastBack = lowered;
1585 if (rawOutType != outType)
1586 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1588 rewriter.replaceOp(op, maybeCastBack->
getResults());
1595 ScaledWMMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1596 : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
1601 matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
1602 ConversionPatternRewriter &rewriter)
const override {
1603 Location loc = op.getLoc();
1605 typeConverter->convertType<VectorType>(op.getDestD().
getType());
1607 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1610 return op->emitOpError(
"WMMA scale only supported on gfx1250+");
1612 int64_t m = op.getM();
1613 int64_t n = op.getN();
1614 int64_t k = op.getK();
1622 if (!aFmtCode || !bFmtCode)
1623 return op.emitOpError(
"unsupported element types for scaled_wmma");
1626 auto scaleAVecType = cast<VectorType>(op.getScaleA().getType());
1627 auto scaleBVecType = cast<VectorType>(op.getScaleB().getType());
1629 if (scaleAVecType.getNumElements() != scaleBVecType.getNumElements())
1630 return op.emitOpError(
"scaleA and scaleB must have equal vector length");
1633 Type scaleAElemType = scaleAVecType.getElementType();
1634 Type scaleBElemType = scaleBVecType.getElementType();
1639 if (!scaleAFmt || !scaleBFmt)
1640 return op.emitOpError(
"unsupported scale element types");
1643 bool isScale16 = (scaleAVecType.getNumElements() == 8);
1644 std::optional<StringRef> intrinsicName =
1647 return op.emitOpError(
"unsupported scaled_wmma dimensions: ")
1648 << m <<
"x" << n <<
"x" << k;
1650 SmallVector<NamedAttribute, 8> attrs;
1653 bool is32x16 = (m == 32 && n == 16 && k == 128);
1655 attrs.emplace_back(
"fmtA", rewriter.getI32IntegerAttr(*aFmtCode));
1656 attrs.emplace_back(
"fmtB", rewriter.getI32IntegerAttr(*bFmtCode));
1660 attrs.emplace_back(
"modC", rewriter.getI16IntegerAttr(0));
1665 "scaleAType", rewriter.getI32IntegerAttr(op.getAFirstScaleLane() / 16));
1666 attrs.emplace_back(
"fmtScaleA", rewriter.getI32IntegerAttr(*scaleAFmt));
1668 "scaleBType", rewriter.getI32IntegerAttr(op.getBFirstScaleLane() / 16));
1669 attrs.emplace_back(
"fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt));
1672 attrs.emplace_back(
"reuseA", rewriter.getBoolAttr(
false));
1673 attrs.emplace_back(
"reuseB", rewriter.getBoolAttr(
false));
1686 OperationState loweredOp(loc, *intrinsicName);
1687 loweredOp.addTypes(outType);
1688 loweredOp.addOperands(
1689 {sourceA, sourceB, adaptor.getDestC(), packedScaleA, packedScaleB});
1690 loweredOp.addAttributes(attrs);
1692 Operation *lowered = rewriter.create(loweredOp);
1693 rewriter.replaceOp(op, lowered->
getResults());
1699struct TransposeLoadOpLowering
1701 TransposeLoadOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1702 : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1707 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1708 ConversionPatternRewriter &rewriter)
const override {
1710 return op.emitOpError(
"Non-gfx950 chipset not supported");
1712 Location loc = op.getLoc();
1713 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1717 size_t srcElementSize =
1718 srcMemRefType.getElementType().getIntOrFloatBitWidth();
1719 if (srcElementSize < 8)
1720 return op.emitOpError(
"Expect source memref to have at least 8 bits "
1721 "element size, got ")
1724 auto resultType = cast<VectorType>(op.getResult().getType());
1727 (adaptor.getSrcIndices()));
1729 size_t numElements = resultType.getNumElements();
1730 size_t elementTypeSize =
1731 resultType.getElementType().getIntOrFloatBitWidth();
1735 Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
1736 rewriter.getIntegerType(32));
1737 Type llvmResultType = typeConverter->convertType(resultType);
1739 switch (elementTypeSize) {
1741 assert(numElements == 16);
1742 auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
1743 rocdlResultType, srcPtr);
1744 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1748 assert(numElements == 16);
1749 auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
1750 rocdlResultType, srcPtr);
1751 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1755 assert(numElements == 8);
1756 auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
1757 rocdlResultType, srcPtr);
1758 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1762 assert(numElements == 4);
1763 rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
1768 return op.emitOpError(
"Unsupported element size for transpose load");
1775 GatherToLDSOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1776 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
1781 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1782 ConversionPatternRewriter &rewriter)
const override {
1783 if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
1784 return op.emitOpError(
"pre-gfx9 and post-gfx10 not supported");
1786 Location loc = op.getLoc();
1788 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1789 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
1794 Type transferType = op.getTransferType();
1795 int loadWidth = [&]() ->
int {
1796 if (
auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1797 return (transferVectorType.getNumElements() *
1798 transferVectorType.getElementTypeBitWidth()) /
1805 if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
1806 return op.emitOpError(
"chipset unsupported element size");
1808 if (chipset !=
kGfx950 && llvm::is_contained({12, 16}, loadWidth))
1809 return op.emitOpError(
"Gather to LDS instructions with 12-byte and "
1810 "16-byte load widths are only supported on gfx950");
1814 (adaptor.getSrcIndices()));
1817 (adaptor.getDstIndices()));
1819 rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
1820 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
1821 rewriter.getI32IntegerAttr(0),
1830struct ExtPackedFp8OpLowering final
1832 ExtPackedFp8OpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1833 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
1838 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1839 ConversionPatternRewriter &rewriter)
const override;
1842struct ScaledExtPackedMatrixOpLowering final
1844 ScaledExtPackedMatrixOpLowering(
const LLVMTypeConverter &converter,
1846 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
1851 matchAndRewrite(ScaledExtPackedMatrixOp op,
1852 ScaledExtPackedMatrixOpAdaptor adaptor,
1853 ConversionPatternRewriter &rewriter)
const override;
1856struct PackedTrunc2xFp8OpLowering final
1858 PackedTrunc2xFp8OpLowering(
const LLVMTypeConverter &converter,
1860 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
1865 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1866 ConversionPatternRewriter &rewriter)
const override;
1869struct PackedStochRoundFp8OpLowering final
1871 PackedStochRoundFp8OpLowering(
const LLVMTypeConverter &converter,
1873 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
1878 matchAndRewrite(PackedStochRoundFp8Op op,
1879 PackedStochRoundFp8OpAdaptor adaptor,
1880 ConversionPatternRewriter &rewriter)
const override;
1883struct ScaledExtPackedOpLowering final
1885 ScaledExtPackedOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1886 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
1891 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1892 ConversionPatternRewriter &rewriter)
const override;
1895struct PackedScaledTruncOpLowering final
1897 PackedScaledTruncOpLowering(
const LLVMTypeConverter &converter,
1899 : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
1904 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1905 ConversionPatternRewriter &rewriter)
const override;
1910LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1911 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1912 ConversionPatternRewriter &rewriter)
const {
1913 Location loc = op.getLoc();
1915 return rewriter.notifyMatchFailure(
1916 loc,
"Fp8 conversion instructions are not available on target "
1917 "architecture and their emulation is not implemented");
1919 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
1920 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1921 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1923 Value source = adaptor.getSource();
1924 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1925 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1928 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1929 Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
1930 if (!sourceVecType) {
1931 longVec = LLVM::InsertElementOp::create(
1934 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1936 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1938 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1943 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1944 if (resultVecType) {
1946 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
1949 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
1954 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
1957 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
1964int32_t getScaleSel(int32_t blockSize,
unsigned bitWidth, int32_t scaleWaveHalf,
1965 int32_t firstScaleByte) {
1971 assert(llvm::is_contained({16, 32}, blockSize));
1972 assert(llvm::is_contained({4u, 6u, 8u}, bitWidth));
1974 const bool isFp8 = bitWidth == 8;
1975 const bool isBlock16 = blockSize == 16;
1978 int32_t bit0 = isBlock16;
1979 assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
1980 int32_t bit1 = (firstScaleByte == 2) << 1;
1981 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
1982 int32_t bit2 = scaleWaveHalf << 2;
1983 return bit2 | bit1 | bit0;
1986 int32_t bit0 = isBlock16;
1988 assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
1989 int32_t bits2and1 = firstScaleByte << 1;
1990 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
1991 int32_t bit3 = scaleWaveHalf << 3;
1992 int32_t bits = bit3 | bits2and1 | bit0;
1994 assert(!llvm::is_contained(
1995 {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
1999static std::optional<StringRef>
2000scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
2001 using fp4 = Float4E2M1FNType;
2002 using fp8 = Float8E4M3FNType;
2003 using bf8 = Float8E5M2Type;
2004 using fp6 = Float6E2M3FNType;
2005 using bf6 = Float6E3M2FNType;
2006 if (isa<fp4>(srcElemType)) {
2007 if (destElemType.
isF16())
2008 return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
2009 if (destElemType.
isBF16())
2010 return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
2011 if (destElemType.
isF32())
2012 return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
2013 return std::nullopt;
2015 if (isa<fp8>(srcElemType)) {
2016 if (destElemType.
isF16())
2017 return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
2018 if (destElemType.
isBF16())
2019 return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
2020 if (destElemType.
isF32())
2021 return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
2022 return std::nullopt;
2024 if (isa<bf8>(srcElemType)) {
2025 if (destElemType.
isF16())
2026 return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
2027 if (destElemType.
isBF16())
2028 return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
2029 if (destElemType.
isF32())
2030 return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
2031 return std::nullopt;
2033 if (isa<fp6>(srcElemType)) {
2034 if (destElemType.
isF16())
2035 return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
2036 if (destElemType.
isBF16())
2037 return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
2038 if (destElemType.
isF32())
2039 return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
2040 return std::nullopt;
2042 if (isa<bf6>(srcElemType)) {
2043 if (destElemType.
isF16())
2044 return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
2045 if (destElemType.
isBF16())
2046 return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
2047 if (destElemType.
isF32())
2048 return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
2049 return std::nullopt;
2051 llvm_unreachable(
"invalid combination of element types for packed conversion "
2055LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
2056 ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
2057 ConversionPatternRewriter &rewriter)
const {
2058 using fp4 = Float4E2M1FNType;
2059 using fp8 = Float8E4M3FNType;
2060 using bf8 = Float8E5M2Type;
2061 using fp6 = Float6E2M3FNType;
2062 using bf6 = Float6E3M2FNType;
2063 Location loc = op.getLoc();
2065 return rewriter.notifyMatchFailure(
2067 "Scaled fp packed conversion instructions are not available on target "
2068 "architecture and their emulation is not implemented");
2072 int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
2073 int32_t firstScaleByte = op.getFirstScaleByte();
2074 int32_t blockSize = op.getBlockSize();
2075 auto sourceType = cast<VectorType>(op.getSource().getType());
2076 auto srcElemType = cast<FloatType>(sourceType.getElementType());
2077 unsigned bitWidth = srcElemType.getWidth();
2079 auto targetType = cast<VectorType>(op.getResult().getType());
2080 auto destElemType = cast<FloatType>(targetType.getElementType());
2082 IntegerType i32 = rewriter.getI32Type();
2083 Value source = adaptor.getSource();
2084 Type llvmResultType = typeConverter->convertType(op.getResult().getType());
2085 Type packedType =
nullptr;
2086 if (isa<fp4>(srcElemType)) {
2088 packedType = getTypeConverter()->convertType(packedType);
2089 }
else if (isa<fp8, bf8>(srcElemType)) {
2090 packedType = VectorType::get(2, i32);
2091 packedType = getTypeConverter()->convertType(packedType);
2092 }
else if (isa<fp6, bf6>(srcElemType)) {
2093 packedType = VectorType::get(3, i32);
2094 packedType = getTypeConverter()->convertType(packedType);
2096 llvm_unreachable(
"invalid element type for packed scaled ext");
2099 if (!packedType || !llvmResultType) {
2100 return rewriter.notifyMatchFailure(op,
"type conversion failed");
2103 std::optional<StringRef> maybeIntrinsic =
2104 scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
2105 if (!maybeIntrinsic.has_value())
2106 return op.emitOpError(
2107 "no intrinsic matching packed scaled conversion on the given chipset");
2110 getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
2112 LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
2113 Value castedSource =
2114 LLVM::BitcastOp::create(rewriter, loc, packedType, source);
2116 OperationState loweredOp(loc, *maybeIntrinsic);
2117 loweredOp.addTypes({llvmResultType});
2118 loweredOp.addOperands({castedSource, castedScale});
2120 SmallVector<NamedAttribute, 1> attrs;
2122 NamedAttribute(
"scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
2124 loweredOp.addAttributes(attrs);
2125 Operation *lowered = rewriter.create(loweredOp);
2126 rewriter.replaceOp(op, lowered);
2131LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
2132 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2133 ConversionPatternRewriter &rewriter)
const {
2134 Location loc = op.getLoc();
2136 return rewriter.notifyMatchFailure(
2137 loc,
"Scaled fp conversion instructions are not available on target "
2138 "architecture and their emulation is not implemented");
2139 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2141 Value source = adaptor.getSource();
2142 Value scale = adaptor.getScale();
2144 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2145 Type sourceElemType = sourceVecType.getElementType();
2146 VectorType destVecType = cast<VectorType>(op.getResult().getType());
2147 Type destElemType = destVecType.getElementType();
2149 VectorType packedVecType;
2150 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
2151 VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
2152 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
2153 }
else if (isa<Float4E2M1FNType>(sourceElemType)) {
2154 VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
2155 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
2157 llvm_unreachable(
"invalid element type for scaled ext");
2161 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
2162 Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
2163 if (!sourceVecType) {
2164 longVec = LLVM::InsertElementOp::create(
2167 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2169 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2171 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2176 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2178 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF32())
2179 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
2180 op, destVecType, i32Source, scale, op.getIndex());
2181 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF16())
2182 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
2183 op, destVecType, i32Source, scale, op.getIndex());
2184 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isBF16())
2185 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
2186 op, destVecType, i32Source, scale, op.getIndex());
2187 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF32())
2188 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
2189 op, destVecType, i32Source, scale, op.getIndex());
2190 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF16())
2191 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
2192 op, destVecType, i32Source, scale, op.getIndex());
2193 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isBF16())
2194 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
2195 op, destVecType, i32Source, scale, op.getIndex());
2196 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF32())
2197 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
2198 op, destVecType, i32Source, scale, op.getIndex());
2199 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF16())
2200 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
2201 op, destVecType, i32Source, scale, op.getIndex());
2202 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isBF16())
2203 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
2204 op, destVecType, i32Source, scale, op.getIndex());
2211LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
2212 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2213 ConversionPatternRewriter &rewriter)
const {
2214 Location loc = op.getLoc();
2216 return rewriter.notifyMatchFailure(
2217 loc,
"Scaled fp conversion instructions are not available on target "
2218 "architecture and their emulation is not implemented");
2219 Type v2i16 = getTypeConverter()->convertType(
2220 VectorType::get(2, rewriter.getI16Type()));
2221 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2223 Type resultType = op.getResult().getType();
2225 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2226 Type sourceElemType = sourceVecType.getElementType();
2228 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
2230 Value source = adaptor.getSource();
2231 Value scale = adaptor.getScale();
2232 Value existing = adaptor.getExisting();
2234 existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
2236 existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
2238 if (sourceVecType.getNumElements() < 2) {
2240 Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2241 VectorType v2 = VectorType::get(2, sourceElemType);
2242 source = LLVM::ZeroOp::create(rewriter, loc, v2);
2243 source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
2246 Value sourceA, sourceB;
2247 if (sourceElemType.
isF32()) {
2250 sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2251 sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
2255 if (sourceElemType.
isF32() && isa<Float8E5M2Type>(resultElemType))
2256 result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
2257 existing, sourceA, sourceB,
2258 scale, op.getIndex());
2259 else if (sourceElemType.
isF16() && isa<Float8E5M2Type>(resultElemType))
2260 result = ROCDL::CvtScaleF32PkBf8F16Op::create(
2261 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2262 else if (sourceElemType.
isBF16() && isa<Float8E5M2Type>(resultElemType))
2263 result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
2264 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2265 else if (sourceElemType.
isF32() && isa<Float8E4M3FNType>(resultElemType))
2266 result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
2267 existing, sourceA, sourceB,
2268 scale, op.getIndex());
2269 else if (sourceElemType.
isF16() && isa<Float8E4M3FNType>(resultElemType))
2270 result = ROCDL::CvtScaleF32PkFp8F16Op::create(
2271 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2272 else if (sourceElemType.
isBF16() && isa<Float8E4M3FNType>(resultElemType))
2273 result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
2274 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2275 else if (sourceElemType.
isF32() && isa<Float4E2M1FNType>(resultElemType))
2276 result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
2277 existing, sourceA, sourceB,
2278 scale, op.getIndex());
2279 else if (sourceElemType.
isF16() && isa<Float4E2M1FNType>(resultElemType))
2280 result = ROCDL::CvtScaleF32PkFp4F16Op::create(
2281 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2282 else if (sourceElemType.
isBF16() && isa<Float4E2M1FNType>(resultElemType))
2283 result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
2284 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2288 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2289 op, getTypeConverter()->convertType(resultType),
result);
2293LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
2294 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
2295 ConversionPatternRewriter &rewriter)
const {
2296 Location loc = op.getLoc();
2298 return rewriter.notifyMatchFailure(
2299 loc,
"Fp8 conversion instructions are not available on target "
2300 "architecture and their emulation is not implemented");
2301 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2303 Type resultType = op.getResult().getType();
2306 Value sourceA = adaptor.getSourceA();
2307 Value sourceB = adaptor.getSourceB();
2309 sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.
getType());
2310 Value existing = adaptor.getExisting();
2312 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2314 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2318 result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2319 existing, op.getWordIndex());
2321 result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2322 existing, op.getWordIndex());
2324 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2325 op, getTypeConverter()->convertType(resultType),
result);
2329LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
2330 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
2331 ConversionPatternRewriter &rewriter)
const {
2332 Location loc = op.getLoc();
2334 return rewriter.notifyMatchFailure(
2335 loc,
"Fp8 conversion instructions are not available on target "
2336 "architecture and their emulation is not implemented");
2337 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2339 Type resultType = op.getResult().getType();
2342 Value source = adaptor.getSource();
2343 Value stoch = adaptor.getStochiasticParam();
2344 Value existing = adaptor.getExisting();
2346 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2348 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2352 result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
2353 existing, op.getStoreIndex());
2355 result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
2356 existing, op.getStoreIndex());
2358 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2359 op, getTypeConverter()->convertType(resultType),
result);
2365struct AMDGPUDPPLowering :
public ConvertOpToLLVMPattern<DPPOp> {
2366 AMDGPUDPPLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2367 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
2371 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
2372 ConversionPatternRewriter &rewriter)
const override {
2375 Location loc = DppOp.getLoc();
2376 Value src = adaptor.getSrc();
2377 Value old = adaptor.getOld();
2380 Type llvmType =
nullptr;
2382 llvmType = rewriter.getI32Type();
2383 }
else if (isa<FloatType>(srcType)) {
2385 ? rewriter.getF32Type()
2386 : rewriter.getF64Type();
2387 }
else if (isa<IntegerType>(srcType)) {
2389 ? rewriter.getI32Type()
2390 : rewriter.getI64Type();
2392 auto llvmSrcIntType = typeConverter->convertType(
2396 auto convertOperand = [&](Value operand, Type operandType) {
2397 if (operandType.getIntOrFloatBitWidth() <= 16) {
2398 if (llvm::isa<FloatType>(operandType)) {
2400 LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
2402 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
2403 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
2404 Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
2406 LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
2408 operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
2413 src = convertOperand(src, srcType);
2414 old = convertOperand(old, oldType);
2417 enum DppCtrl :
unsigned {
2426 ROW_HALF_MIRROR = 0x141,
2431 auto kind = DppOp.getKind();
2432 auto permArgument = DppOp.getPermArgument();
2433 uint32_t DppCtrl = 0;
2437 case DPPPerm::quad_perm: {
2438 auto quadPermAttr = cast<ArrayAttr>(*permArgument);
2440 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
2441 uint32_t num = elem.getInt();
2442 DppCtrl |= num << (i * 2);
2447 case DPPPerm::row_shl: {
2448 auto intAttr = cast<IntegerAttr>(*permArgument);
2449 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
2452 case DPPPerm::row_shr: {
2453 auto intAttr = cast<IntegerAttr>(*permArgument);
2454 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
2457 case DPPPerm::row_ror: {
2458 auto intAttr = cast<IntegerAttr>(*permArgument);
2459 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
2462 case DPPPerm::wave_shl:
2463 DppCtrl = DppCtrl::WAVE_SHL1;
2465 case DPPPerm::wave_shr:
2466 DppCtrl = DppCtrl::WAVE_SHR1;
2468 case DPPPerm::wave_rol:
2469 DppCtrl = DppCtrl::WAVE_ROL1;
2471 case DPPPerm::wave_ror:
2472 DppCtrl = DppCtrl::WAVE_ROR1;
2474 case DPPPerm::row_mirror:
2475 DppCtrl = DppCtrl::ROW_MIRROR;
2477 case DPPPerm::row_half_mirror:
2478 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
2480 case DPPPerm::row_bcast_15:
2481 DppCtrl = DppCtrl::BCAST15;
2483 case DPPPerm::row_bcast_31:
2484 DppCtrl = DppCtrl::BCAST31;
2490 auto rowMask = DppOp->getAttrOfType<IntegerAttr>(
"row_mask").getInt();
2491 auto bankMask = DppOp->getAttrOfType<IntegerAttr>(
"bank_mask").getInt();
2492 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>(
"bound_ctrl").getValue();
2496 ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
2497 rowMask, bankMask, boundCtrl);
2499 Value
result = dppMovOp.getRes();
2501 result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType,
result);
2502 if (!llvm::isa<IntegerType>(srcType)) {
2503 result = LLVM::BitcastOp::create(rewriter, loc, srcType,
result);
2514struct AMDGPUSwizzleBitModeLowering
2515 :
public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
2519 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
2520 ConversionPatternRewriter &rewriter)
const override {
2521 Location loc = op.getLoc();
2522 Type i32 = rewriter.getI32Type();
2523 Value src = adaptor.getSrc();
2524 SmallVector<Value> decomposed =
2525 LLVM::decomposeValue(rewriter, loc, src, i32);
2526 unsigned andMask = op.getAndMask();
2527 unsigned orMask = op.getOrMask();
2528 unsigned xorMask = op.getXorMask();
2532 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
2534 SmallVector<Value> swizzled;
2535 for (Value v : decomposed) {
2537 ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
2538 swizzled.emplace_back(res);
2541 Value
result = LLVM::composeValue(rewriter, loc, swizzled, src.
getType());
2542 rewriter.replaceOp(op,
result);
2547struct AMDGPUPermlaneLowering :
public ConvertOpToLLVMPattern<PermlaneSwapOp> {
2550 AMDGPUPermlaneLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2551 : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
2555 matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
2556 ConversionPatternRewriter &rewriter)
const override {
2558 return op->emitOpError(
"permlane_swap is only supported on gfx950+");
2560 Location loc = op.getLoc();
2561 Type i32 = rewriter.getI32Type();
2562 Value src = adaptor.getSrc();
2563 unsigned rowLength = op.getRowLength();
2564 bool fi = op.getFetchInactive();
2565 bool boundctrl = op.getBoundCtrl();
2567 SmallVector<Value> decomposed =
2568 LLVM::decomposeValue(rewriter, loc, src, i32);
2570 SmallVector<Value> permuted;
2571 for (Value v : decomposed) {
2573 Type i32pair = LLVM::LLVMStructType::getLiteral(
2574 rewriter.getContext(), {v.getType(), v.getType()});
2576 if (rowLength == 16)
2577 res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2579 else if (rowLength == 32)
2580 res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2583 llvm_unreachable(
"unsupported row length");
2585 Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
2586 Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
2588 Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
2589 LLVM::ICmpPredicate::eq, vdst0, v);
2594 LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
2595 permuted.emplace_back(vdstNew);
2598 Value
result = LLVM::composeValue(rewriter, loc, permuted, src.
getType());
2599 rewriter.replaceOp(op,
result);
2604static Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
2605 Value accumulator, Value value, int64_t shift) {
2610 value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
2616 constexpr bool isDisjoint =
true;
2617 return LLVM::OrOp::create(rewriter, loc, accumulator, value, isDisjoint);
2620template <
typename BaseOp>
2621struct AMDGPUMakeDmaBaseLowering :
public ConvertOpToLLVMPattern<BaseOp> {
2622 using ConvertOpToLLVMPattern<BaseOp>::ConvertOpToLLVMPattern;
2625 AMDGPUMakeDmaBaseLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2626 : ConvertOpToLLVMPattern<BaseOp>(converter), chipset(chipset) {}
2630 matchAndRewrite(BaseOp op, Adaptor adaptor,
2631 ConversionPatternRewriter &rewriter)
const override {
2633 return op->emitOpError(
"make_dma_base is only supported on gfx1250");
2635 Location loc = op.getLoc();
2637 constexpr int32_t constlen = 4;
2638 Value consts[constlen];
2639 for (int64_t i = 0; i < constlen; ++i)
2642 constexpr int32_t sgprslen = constlen;
2643 Value sgprs[sgprslen];
2644 for (int64_t i = 0; i < sgprslen; ++i) {
2645 sgprs[i] = consts[0];
2648 sgprs[0] = consts[1];
2650 if constexpr (BaseOp::isGather()) {
2651 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30);
2653 auto type = cast<TDMGatherBaseType>(op.getResult().getType());
2654 Type indexType = type.getIndexType();
2656 assert(llvm::is_contained({16u, 32u}, indexSize) &&
2657 "expected index_size to be 16 or 32");
2658 unsigned idx = (indexSize / 16) - 1;
2661 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 31);
2664 ValueRange ldsIndices = adaptor.getLdsIndices();
2665 Value lds = adaptor.getLds();
2666 auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
2669 rewriter, loc, ldsMemRefType, lds, ldsIndices);
2671 ValueRange globalIndices = adaptor.getGlobalIndices();
2672 Value global = adaptor.getGlobal();
2673 auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
2676 rewriter, loc, globalMemRefType, global, globalIndices);
2678 Type i32 = rewriter.getI32Type();
2679 Type i64 = rewriter.getI64Type();
2681 sgprs[1] = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
2682 Value castForGlobalAddr =
2683 LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
2685 sgprs[2] = LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
2687 Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
2690 Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
2693 highHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
2695 sgprs[3] = setValueAtOffset(rewriter, loc, highHalf, consts[2], 30);
2697 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
2698 assert(v4i32 &&
"expected type conversion to succeed");
2699 Value
result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
2701 for (
auto [sgpr, constant] : llvm::zip_equal(sgprs, consts))
2703 LLVM::InsertElementOp::create(rewriter, loc,
result, sgpr, constant);
2705 rewriter.replaceOp(op,
result);
2710template <
typename DescriptorOp>
2711struct AMDGPULowerDescriptor :
public ConvertOpToLLVMPattern<DescriptorOp> {
2712 using ConvertOpToLLVMPattern<DescriptorOp>::ConvertOpToLLVMPattern;
2715 AMDGPULowerDescriptor(
const LLVMTypeConverter &converter, Chipset chipset)
2716 : ConvertOpToLLVMPattern<DescriptorOp>(converter), chipset(chipset) {}
2719 Value getDGroup0(OpAdaptor adaptor)
const {
return adaptor.getBase(); }
2721 Value setWorkgroupMask(DescriptorOp op, OpAdaptor adaptor,
2722 ConversionPatternRewriter &rewriter, Location loc,
2723 Value sgpr0)
const {
2724 Value mask = op.getWorkgroupMask();
2728 Type i16 = rewriter.getI16Type();
2729 mask = LLVM::BitcastOp::create(rewriter, loc, i16, mask);
2730 Type i32 = rewriter.getI32Type();
2731 Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask);
2732 return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0);
2735 Value setDataSize(DescriptorOp op, OpAdaptor adaptor,
2736 ConversionPatternRewriter &rewriter, Location loc,
2737 Value sgpr0, ArrayRef<Value> consts)
const {
2738 unsigned elementTypeWidthInBits = op.getElementTypeWidth();
2739 assert(llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidthInBits) &&
2740 "expected type width to be 8, 16, 32, or 64.");
2741 int64_t idx = llvm::Log2_32(elementTypeWidthInBits / 8);
2742 Value size = consts[idx];
2743 return setValueAtOffset(rewriter, loc, sgpr0, size, 16);
2746 Value setAtomicBarrier(DescriptorOp op, OpAdaptor adaptor,
2747 ConversionPatternRewriter &rewriter, Location loc,
2748 Value sgpr0, ArrayRef<Value> consts)
const {
2749 if (!adaptor.getAtomicBarrierAddress())
2752 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
2755 Value setIterateEnable(DescriptorOp op, OpAdaptor adaptor,
2756 ConversionPatternRewriter &rewriter, Location loc,
2757 Value sgpr0, ArrayRef<Value> consts)
const {
2758 if (!adaptor.getGlobalIncrement())
2763 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
2766 Value setPadEnable(DescriptorOp op, OpAdaptor adaptor,
2767 ConversionPatternRewriter &rewriter, Location loc,
2768 Value sgpr0, ArrayRef<Value> consts)
const {
2769 if (!op.getPadAmount())
2772 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
2775 Value setEarlyTimeout(DescriptorOp op, OpAdaptor adaptor,
2776 ConversionPatternRewriter &rewriter, Location loc,
2777 Value sgpr0, ArrayRef<Value> consts)
const {
2778 if (!op.getWorkgroupMask())
2781 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21);
2784 Value setPadInterval(DescriptorOp op, OpAdaptor adaptor,
2785 ConversionPatternRewriter &rewriter, Location loc,
2786 Value sgpr0, ArrayRef<Value> consts)
const {
2787 if (!op.getPadAmount())
2796 IntegerType i32 = rewriter.getI32Type();
2797 Value padInterval = adaptor.getPadInterval();
2798 padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
2799 padInterval,
false);
2800 padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
2802 return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
2805 Value setPadAmount(DescriptorOp op, OpAdaptor adaptor,
2806 ConversionPatternRewriter &rewriter, Location loc,
2807 Value sgpr0, ArrayRef<Value> consts)
const {
2808 if (!op.getPadAmount())
2817 Value padAmount = adaptor.getPadAmount();
2818 padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
2820 return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
2823 Value setAtomicBarrierAddress(DescriptorOp op, OpAdaptor adaptor,
2824 ConversionPatternRewriter &rewriter,
2825 Location loc, Value sgpr1,
2826 ArrayRef<Value> consts)
const {
2827 if (!adaptor.getAtomicBarrierAddress())
2830 Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
2831 auto barrierAddressTy =
2832 cast<MemRefType>(op.getAtomicBarrierAddress().getType());
2833 ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
2835 rewriter, loc, barrierAddressTy, atomicBarrierAddress,
2836 atomicBarrierIndices);
2837 IntegerType i32 = rewriter.getI32Type();
2843 atomicBarrierAddress =
2844 LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
2845 atomicBarrierAddress =
2846 LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
2848 atomicBarrierAddress =
2849 LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
2850 return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
2853 std::pair<Value, Value> setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
2854 ConversionPatternRewriter &rewriter,
2855 Location loc, Value sgpr1, Value sgpr2,
2856 ArrayRef<Value> consts, uint64_t dimX,
2857 uint32_t offset)
const {
2858 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
2859 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
2860 SmallVector<OpFoldResult> mixedGlobalSizes =
2862 if (mixedGlobalSizes.size() <= dimX)
2863 return {sgpr1, sgpr2};
2865 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
2872 if (
auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
2876 IntegerType i32 = rewriter.getI32Type();
2877 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
2878 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
2881 sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDimX, offset);
2884 Value tensorDimXHigh = LLVM::LShrOp::create(rewriter, loc, tensorDimX, c16);
2885 sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDimXHigh, offset + 16);
2886 return {sgpr1, sgpr2};
2889 std::pair<Value, Value> setTensorDim0(DescriptorOp op, OpAdaptor adaptor,
2890 ConversionPatternRewriter &rewriter,
2891 Location loc, Value sgpr1, Value sgpr2,
2892 ArrayRef<Value> consts)
const {
2893 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, 0,
2897 std::pair<Value, Value> setTensorDim1(DescriptorOp op, OpAdaptor adaptor,
2898 ConversionPatternRewriter &rewriter,
2899 Location loc, Value sgpr2, Value sgpr3,
2900 ArrayRef<Value> consts)
const {
2901 return setTensorDimX(op, adaptor, rewriter, loc, sgpr2, sgpr3, consts, 1,
2905 Value setTileDimX(DescriptorOp op, OpAdaptor adaptor,
2906 ConversionPatternRewriter &rewriter, Location loc,
2907 Value sgpr, ArrayRef<Value> consts,
size_t dimX,
2908 int64_t offset)
const {
2909 ArrayRef<int64_t> sharedStaticSizes = adaptor.getSharedStaticSizes();
2910 ValueRange sharedDynamicSizes = adaptor.getSharedDynamicSizes();
2911 SmallVector<OpFoldResult> mixedSharedSizes =
2913 if (mixedSharedSizes.size() <= dimX)
2916 OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
2925 if (
auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult)) {
2929 IntegerType i32 = rewriter.getI32Type();
2930 tileDimX = cast<Value>(tileDimXOpFoldResult);
2931 tileDimX = LLVM::TruncOp::create(rewriter, loc, i32, tileDimX);
2934 return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
2937 Value setTileDim0(DescriptorOp op, OpAdaptor adaptor,
2938 ConversionPatternRewriter &rewriter, Location loc,
2939 Value sgpr3, ArrayRef<Value> consts)
const {
2940 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
2943 Value setTileDim1(DescriptorOp op, OpAdaptor adaptor,
2944 ConversionPatternRewriter &rewriter, Location loc,
2945 Value sgpr4, ArrayRef<Value> consts)
const {
2946 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
2949 Value setValidIndices(DescriptorOp op, OpAdaptor adaptor,
2950 ConversionPatternRewriter &rewriter, Location loc,
2951 Value sgpr4, ArrayRef<Value> consts)
const {
2952 auto type = cast<VectorType>(op.getIndices().getType());
2953 ArrayRef<int64_t> shape = type.getShape();
2954 assert(shape.size() == 1 &&
"expected shape to be of rank 1.");
2955 unsigned length = shape.back();
2956 assert(0 < length && length <= 16 &&
"expected length to be at most 16.");
2958 return setValueAtOffset(rewriter, loc, sgpr4, value, 128);
2961 Value setTileDim1OrValidIndices(DescriptorOp op, OpAdaptor adaptor,
2962 ConversionPatternRewriter &rewriter,
2963 Location loc, Value sgpr4,
2964 ArrayRef<Value> consts)
const {
2965 if constexpr (DescriptorOp::isGather())
2966 return setValidIndices(op, adaptor, rewriter, loc, sgpr4, consts);
2967 return setTileDim1(op, adaptor, rewriter, loc, sgpr4, consts);
2970 Value setTileDim2(DescriptorOp op, OpAdaptor adaptor,
2971 ConversionPatternRewriter &rewriter, Location loc,
2972 Value sgpr4, ArrayRef<Value> consts)
const {
2974 if constexpr (DescriptorOp::isGather())
2976 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
2979 std::pair<Value, Value>
2980 setTensorDimXStride(DescriptorOp op, OpAdaptor adaptor,
2981 ConversionPatternRewriter &rewriter, Location loc,
2982 Value sgprY, Value sgprZ, ArrayRef<Value> consts,
2983 size_t dimX, int64_t offset)
const {
2984 ArrayRef<int64_t> globalStaticStrides = adaptor.getGlobalStaticStrides();
2985 ValueRange globalDynamicStrides = adaptor.getGlobalDynamicStrides();
2986 SmallVector<OpFoldResult> mixedGlobalStrides =
2987 getMixedValues(globalStaticStrides, globalDynamicStrides, rewriter);
2989 if (mixedGlobalStrides.size() <= (dimX + 1))
2990 return {sgprY, sgprZ};
2992 OpFoldResult tensorDimXStrideOpFoldResult =
2993 *(mixedGlobalStrides.rbegin() + dimX + 1);
2998 Value tensorDimXStride;
2999 if (
auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult))
3003 tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
3005 constexpr int64_t first48bits = (1ll << 48) - 1;
3008 LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
3009 IntegerType i32 = rewriter.getI32Type();
3010 Value tensorDimXStrideLow =
3011 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
3012 sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
3014 int64_t shift = (offset % 32) == 0 ? 32 : offset % 32;
3016 Value tensorDimXStrideHigh =
3017 LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
3018 tensorDimXStrideHigh =
3019 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
3020 sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
3022 return {sgprY, sgprZ};
3025 std::pair<Value, Value>
3026 setTensorDim0Stride(DescriptorOp op, OpAdaptor adaptor,
3027 ConversionPatternRewriter &rewriter, Location loc,
3028 Value sgpr5, Value sgpr6, ArrayRef<Value> consts)
const {
3029 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3033 std::pair<Value, Value>
3034 setTensorDim1Stride(DescriptorOp op, OpAdaptor adaptor,
3035 ConversionPatternRewriter &rewriter, Location loc,
3036 Value sgpr5, Value sgpr6, ArrayRef<Value> consts)
const {
3038 if constexpr (DescriptorOp::isGather())
3039 return {sgpr5, sgpr6};
3040 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3044 Value getDGroup1(DescriptorOp op, OpAdaptor adaptor,
3045 ConversionPatternRewriter &rewriter, Location loc,
3046 ArrayRef<Value> consts)
const {
3048 for (int64_t i = 0; i < 8; ++i) {
3049 sgprs[i] = consts[0];
3052 sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]);
3053 sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts);
3054 sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts);
3055 sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3056 sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3057 sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts);
3058 sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts);
3059 sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts);
3062 setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts);
3063 std::tie(sgprs[1], sgprs[2]) =
3064 setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3065 std::tie(sgprs[2], sgprs[3]) =
3066 setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3068 sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
3070 setTileDim1OrValidIndices(op, adaptor, rewriter, loc, sgprs[4], consts);
3071 sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
3072 std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
3073 op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
3074 std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride(
3075 op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts);
3077 IntegerType i32 = rewriter.getI32Type();
3078 Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32));
3079 assert(v8i32 &&
"expected type conversion to succeed");
3080 Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32);
3082 for (
auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) {
3084 LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant);
3090 Value setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3091 ConversionPatternRewriter &rewriter, Location loc,
3092 Value sgpr0, ArrayRef<Value> consts, int64_t dimX,
3093 int64_t offset)
const {
3094 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3095 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3096 SmallVector<OpFoldResult> mixedGlobalSizes =
3098 if (mixedGlobalSizes.size() <=
static_cast<unsigned long>(dimX))
3101 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3103 if (
auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3107 IntegerType i32 = rewriter.getI32Type();
3108 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3109 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3112 return setValueAtOffset(rewriter, loc, sgpr0, tensorDimX, offset);
3115 Value setTensorDim2(DescriptorOp op, OpAdaptor adaptor,
3116 ConversionPatternRewriter &rewriter, Location loc,
3117 Value sgpr0, ArrayRef<Value> consts)
const {
3118 return setTensorDimX(op, adaptor, rewriter, loc, sgpr0, consts, 2, 0);
3121 Value truncateAndSetValueAtOffset(ConversionPatternRewriter &rewriter,
3122 Location loc, Value accumulator,
3123 Value value, int64_t shift)
const {
3125 IntegerType i32 = rewriter.getI32Type();
3126 value = LLVM::TruncOp::create(rewriter, loc, i32, value);
3127 return setValueAtOffset(rewriter, loc, accumulator, value, shift);
3130 Value setLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3131 ConversionPatternRewriter &rewriter, Location loc,
3132 Value sgpr1, ArrayRef<Value> consts,
3133 int64_t offset)
const {
3134 Value ldsAddrIncrement = adaptor.getLdsIncrement();
3135 return setValueAtOffset(rewriter, loc, sgpr1, ldsAddrIncrement, offset);
3138 std::pair<Value, Value>
3139 setGlobalAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3140 ConversionPatternRewriter &rewriter, Location loc,
3141 Value sgpr2, Value sgpr3, ArrayRef<Value> consts,
3142 int64_t offset)
const {
3143 Value globalAddrIncrement = adaptor.getGlobalIncrement();
3144 sgpr2 = truncateAndSetValueAtOffset(rewriter, loc, sgpr2,
3145 globalAddrIncrement, offset);
3147 globalAddrIncrement =
3148 LLVM::LShrOp::create(rewriter, loc, globalAddrIncrement, shift);
3149 constexpr int64_t first16BitsHigh = (1ll << 16) - 1;
3150 sgpr3 = truncateAndSetValueAtOffset(rewriter, loc, sgpr3,
3151 globalAddrIncrement, offset + 32);
3153 sgpr3 = LLVM::AndOp::create(rewriter, loc, sgpr3, mask);
3154 return {sgpr2, sgpr3};
3157 Value setTensorDim3OrLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3158 ConversionPatternRewriter &rewriter,
3159 Location loc, Value sgpr1,
3160 ArrayRef<Value> consts)
const {
3161 Value ldsIncrement = op.getLdsIncrement();
3162 constexpr int64_t dim = 3;
3163 constexpr int64_t offset = 32;
3165 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, consts, dim,
3167 return setLDSAddrIncrement(op, adaptor, rewriter, loc, sgpr1, consts,
3171 std::pair<Value, Value> setTensorDim2StrideOrGlobalAddrIncrement(
3172 DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
3173 Location loc, Value sgpr2, Value sgpr3, ArrayRef<Value> consts)
const {
3174 Value globalIncrement = op.getGlobalIncrement();
3175 constexpr int32_t dim = 2;
3176 constexpr int32_t offset = 64;
3177 if (!globalIncrement)
3178 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3179 consts, dim, offset);
3180 return setGlobalAddrIncrement(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3184 Value setIterateCount(DescriptorOp op, OpAdaptor adaptor,
3185 ConversionPatternRewriter &rewriter, Location loc,
3186 Value sgpr3, ArrayRef<Value> consts,
3187 int32_t offset)
const {
3188 Value iterationCount = adaptor.getIterationCount();
3189 IntegerType i32 = rewriter.getI32Type();
3196 iterationCount = LLVM::TruncOp::create(rewriter, loc, i32, iterationCount);
3198 LLVM::SubOp::create(rewriter, loc, iterationCount, consts[1]);
3199 return setValueAtOffset(rewriter, loc, sgpr3, iterationCount, offset);
3202 Value setTileDim3OrIterateCount(DescriptorOp op, OpAdaptor adaptor,
3203 ConversionPatternRewriter &rewriter,
3204 Location loc, Value sgpr3,
3205 ArrayRef<Value> consts)
const {
3206 Value iterateCount = op.getIterationCount();
3207 constexpr int32_t dim = 2;
3208 constexpr int32_t offset = 112;
3210 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, dim,
3213 return setIterateCount(op, adaptor, rewriter, loc, sgpr3, consts, offset);
3216 Value getDGroup2(DescriptorOp op, OpAdaptor adaptor,
3217 ConversionPatternRewriter &rewriter, Location loc,
3218 ArrayRef<Value> consts)
const {
3219 if constexpr (DescriptorOp::isGather())
3220 return getDGroup2Gather(op, adaptor, rewriter, loc, consts);
3221 return getDGroup2NonGather(op, adaptor, rewriter, loc, consts);
3224 Value getDGroup2NonGather(DescriptorOp op, OpAdaptor adaptor,
3225 ConversionPatternRewriter &rewriter, Location loc,
3226 ArrayRef<Value> consts)
const {
3227 IntegerType i32 = rewriter.getI32Type();
3228 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3229 assert(v4i32 &&
"expected type conversion to succeed.");
3231 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
3232 if (onlyNeedsTwoDescriptors)
3233 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
3235 constexpr int64_t sgprlen = 4;
3236 Value sgprs[sgprlen];
3237 for (
int i = 0; i < sgprlen; ++i)
3238 sgprs[i] = consts[0];
3240 sgprs[0] = setTensorDim2(op, adaptor, rewriter, loc, sgprs[0], consts);
3241 sgprs[1] = setTensorDim3OrLDSAddrIncrement(op, adaptor, rewriter, loc,
3243 std::tie(sgprs[2], sgprs[3]) = setTensorDim2StrideOrGlobalAddrIncrement(
3244 op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3246 setTileDim3OrIterateCount(op, adaptor, rewriter, loc, sgprs[3], consts);
3248 Value dgroup2 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3249 for (
auto [sgpr, constant] : llvm::zip(sgprs, consts))
3251 LLVM::InsertElementOp::create(rewriter, loc, dgroup2, sgpr, constant);
3256 Value getGatherIndices(DescriptorOp op, OpAdaptor adaptor,
3257 ConversionPatternRewriter &rewriter, Location loc,
3258 ArrayRef<Value> consts,
bool firstHalf)
const {
3259 IntegerType i32 = rewriter.getI32Type();
3260 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3261 assert(v4i32 &&
"expected type conversion to succeed.");
3263 Value
indices = adaptor.getIndices();
3264 auto vectorType = cast<VectorType>(
indices.getType());
3265 unsigned length = vectorType.getShape().back();
3266 Type elementType = vectorType.getElementType();
3267 unsigned maxLength = elementType == i32 ? 4 : 8;
3268 int32_t offset = firstHalf ? 0 : maxLength;
3269 unsigned discountedLength =
3270 std::max(
static_cast<int32_t
>(length - offset), 0);
3272 unsigned targetSize = std::min(maxLength, discountedLength);
3274 SmallVector<Value> indicesVector;
3275 for (
unsigned i = offset; i < targetSize + offset; ++i) {
3277 if (i < consts.size())
3281 Value elem = LLVM::ExtractElementOp::create(rewriter, loc,
indices, idx);
3282 indicesVector.push_back(elem);
3285 SmallVector<Value> indicesI32Vector;
3286 if (elementType == i32) {
3287 indicesI32Vector = indicesVector;
3289 for (
unsigned i = 0; i < targetSize; ++i) {
3290 Value index = indicesVector[i];
3291 indicesI32Vector.push_back(
3292 LLVM::ZExtOp::create(rewriter, loc, i32, index));
3294 if ((targetSize % 2) != 0)
3296 indicesI32Vector.push_back(consts[0]);
3299 SmallVector<Value> indicesToInsert;
3300 if (elementType == i32) {
3301 indicesToInsert = indicesI32Vector;
3303 unsigned size = indicesI32Vector.size() / 2;
3304 for (
unsigned i = 0; i < size; ++i) {
3305 Value first = indicesI32Vector[2 * i];
3306 Value second = indicesI32Vector[2 * i + 1];
3307 Value joined = setValueAtOffset(rewriter, loc, first, second, 16);
3308 indicesToInsert.push_back(joined);
3312 Value dgroup = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3313 for (
auto [sgpr, constant] : llvm::zip_first(indicesToInsert, consts))
3315 LLVM::InsertElementOp::create(rewriter, loc, dgroup, sgpr, constant);
3320 Value getDGroup2Gather(DescriptorOp op, OpAdaptor adaptor,
3321 ConversionPatternRewriter &rewriter, Location loc,
3322 ArrayRef<Value> consts)
const {
3323 return getGatherIndices(op, adaptor, rewriter, loc, consts,
true);
3326 std::pair<Value, Value>
3327 setTensorDim3Stride(DescriptorOp op, OpAdaptor adaptor,
3328 ConversionPatternRewriter &rewriter, Location loc,
3329 Value sgpr0, Value sgpr1, ArrayRef<Value> consts)
const {
3330 constexpr int32_t dim = 3;
3331 constexpr int32_t offset = 0;
3332 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr0, sgpr1, consts,
3336 std::pair<Value, Value> setTensorDim4(DescriptorOp op, OpAdaptor adaptor,
3337 ConversionPatternRewriter &rewriter,
3338 Location loc, Value sgpr1, Value sgpr2,
3339 ArrayRef<Value> consts)
const {
3340 constexpr int32_t dim = 4;
3341 constexpr int32_t offset = 48;
3342 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, dim,
3346 Value setTileDim4(DescriptorOp op, OpAdaptor adaptor,
3347 ConversionPatternRewriter &rewriter, Location loc,
3348 Value sgpr2, ArrayRef<Value> consts)
const {
3349 constexpr int32_t dim = 4;
3350 constexpr int32_t offset = 80;
3351 return setTileDimX(op, adaptor, rewriter, loc, sgpr2, consts, dim, offset);
3354 Value getDGroup3(DescriptorOp op, OpAdaptor adaptor,
3355 ConversionPatternRewriter &rewriter, Location loc,
3356 ArrayRef<Value> consts)
const {
3357 if constexpr (DescriptorOp::isGather())
3358 return getDGroup3Gather(op, adaptor, rewriter, loc, consts);
3359 return getDGroup3NonGather(op, adaptor, rewriter, loc, consts);
3362 Value getDGroup3NonGather(DescriptorOp op, OpAdaptor adaptor,
3363 ConversionPatternRewriter &rewriter, Location loc,
3364 ArrayRef<Value> consts)
const {
3365 IntegerType i32 = rewriter.getI32Type();
3366 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3367 assert(v4i32 &&
"expected type conversion to succeed.");
3368 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
3369 if (onlyNeedsTwoDescriptors)
3370 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
3372 constexpr int32_t sgprlen = 4;
3373 Value sgprs[sgprlen];
3374 for (
int i = 0; i < sgprlen; ++i)
3375 sgprs[i] = consts[0];
3377 std::tie(sgprs[0], sgprs[1]) = setTensorDim3Stride(
3378 op, adaptor, rewriter, loc, sgprs[0], sgprs[1], consts);
3379 std::tie(sgprs[1], sgprs[2]) =
3380 setTensorDim4(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3381 sgprs[2] = setTileDim4(op, adaptor, rewriter, loc, sgprs[2], consts);
3383 Value dgroup3 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3384 for (
auto [sgpr, constant] : llvm::zip(sgprs, consts))
3386 LLVM::InsertElementOp::create(rewriter, loc, dgroup3, sgpr, constant);
3391 Value getDGroup3Gather(DescriptorOp op, OpAdaptor adaptor,
3392 ConversionPatternRewriter &rewriter, Location loc,
3393 ArrayRef<Value> consts)
const {
3394 return getGatherIndices(op, adaptor, rewriter, loc, consts,
false);
3398 matchAndRewrite(DescriptorOp op, OpAdaptor adaptor,
3399 ConversionPatternRewriter &rewriter)
const override {
3401 return op->emitOpError(
3402 "make_dma_descriptor is only supported on gfx1250");
3404 Location loc = op.getLoc();
3406 SmallVector<Value> consts;
3407 for (int64_t i = 0; i < 8; ++i)
3410 Value dgroup0 = this->getDGroup0(adaptor);
3411 Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts);
3412 Value dgroup2 = this->getDGroup2(op, adaptor, rewriter, loc, consts);
3413 Value dgroup3 = this->getDGroup3(op, adaptor, rewriter, loc, consts);
3414 SmallVector<Value> results = {dgroup0, dgroup1, dgroup2, dgroup3};
3415 rewriter.replaceOpWithMultiple(op, {results});
3420template <
typename SourceOp,
typename TargetOp>
3421struct AMDGPUTensorLoadStoreOpLowering
3422 :
public ConvertOpToLLVMPattern<SourceOp> {
3423 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
3425 AMDGPUTensorLoadStoreOpLowering(
const LLVMTypeConverter &converter,
3427 : ConvertOpToLLVMPattern<SourceOp>(converter), chipset(chipset) {}
3431 matchAndRewrite(SourceOp op, Adaptor adaptor,
3432 ConversionPatternRewriter &rewriter)
const override {
3434 return op->emitOpError(
"is only supported on gfx1250");
3437 rewriter.replaceOpWithNewOp<TargetOp>(op, desc[0], desc[1], desc[2],
3446struct ConvertAMDGPUToROCDLPass
3447 :
public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
3450 void runOnOperation()
override {
3453 if (
failed(maybeChipset)) {
3454 emitError(UnknownLoc::get(ctx),
"Invalid chipset name: " + chipset);
3455 return signalPassFailure();
3459 LLVMTypeConverter converter(ctx);
3462 amdgpu::populateCommonGPUTypeAndAttributeConversions(converter);
3464 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
3465 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
3466 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
3467 if (
failed(applyPartialConversion(getOperation(),
target,
3469 signalPassFailure();
3477 typeConverter, [](gpu::AddressSpace space) {
3479 case gpu::AddressSpace::Global:
3480 return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace;
3481 case gpu::AddressSpace::Workgroup:
3482 return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace;
3483 case gpu::AddressSpace::Private:
3484 return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace;
3486 llvm_unreachable(
"unknown address space enum value");
3492 typeConverter.addTypeAttributeConversion(
3494 -> TypeConverter::AttributeConversionResult {
3496 Type i64 = IntegerType::get(ctx, 64);
3497 switch (as.getValue()) {
3498 case amdgpu::AddressSpace::FatRawBuffer:
3499 return IntegerAttr::get(i64, 7);
3500 case amdgpu::AddressSpace::BufferRsrc:
3501 return IntegerAttr::get(i64, 8);
3502 case amdgpu::AddressSpace::FatStructuredBuffer:
3503 return IntegerAttr::get(i64, 9);
3505 return TypeConverter::AttributeConversionResult::abort();
3507 typeConverter.addConversion([&](TDMBaseType type) ->
Type {
3509 return typeConverter.convertType(VectorType::get(4, i32));
3511 typeConverter.addConversion([&](TDMGatherBaseType type) ->
Type {
3513 return typeConverter.convertType(VectorType::get(4, i32));
3515 typeConverter.addConversion(
3516 [&](TDMDescriptorType type,
3519 Type v4i32 = typeConverter.convertType(VectorType::get(4, i32));
3520 Type v8i32 = typeConverter.convertType(VectorType::get(8, i32));
3521 llvm::append_values(
result, v4i32, v8i32, v4i32, v4i32);
3531 if (inputs.size() != 1)
3534 if (!isa<TDMDescriptorType>(inputs[0].
getType()))
3537 auto cast = UnrealizedConversionCastOp::create(builder, loc, types, inputs);
3538 return cast.getResults();
3541 typeConverter.addTargetMaterialization(addUnrealizedCast);
3549 .add<FatRawBufferCastLowering,
3550 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
3551 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
3552 RawBufferOpLowering<RawBufferAtomicFaddOp,
3553 ROCDL::RawPtrBufferAtomicFaddOp>,
3554 RawBufferOpLowering<RawBufferAtomicFmaxOp,
3555 ROCDL::RawPtrBufferAtomicFmaxOp>,
3556 RawBufferOpLowering<RawBufferAtomicSmaxOp,
3557 ROCDL::RawPtrBufferAtomicSmaxOp>,
3558 RawBufferOpLowering<RawBufferAtomicUminOp,
3559 ROCDL::RawPtrBufferAtomicUminOp>,
3560 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
3561 ROCDL::RawPtrBufferAtomicCmpSwap>,
3562 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
3563 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
3564 SparseMFMAOpLowering, WMMAOpLowering, ScaledWMMAOpLowering,
3565 ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
3566 ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
3567 PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
3568 GatherToLDSOpLowering, TransposeLoadOpLowering,
3569 AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
3570 AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
3571 AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
3572 AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>,
3573 AMDGPUTensorLoadStoreOpLowering<TensorLoadToLDSOp,
3574 ROCDL::TensorLoadToLDSOp>,
3575 AMDGPUTensorLoadStoreOpLowering<TensorStoreFromLDSOp,
3576 ROCDL::TensorStoreFromLDSOp>>(
3577 converter, chipset);
3578 patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k, bool isRDNA3)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma for RDNA3/4 architectures.
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
constexpr Chipset kGfx908
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs, StringRef attrName)
Push an input operand.
constexpr Chipset kGfx1250
static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA/WMMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to R...
constexpr Chipset kGfx90a
static std::optional< StringRef > getScaledWmmaIntrinsicName(int64_t m, int64_t n, int64_t k, bool isScale16)
Determines the ROCDL intrinsic name for scaled WMMA based on dimensions and scale block size (16 or 3...
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static std::optional< StringRef > smfmacOpToIntrinsic(SparseMFMAOp op, Chipset chipset)
Returns the rocdl intrinsic corresponding to a SparseMFMA (smfmac) operation if one exists.
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value)
static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts sparse MFMA (smfmac) operands to the expected ROCDL types.
static std::optional< StringRef > wmmaOpToIntrinsicGfx1250(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k)
Return the rocdl intrinsic corresponding to a WMMA operation wmma for the gfx1250 architecture.
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, int64_t elementByteWidth, amdgpu::Chipset chipset, bool boundsCheck)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Pack small float vector operands (fp4/fp6/fp8/bf16) into the format expected by scaled matrix multipl...
static std::optional< uint32_t > getWmmaScaleFormat(Type elemType)
Maps f8 scale element types to WMMA scale format codes.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static std::optional< uint32_t > smallFloatTypeToFormatCode(Type mlirElemType)
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i64.
constexpr Chipset kGfx950
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
typename SourceOp::Adaptor OpAdaptor
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
NamedAttribute represents a combination of a name and an Attribute value.
This class helps build Operations.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
::mlir::Pass::Option< std::string > chipset
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
bool hasOcpFp8(const Chipset &chipset)
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces and types,...
llvm::TypeSwitch< T, ResultT > TypeSwitch
void populateAMDGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.