29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/Support/Casting.h"
32#include "llvm/Support/ErrorHandling.h"
36#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
37#include "mlir/Conversion/Passes.h.inc"
53 IntegerType i32 = rewriter.getI32Type();
55 auto valTy = cast<IntegerType>(val.
getType());
58 return valTy.getWidth() > 32
59 ?
Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
60 :
Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
65 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
71 IntegerType i64 = rewriter.getI64Type();
73 auto valTy = cast<IntegerType>(val.
getType());
76 return valTy.getWidth() > 64
77 ?
Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
78 :
Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
83 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
90 IntegerType i32 = rewriter.getI32Type();
92 for (
auto [i, increment, stride] : llvm::enumerate(
indices, strides)) {
95 ShapedType::isDynamic(stride)
97 memRefDescriptor.
stride(rewriter, loc, i))
98 : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
99 increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
111 MemRefType memrefType,
115 if (chipset >=
kGfx1250 && !boundsCheck) {
116 constexpr int64_t first45bits = (1ll << 45) - 1;
119 if (memrefType.hasStaticShape() &&
120 !llvm::any_of(strides, ShapedType::isDynamic)) {
121 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
123 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
124 size = std::max(
shape[i] * strides[i], size);
125 size = size * elementByteWidth;
129 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
130 Value size = memrefDescriptor.
size(rewriter, loc, i);
131 Value stride = memrefDescriptor.
stride(rewriter, loc, i);
132 Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
134 ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
139 return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
145 Value cacheSwizzleStride =
nullptr,
146 unsigned addressSpace = 8) {
150 Type i16 = rewriter.getI16Type();
153 Value cacheStrideZext =
154 LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
155 Value swizzleBit = LLVM::ConstantOp::create(
156 rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
157 stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
160 stride = LLVM::ConstantOp::create(rewriter, loc, i16,
161 rewriter.getI16IntegerAttr(0));
190 flags |= (7 << 12) | (4 << 15);
193 uint32_t oob = boundsCheck ? 3 : 2;
194 flags |= (oob << 28);
199 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
200 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
201 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
206struct FatRawBufferCastLowering
208 FatRawBufferCastLowering(
const LLVMTypeConverter &converter, Chipset chipset)
209 : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
215 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
216 ConversionPatternRewriter &rewriter)
const override {
217 Location loc = op.getLoc();
218 Value memRef = adaptor.getSource();
219 Value unconvertedMemref = op.getSource();
220 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
221 MemRefDescriptor descriptor(memRef);
223 DataLayout dataLayout = DataLayout::closest(op);
224 int64_t elementByteWidth =
227 int64_t unusedOffset = 0;
228 SmallVector<int64_t, 5> strideVals;
229 if (
failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
230 return op.emitOpError(
"Can't lower non-stride-offset memrefs");
232 Value numRecords = adaptor.getValidBytes();
235 getNumRecords(rewriter, loc, memrefType, descriptor, strideVals,
236 elementByteWidth, chipset, adaptor.getBoundsCheck());
239 adaptor.getResetOffset()
240 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
242 : descriptor.alignedPtr(rewriter, loc);
244 Value offset = adaptor.getResetOffset()
245 ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
246 rewriter.getIndexAttr(0))
247 : descriptor.offset(rewriter, loc);
249 bool hasSizes = memrefType.getRank() > 0;
252 Value sizes = hasSizes
253 ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
257 hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
262 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
263 chipset, adaptor.getCacheSwizzleStride(), 7);
265 Value
result = MemRefDescriptor::poison(
267 getTypeConverter()->convertType(op.getResult().getType()));
269 result = LLVM::InsertValueOp::create(rewriter, loc,
result, fatPtr, pos);
270 result = LLVM::InsertValueOp::create(rewriter, loc,
result, fatPtr,
272 result = LLVM::InsertValueOp::create(rewriter, loc,
result, offset,
275 result = LLVM::InsertValueOp::create(rewriter, loc,
result, sizes,
277 result = LLVM::InsertValueOp::create(rewriter, loc,
result, strides,
280 rewriter.replaceOp(op,
result);
286template <
typename GpuOp,
typename Intrinsic>
288 RawBufferOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
289 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
292 static constexpr uint32_t maxVectorOpWidth = 128;
295 matchAndRewrite(GpuOp gpuOp,
typename GpuOp::Adaptor adaptor,
296 ConversionPatternRewriter &rewriter)
const override {
297 Location loc = gpuOp.getLoc();
298 Value memref = adaptor.getMemref();
299 Value unconvertedMemref = gpuOp.getMemref();
300 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
302 if (chipset.majorVersion < 9)
303 return gpuOp.emitOpError(
"raw buffer ops require GCN or higher");
305 Value storeData = adaptor.getODSOperands(0)[0];
306 if (storeData == memref)
310 wantedDataType = storeData.
getType();
312 wantedDataType = gpuOp.getODSResults(0)[0].getType();
314 Value atomicCmpData = Value();
317 Value maybeCmpData = adaptor.getODSOperands(1)[0];
318 if (maybeCmpData != memref)
319 atomicCmpData = maybeCmpData;
322 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
324 Type i32 = rewriter.getI32Type();
327 DataLayout dataLayout = DataLayout::closest(gpuOp);
328 int64_t elementByteWidth =
337 Type llvmBufferValType = llvmWantedDataType;
339 if (
auto floatType = dyn_cast<FloatType>(wantedDataType))
340 llvmBufferValType = this->getTypeConverter()->convertType(
341 rewriter.getIntegerType(floatType.getWidth()));
343 if (
auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
344 uint32_t vecLen = dataVector.getNumElements();
347 uint32_t totalBits = elemBits * vecLen;
349 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
350 if (totalBits > maxVectorOpWidth)
351 return gpuOp.emitOpError(
352 "Total width of loads or stores must be no more than " +
353 Twine(maxVectorOpWidth) +
" bits, but we call for " +
355 " bits. This should've been caught in validation");
356 if (!usePackedFp16 && elemBits < 32) {
357 if (totalBits > 32) {
358 if (totalBits % 32 != 0)
359 return gpuOp.emitOpError(
"Load or store of more than 32-bits that "
360 "doesn't fit into words. Can't happen\n");
361 llvmBufferValType = this->typeConverter->convertType(
362 VectorType::get(totalBits / 32, i32));
364 llvmBufferValType = this->typeConverter->convertType(
365 rewriter.getIntegerType(totalBits));
369 if (
auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
372 if (vecType.getNumElements() == 1)
373 llvmBufferValType = vecType.getElementType();
376 SmallVector<Value, 6> args;
378 if (llvmBufferValType != llvmWantedDataType) {
379 Value castForStore = LLVM::BitcastOp::create(
380 rewriter, loc, llvmBufferValType, storeData);
381 args.push_back(castForStore);
383 args.push_back(storeData);
388 if (llvmBufferValType != llvmWantedDataType) {
389 Value castForCmp = LLVM::BitcastOp::create(
390 rewriter, loc, llvmBufferValType, atomicCmpData);
391 args.push_back(castForCmp);
393 args.push_back(atomicCmpData);
399 SmallVector<int64_t, 5> strides;
400 if (
failed(memrefType.getStridesAndOffset(strides, offset)))
401 return gpuOp.emitOpError(
"Can't lower non-stride-offset memrefs");
403 MemRefDescriptor memrefDescriptor(memref);
405 Value ptr = memrefDescriptor.bufferPtr(
406 rewriter, loc, *this->getTypeConverter(), memrefType);
408 getNumRecords(rewriter, loc, memrefType, memrefDescriptor, strides,
409 elementByteWidth, chipset, adaptor.getBoundsCheck());
411 adaptor.getBoundsCheck(), chipset);
412 args.push_back(resource);
416 adaptor.getIndices(), strides);
417 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
418 indexOffset && *indexOffset > 0) {
420 voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
424 voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
425 args.push_back(voffset);
428 Value sgprOffset = adaptor.getSgprOffset();
431 sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
432 args.push_back(sgprOffset);
439 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
441 Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
442 ArrayRef<NamedAttribute>());
445 if (llvmBufferValType != llvmWantedDataType) {
446 replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
451 rewriter.eraseOp(gpuOp);
468static FailureOr<unsigned> encodeWaitcnt(
Chipset chipset,
unsigned vmcnt,
469 unsigned expcnt,
unsigned lgkmcnt) {
471 vmcnt = std::min(15u, vmcnt);
472 expcnt = std::min(7u, expcnt);
473 lgkmcnt = std::min(15u, lgkmcnt);
474 return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
477 vmcnt = std::min(63u, vmcnt);
478 expcnt = std::min(7u, expcnt);
479 lgkmcnt = std::min(15u, lgkmcnt);
480 unsigned lowBits = vmcnt & 0xF;
481 unsigned highBits = (vmcnt >> 4) << 14;
482 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
483 return lowBits | highBits | otherCnts;
486 vmcnt = std::min(63u, vmcnt);
487 expcnt = std::min(7u, expcnt);
488 lgkmcnt = std::min(63u, lgkmcnt);
489 unsigned lowBits = vmcnt & 0xF;
490 unsigned highBits = (vmcnt >> 4) << 14;
491 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
492 return lowBits | highBits | otherCnts;
495 vmcnt = std::min(63u, vmcnt);
496 expcnt = std::min(7u, expcnt);
497 lgkmcnt = std::min(63u, lgkmcnt);
498 return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
503struct MemoryCounterWaitOpLowering
505 MemoryCounterWaitOpLowering(
const LLVMTypeConverter &converter,
507 : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
513 matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
514 ConversionPatternRewriter &rewriter)
const override {
515 if (chipset.majorVersion >= 12) {
516 Location loc = op.getLoc();
517 if (std::optional<int> ds = adaptor.getDs())
518 ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
520 if (std::optional<int>
load = adaptor.getLoad())
521 ROCDL::WaitLoadcntOp::create(rewriter, loc, *
load);
523 if (std::optional<int> store = adaptor.getStore())
524 ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
526 if (std::optional<int> exp = adaptor.getExp())
527 ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
529 if (std::optional<int> tensor = adaptor.getTensor())
530 ROCDL::WaitTensorcntOp::create(rewriter, loc, *tensor);
532 rewriter.eraseOp(op);
536 if (adaptor.getTensor())
537 return op.emitOpError(
"unsupported chipset");
539 auto getVal = [](Attribute attr) ->
unsigned {
541 return cast<IntegerAttr>(attr).getInt();
546 unsigned ds = getVal(adaptor.getDsAttr());
547 unsigned exp = getVal(adaptor.getExpAttr());
549 unsigned vmcnt = 1024;
550 Attribute
load = adaptor.getLoadAttr();
551 Attribute store = adaptor.getStoreAttr();
553 vmcnt = getVal(
load) + getVal(store);
555 vmcnt = getVal(
load);
557 vmcnt = getVal(store);
560 FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
562 return op.emitOpError(
"unsupported chipset");
564 rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
570 LDSBarrierOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
571 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
576 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
577 ConversionPatternRewriter &rewriter)
const override {
578 Location loc = op.getLoc();
581 bool requiresInlineAsm = chipset <
kGfx90a;
584 rewriter.getAttr<LLVM::MMRATagAttr>(
"amdgpu-synchronize-as",
"local");
593 StringRef scope =
"workgroup";
595 auto relFence = LLVM::FenceOp::create(rewriter, loc,
596 LLVM::AtomicOrdering::release, scope);
597 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
598 if (requiresInlineAsm) {
599 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
600 LLVM::AsmDialect::AD_ATT);
601 const char *asmStr =
";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
602 const char *constraints =
"";
603 LLVM::InlineAsmOp::create(
606 asmStr, constraints,
true,
607 false, LLVM::TailCallKind::None,
610 }
else if (chipset.majorVersion < 12) {
611 ROCDL::SBarrierOp::create(rewriter, loc);
613 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
614 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
617 auto acqFence = LLVM::FenceOp::create(rewriter, loc,
618 LLVM::AtomicOrdering::acquire, scope);
619 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
620 rewriter.replaceOp(op, acqFence);
626 SchedBarrierOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
627 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
632 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
633 ConversionPatternRewriter &rewriter)
const override {
634 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
635 (uint32_t)op.getOpts());
659 bool allowBf16 =
true) {
661 if (
auto vectorType = dyn_cast<VectorType>(inputType)) {
662 if (vectorType.getElementType().isBF16() && !allowBf16)
663 return LLVM::BitcastOp::create(
664 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
665 if (vectorType.getElementType().isInteger(8) &&
666 vectorType.getNumElements() <= 8)
667 return LLVM::BitcastOp::create(
669 rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
670 if (isa<IntegerType>(vectorType.getElementType()) &&
671 vectorType.getElementTypeBitWidth() <= 8) {
672 int64_t numWords = llvm::divideCeil(
673 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
675 return LLVM::BitcastOp::create(
676 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
687 bool allowBf16 =
true) {
689 auto vectorType = cast<VectorType>(inputType);
691 if (vectorType.getElementType().isBF16() && !allowBf16)
692 return LLVM::BitcastOp::create(
693 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
695 if (isa<IntegerType>(vectorType.getElementType()) &&
696 vectorType.getElementTypeBitWidth() <= 8) {
697 int64_t numWords = llvm::divideCeil(
698 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32);
699 Type castType = (numWords > 1)
700 ?
Type{VectorType::get(numWords, rewriter.getI32Type())}
701 : rewriter.getI32Type();
702 return LLVM::BitcastOp::create(rewriter, loc, castType, input);
720 .Case([&](IntegerType) {
722 return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(),
725 .Case([&](VectorType vectorType) {
727 int64_t numElements = vectorType.getNumElements();
728 assert((numElements == 4 || numElements == 8) &&
729 "scale operand must be a vector of length 4 or 8");
730 IntegerType outputType =
731 (numElements == 4) ? rewriter.getI32Type() : rewriter.getI64Type();
732 return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
734 .DefaultUnreachable(
"unexpected input type for scale operand");
740 .Case([](Float8E8M0FNUType) {
return 0; })
741 .Case([](Float8E4M3FNType) {
return 2; })
742 .Default(std::nullopt);
747static std::optional<StringRef>
749 if (m == 16 && n == 16 && k == 128)
751 ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
752 : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
754 if (m == 32 && n == 16 && k == 128)
755 return isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
756 : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
770 ConversionPatternRewriter &rewriter,
Location loc,
775 auto vectorType = dyn_cast<VectorType>(inputType);
777 operands.push_back(llvmInput);
780 Type elemType = vectorType.getElementType();
782 operands.push_back(llvmInput);
789 auto mlirInputType = cast<VectorType>(mlirInput.
getType());
790 bool isInputInteger = mlirInputType.getElementType().isInteger();
791 if (isInputInteger) {
793 bool localIsUnsigned = isUnsigned;
795 localIsUnsigned =
true;
797 localIsUnsigned =
false;
800 NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
805 Type i32 = rewriter.getI32Type();
806 Type intrinsicInType = numBits <= 32
807 ? (
Type)rewriter.getIntegerType(numBits)
808 : (
Type)VectorType::get(numBits / 32, i32);
809 auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
810 Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
811 loc, llvmIntrinsicInType, llvmInput);
816 castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
817 operands.push_back(castInput);
830 Value output, int32_t subwordOffset,
834 auto vectorType = dyn_cast<VectorType>(inputType);
835 Type elemType = vectorType.getElementType();
836 operands.push_back(output);
848 return (chipset ==
kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
849 (
hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
855 return (chipset ==
kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
856 (
hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
864 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
865 b = mfma.getBlocks();
870 if (mfma.getReducePrecision() && chipset >=
kGfx942) {
871 if (m == 32 && n == 32 && k == 4 &&
b == 1)
872 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
873 if (m == 16 && n == 16 && k == 8 &&
b == 1)
874 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
876 if (m == 32 && n == 32 && k == 1 &&
b == 2)
877 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
878 if (m == 16 && n == 16 && k == 1 &&
b == 4)
879 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
880 if (m == 4 && n == 4 && k == 1 &&
b == 16)
881 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
882 if (m == 32 && n == 32 && k == 2 &&
b == 1)
883 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
884 if (m == 16 && n == 16 && k == 4 &&
b == 1)
885 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
890 if (m == 32 && n == 32 && k == 16 &&
b == 1)
891 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
892 if (m == 16 && n == 16 && k == 32 &&
b == 1)
893 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
895 if (m == 32 && n == 32 && k == 4 &&
b == 2)
896 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
897 if (m == 16 && n == 16 && k == 4 &&
b == 4)
898 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
899 if (m == 4 && n == 4 && k == 4 &&
b == 16)
900 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
901 if (m == 32 && n == 32 && k == 8 &&
b == 1)
902 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
903 if (m == 16 && n == 16 && k == 16 &&
b == 1)
904 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
909 if (m == 32 && n == 32 && k == 16 &&
b == 1)
910 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
911 if (m == 16 && n == 16 && k == 32 &&
b == 1)
912 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
915 if (m == 32 && n == 32 && k == 4 &&
b == 2)
916 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
917 if (m == 16 && n == 16 && k == 4 &&
b == 4)
918 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
919 if (m == 4 && n == 4 && k == 4 &&
b == 16)
920 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
921 if (m == 32 && n == 32 && k == 8 &&
b == 1)
922 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
923 if (m == 16 && n == 16 && k == 16 &&
b == 1)
924 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
926 if (m == 32 && n == 32 && k == 2 &&
b == 2)
927 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
928 if (m == 16 && n == 16 && k == 2 &&
b == 4)
929 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
930 if (m == 4 && n == 4 && k == 2 &&
b == 16)
931 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
932 if (m == 32 && n == 32 && k == 4 &&
b == 1)
933 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
934 if (m == 16 && n == 16 && k == 8 &&
b == 1)
935 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
940 if (m == 32 && n == 32 && k == 32 &&
b == 1)
941 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
942 if (m == 16 && n == 16 && k == 64 &&
b == 1)
943 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
945 if (m == 32 && n == 32 && k == 4 &&
b == 2)
946 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
947 if (m == 16 && n == 16 && k == 4 &&
b == 4)
948 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
949 if (m == 4 && n == 4 && k == 4 &&
b == 16)
950 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
951 if (m == 32 && n == 32 && k == 8 &&
b == 1)
952 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
953 if (m == 16 && n == 16 && k == 16 &&
b == 1)
954 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
955 if (m == 32 && n == 32 && k == 16 &&
b == 1 && chipset >=
kGfx942)
956 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
957 if (m == 16 && n == 16 && k == 32 &&
b == 1 && chipset >=
kGfx942)
958 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
962 if (m == 16 && n == 16 && k == 4 &&
b == 1)
963 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
964 if (m == 4 && n == 4 && k == 4 &&
b == 4)
965 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
972 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
973 if (m == 16 && n == 16 && k == 32 &&
b == 1) {
975 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
977 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
979 if (m == 32 && n == 32 && k == 16 &&
b == 1) {
981 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
983 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
989 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
990 if (m == 16 && n == 16 && k == 32 &&
b == 1) {
992 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
994 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
996 if (m == 32 && n == 32 && k == 16 &&
b == 1) {
998 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
1000 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
1004 return std::nullopt;
1009 .Case([](Float8E4M3FNType) {
return 0u; })
1010 .Case([](Float8E5M2Type) {
return 1u; })
1011 .Case([](Float6E2M3FNType) {
return 2u; })
1012 .Case([](Float6E3M2FNType) {
return 3u; })
1013 .Case([](Float4E2M1FNType) {
return 4u; })
1014 .Default(std::nullopt);
1024static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1026 uint32_t n, uint32_t k, uint32_t
b,
Chipset chipset) {
1032 return std::nullopt;
1033 if (!isa<Float32Type>(destType))
1034 return std::nullopt;
1038 if (!aTypeCode || !bTypeCode)
1039 return std::nullopt;
1041 if (m == 32 && n == 32 && k == 64 &&
b == 1)
1042 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
1043 *aTypeCode, *bTypeCode};
1044 if (m == 16 && n == 16 && k == 128 &&
b == 1)
1046 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
1049 return std::nullopt;
1052static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1055 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
1056 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
1057 mfma.getBlocks(), chipset);
1060static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1063 smfma.getSourceB().getType(),
1064 smfma.getDestC().getType(), smfma.getM(),
1065 smfma.getN(), smfma.getK(), 1u, chipset);
1070static std::optional<StringRef>
1072 Type elemDestType, uint32_t k,
bool isRDNA3) {
1073 using fp8 = Float8E4M3FNType;
1074 using bf8 = Float8E5M2Type;
1079 if (elemSourceType.
isF16() && elemDestType.
isF32())
1080 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
1081 if (elemSourceType.
isBF16() && elemDestType.
isF32())
1082 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
1083 if (elemSourceType.
isF16() && elemDestType.
isF16())
1084 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1086 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1088 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1093 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1094 return std::nullopt;
1098 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1099 elemDestType.
isF32())
1100 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1101 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1102 elemDestType.
isF32())
1103 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1104 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1105 elemDestType.
isF32())
1106 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1107 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1108 elemDestType.
isF32())
1109 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1111 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1113 return std::nullopt;
1117 if (k == 32 && !isRDNA3) {
1119 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1122 return std::nullopt;
1128 Type elemBSourceType,
1131 using fp8 = Float8E4M3FNType;
1132 using bf8 = Float8E5M2Type;
1135 if (elemSourceType.
isF32() && elemDestType.
isF32())
1136 return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1138 return std::nullopt;
1142 if (elemSourceType.
isF16() && elemDestType.
isF32())
1143 return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
1144 if (elemSourceType.
isBF16() && elemDestType.
isF32())
1145 return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
1146 if (elemSourceType.
isF16() && elemDestType.
isF16())
1147 return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
1149 return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
1151 return std::nullopt;
1155 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1156 if (elemDestType.
isF32())
1157 return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
1158 if (elemDestType.
isF16())
1159 return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
1161 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1162 if (elemDestType.
isF32())
1163 return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
1164 if (elemDestType.
isF16())
1165 return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
1167 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1168 if (elemDestType.
isF32())
1169 return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
1170 if (elemDestType.
isF16())
1171 return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
1173 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1174 if (elemDestType.
isF32())
1175 return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
1176 if (elemDestType.
isF16())
1177 return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
1180 return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
1182 return std::nullopt;
1186 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1187 if (elemDestType.
isF32())
1188 return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
1189 if (elemDestType.
isF16())
1190 return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
1192 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1193 if (elemDestType.
isF32())
1194 return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
1195 if (elemDestType.
isF16())
1196 return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
1198 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1199 if (elemDestType.
isF32())
1200 return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
1201 if (elemDestType.
isF16())
1202 return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
1204 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1205 if (elemDestType.
isF32())
1206 return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
1207 if (elemDestType.
isF16())
1208 return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
1211 return std::nullopt;
1214 return std::nullopt;
1222 bool isGfx950 = chipset >=
kGfx950;
1226 uint32_t m = op.getM(), n = op.getN(), k = op.getK();
1231 if (m == 16 && n == 16 && k == 32) {
1233 return ROCDL::smfmac_f32_16x16x32_f16::getOperationName();
1235 return ROCDL::smfmac_f32_16x16x32_bf16::getOperationName();
1238 if (m == 16 && n == 16 && k == 64) {
1241 return ROCDL::smfmac_f32_16x16x64_f16::getOperationName();
1243 return ROCDL::smfmac_f32_16x16x64_bf16::getOperationName();
1247 return ROCDL::smfmac_i32_16x16x64_i8::getOperationName();
1248 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1249 return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName();
1250 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1251 return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName();
1252 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1253 return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName();
1254 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1255 return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName();
1258 if (m == 16 && n == 16 && k == 128 && isGfx950) {
1261 return ROCDL::smfmac_i32_16x16x128_i8::getOperationName();
1262 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1263 return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName();
1264 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1265 return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName();
1266 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1267 return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName();
1268 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1269 return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName();
1272 if (m == 32 && n == 32 && k == 16) {
1274 return ROCDL::smfmac_f32_32x32x16_f16::getOperationName();
1276 return ROCDL::smfmac_f32_32x32x16_bf16::getOperationName();
1279 if (m == 32 && n == 32 && k == 32) {
1282 return ROCDL::smfmac_f32_32x32x32_f16::getOperationName();
1284 return ROCDL::smfmac_f32_32x32x32_bf16::getOperationName();
1288 return ROCDL::smfmac_i32_32x32x32_i8::getOperationName();
1289 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1290 return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName();
1291 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1292 return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName();
1293 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1294 return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName();
1295 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1296 return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName();
1299 if (m == 32 && n == 32 && k == 64 && isGfx950) {
1302 return ROCDL::smfmac_i32_32x32x64_i8::getOperationName();
1303 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1304 return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName();
1305 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1306 return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName();
1307 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1308 return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName();
1309 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1310 return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName();
1313 return std::nullopt;
1321 auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
1322 auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
1323 auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1324 Type elemSourceType = sourceVectorType.getElementType();
1325 Type elemBSourceType = sourceBVectorType.getElementType();
1326 Type elemDestType = destVectorType.getElementType();
1328 const uint32_t k = wmma.getK();
1333 if (isRDNA3 || isRDNA4)
1342 return std::nullopt;
1355static std::optional<SparseWMMAOpInfo>
1361 uint32_t m = swmmac.getM(), n = swmmac.getN(), k = swmmac.getK();
1363 if ((m != 16) || (n != 16))
1364 return std::nullopt;
1371 ROCDL::swmmac_f32_16x16x32_f16::getOperationName(),
false,
false,
1375 ROCDL::swmmac_f32_16x16x32_bf16::getOperationName(),
false,
false,
1379 ROCDL::swmmac_f16_16x16x32_f16::getOperationName(),
false,
false,
1383 ROCDL::swmmac_bf16_16x16x32_bf16::getOperationName(),
false,
false,
1388 ROCDL::swmmac_i32_16x16x32_iu8::getOperationName(),
true,
false,
1393 ROCDL::swmmac_i32_16x16x32_iu4::getOperationName(),
true,
false,
1398 ROCDL::swmmac_f32_16x16x32_fp8_fp8::getOperationName(),
false,
1403 ROCDL::swmmac_f32_16x16x32_fp8_bf8::getOperationName(),
false,
1408 ROCDL::swmmac_f32_16x16x32_bf8_fp8::getOperationName(),
false,
1412 ROCDL::swmmac_f32_16x16x32_bf8_bf8::getOperationName(),
false,
1419 ROCDL::swmmac_i32_16x16x64_iu4::getOperationName(),
true,
false,
1424 const bool isGFX1250 = chipset ==
kGfx1250;
1425 const bool isWavesize64 = swmmac.getWave64();
1426 if (isGFX1250 && !isWavesize64) {
1430 ROCDL::swmmac_f32_16x16x64_f16::getOperationName(),
true,
true,
1434 ROCDL::swmmac_f32_16x16x64_bf16::getOperationName(),
true,
true,
1438 ROCDL::swmmac_f16_16x16x64_f16::getOperationName(),
true,
true,
1442 ROCDL::swmmac_bf16_16x16x64_bf16::getOperationName(),
true,
true,
1449 ROCDL::swmmac_f32_16x16x128_fp8_fp8::getOperationName(),
false,
1454 ROCDL::swmmac_f32_16x16x128_fp8_bf8::getOperationName(),
false,
1459 ROCDL::swmmac_f32_16x16x128_bf8_fp8::getOperationName(),
false,
1463 ROCDL::swmmac_f32_16x16x128_bf8_bf8::getOperationName(),
false,
1468 ROCDL::swmmac_f16_16x16x128_fp8_fp8::getOperationName(),
false,
1473 ROCDL::swmmac_f16_16x16x128_fp8_bf8::getOperationName(),
false,
1478 ROCDL::swmmac_f16_16x16x128_bf8_fp8::getOperationName(),
false,
1482 ROCDL::swmmac_f16_16x16x128_bf8_bf8::getOperationName(),
false,
1487 ROCDL::swmmac_f16_16x16x128_bf8_bf8::getOperationName(),
false,
1492 ROCDL::swmmac_i32_16x16x128_iu8::getOperationName(),
true,
true,
1497 return std::nullopt;
1502 MFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1503 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
1508 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1509 ConversionPatternRewriter &rewriter)
const override {
1510 Location loc = op.getLoc();
1511 Type outType = typeConverter->convertType(op.getDestD().getType());
1512 Type intrinsicOutType = outType;
1513 if (
auto outVecType = dyn_cast<VectorType>(outType))
1514 if (outVecType.getElementType().isBF16())
1515 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
1517 if (chipset.majorVersion != 9 || chipset <
kGfx908)
1518 return op->emitOpError(
"MFMA only supported on gfx908+");
1519 uint32_t getBlgpField =
static_cast<uint32_t
>(op.getBlgp());
1520 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1522 return op.emitOpError(
"negation unsupported on older than gfx942");
1524 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1527 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1529 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1530 return op.emitOpError(
"no intrinsic matching MFMA size on given chipset");
1533 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1535 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1536 return op.emitOpError(
1537 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1538 "be scaled as those fields are used for type information");
1541 StringRef intrinsicName =
1542 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1545 bool allowBf16 = [&]() {
1550 return intrinsicName.contains(
"16x16x32.bf16") ||
1551 intrinsicName.contains(
"32x32x16.bf16");
1553 OperationState loweredOp(loc, intrinsicName);
1554 loweredOp.addTypes(intrinsicOutType);
1556 rewriter, loc, adaptor.getSourceA(), allowBf16),
1558 rewriter, loc, adaptor.getSourceB(), allowBf16),
1559 adaptor.getDestC()});
1562 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1563 loweredOp.addOperands({zero, zero});
1564 loweredOp.addAttributes({{
"cbsz", rewriter.getI32IntegerAttr(aTypeCode)},
1565 {
"blgp", rewriter.getI32IntegerAttr(bTypeCode)},
1566 {
"opselA", rewriter.getI32IntegerAttr(0)},
1567 {
"opselB", rewriter.getI32IntegerAttr(0)}});
1569 loweredOp.addAttributes(
1570 {{
"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1571 {
"abid", rewriter.getI32IntegerAttr(op.getAbid())},
1572 {
"blgp", rewriter.getI32IntegerAttr(getBlgpField)}});
1574 Value lowered = rewriter.create(loweredOp)->getResult(0);
1575 if (outType != intrinsicOutType)
1576 lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1577 rewriter.replaceOp(op, lowered);
1583 ScaledMFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1584 : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1589 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1590 ConversionPatternRewriter &rewriter)
const override {
1591 Location loc = op.getLoc();
1592 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1594 if (chipset.majorVersion != 9 || chipset <
kGfx950)
1595 return op->emitOpError(
"scaled MFMA only supported on gfx908+");
1596 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1598 if (!maybeScaledIntrinsic.has_value())
1599 return op.emitOpError(
1600 "no intrinsic matching scaled MFMA size on given chipset");
1602 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1603 OperationState loweredOp(loc, intrinsicName);
1604 loweredOp.addTypes(intrinsicOutType);
1605 loweredOp.addOperands(
1608 adaptor.getDestC()});
1609 loweredOp.addOperands(
1614 loweredOp.addAttributes(
1615 {{
"cbsz", rewriter.getI32IntegerAttr(aTypeCode)},
1616 {
"blgp", rewriter.getI32IntegerAttr(bTypeCode)},
1617 {
"opselA", rewriter.getI32IntegerAttr(adaptor.getScalesIdxA())},
1618 {
"opselB", rewriter.getI32IntegerAttr(adaptor.getScalesIdxB())}});
1620 Value lowered = rewriter.create(loweredOp)->getResult(0);
1621 rewriter.replaceOp(op, lowered);
1627 SparseMFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1628 : ConvertOpToLLVMPattern<SparseMFMAOp>(converter), chipset(chipset) {}
1633 matchAndRewrite(SparseMFMAOp op, SparseMFMAOpAdaptor adaptor,
1634 ConversionPatternRewriter &rewriter)
const override {
1635 Location loc = op.getLoc();
1637 typeConverter->convertType<VectorType>(op.getDestC().
getType());
1639 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1642 if (chipset.majorVersion != 9 || chipset <
kGfx942)
1643 return op->emitOpError(
"sparse MFMA (smfmac) only supported on gfx942+");
1644 bool isGfx950 = chipset >=
kGfx950;
1650 Value c = adaptor.getDestC();
1653 if (!maybeIntrinsic.has_value())
1654 return op.emitOpError(
1655 "no intrinsic matching sparse MFMA on the given chipset");
1658 Value sparseIdx = LLVM::BitcastOp::create(
1659 rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
1661 OperationState loweredOp(loc, maybeIntrinsic.value());
1662 loweredOp.addTypes(outType);
1663 loweredOp.addOperands({a,
b, c, sparseIdx});
1664 loweredOp.addAttributes(
1665 {{
"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1666 {
"abid", rewriter.getI32IntegerAttr(op.getAbid())}});
1667 Value lowered = rewriter.create(loweredOp)->getResult(0);
1668 rewriter.replaceOp(op, lowered);
1674 WMMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1675 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1680 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1681 ConversionPatternRewriter &rewriter)
const override {
1682 Location loc = op.getLoc();
1684 typeConverter->convertType<VectorType>(op.getDestD().
getType());
1686 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1688 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1689 return op->emitOpError(
"WMMA only supported on gfx11 and gfx12");
1691 bool isGFX1250 = chipset >=
kGfx1250;
1696 auto aType = cast<VectorType>(adaptor.getSourceA().getType());
1697 auto bType = cast<VectorType>(adaptor.getSourceB().getType());
1698 auto destCType = cast<VectorType>(adaptor.getDestC().getType());
1699 bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
1700 bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
1701 bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
1702 bool castOutToI16 = outType.getElementType().
isBF16() && !isGFX1250;
1703 VectorType rawOutType = outType;
1705 rawOutType = outType.clone(rewriter.getI16Type());
1706 Value a = adaptor.getSourceA();
1708 a = LLVM::BitcastOp::create(rewriter, loc,
1709 aType.clone(rewriter.getI16Type()), a);
1710 Value
b = adaptor.getSourceB();
1712 b = LLVM::BitcastOp::create(rewriter, loc,
1713 bType.clone(rewriter.getI16Type()),
b);
1714 Value destC = adaptor.getDestC();
1716 destC = LLVM::BitcastOp::create(
1717 rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
1721 if (!maybeIntrinsic.has_value())
1722 return op.emitOpError(
"no intrinsic matching WMMA on the given chipset");
1724 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1725 return op.emitOpError(
"subwordOffset not supported on gfx12+");
1727 SmallVector<Value, 4> operands;
1728 SmallVector<NamedAttribute, 4> attrs;
1730 op.getSourceA(), operands, attrs,
"signA");
1732 op.getSourceB(), operands, attrs,
"signB");
1734 op.getSubwordOffset(), op.getClamp(), operands,
1737 OperationState loweredOp(loc, *maybeIntrinsic);
1738 loweredOp.addTypes(rawOutType);
1739 loweredOp.addOperands(operands);
1740 loweredOp.addAttributes(attrs);
1741 Operation *lowered = rewriter.create(loweredOp);
1743 Operation *maybeCastBack = lowered;
1744 if (rawOutType != outType)
1745 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1747 rewriter.replaceOp(op, maybeCastBack->
getResults());
1754 SparseWMMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1755 : ConvertOpToLLVMPattern<SparseWMMAOp>(converter), chipset(chipset) {}
1760 matchAndRewrite(SparseWMMAOp op, SparseWMMAOpAdaptor adaptor,
1761 ConversionPatternRewriter &rewriter)
const override {
1762 Location loc = op.getLoc();
1764 typeConverter->convertType<VectorType>(op.getDestD().
getType());
1766 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1768 std::optional<SparseWMMAOpInfo> maybeIntrinsic =
1771 if (!maybeIntrinsic.has_value())
1772 return op.emitOpError(
1773 "no intrinsic matching Sparse WMMA on the given chipset");
1774 SparseWMMAOpInfo intrinsic = maybeIntrinsic.value();
1776 SmallVector<NamedAttribute> attrs;
1778 if ((op.getUnsignedA() || op.getUnsignedB()) && !intrinsic.
useSign)
1779 return op->emitOpError(
"intrinsic doesn't support unsign");
1781 if (
auto attr = op.getUnsignedAAttr())
1782 attrs.push_back({
"signA", attr});
1783 if (
auto attr = op.getUnsignedBAttr())
1784 attrs.push_back({
"signB", attr});
1787 if ((op.getReuseA() || op.getReuseB()) && !intrinsic.
useReuse)
1788 return op->emitOpError(
"intrinsic doesn't support reuse");
1790 if (
auto attr = op.getReuseAAttr())
1791 attrs.push_back({
"reuseA", attr});
1792 if (
auto attr = op.getReuseBAttr())
1793 attrs.push_back({
"reuseB", attr});
1796 if (op.getClamp() && !intrinsic.
useClamp)
1797 return op->emitOpError(
"intrinsic doesn't support clamp");
1798 if (intrinsic.
useClamp && op.getClampAttr())
1799 attrs.push_back({
"clamp", op.getClampAttr()});
1801 const bool isGFX1250orHigher =
1802 chipset.majorVersion == 12 && chipset.minorVersion >= 5;
1807 Value c = adaptor.getDestC();
1808 VectorType rawOutType = outType;
1809 if (!isGFX1250orHigher) {
1811 rawOutType = cast<VectorType>(c.
getType());
1815 Value sparseIdx = LLVM::BitcastOp::create(
1816 rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
1818 OperationState loweredOp(loc, intrinsic.
name);
1819 loweredOp.addTypes(rawOutType);
1820 loweredOp.addOperands({a,
b, c, sparseIdx});
1821 loweredOp.addAttributes(attrs);
1822 Operation *lowered = rewriter.create(loweredOp);
1824 Operation *maybeCastBack = lowered;
1825 if (rawOutType != outType)
1826 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1828 rewriter.replaceOp(op, maybeCastBack->
getResults());
1835 ScaledWMMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1836 : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
1841 matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
1842 ConversionPatternRewriter &rewriter)
const override {
1843 Location loc = op.getLoc();
1845 typeConverter->convertType<VectorType>(op.getDestD().
getType());
1847 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1850 return op->emitOpError(
"WMMA scale only supported on gfx1250+");
1852 int64_t m = op.getM();
1853 int64_t n = op.getN();
1854 int64_t k = op.getK();
1862 if (!aFmtCode || !bFmtCode)
1863 return op.emitOpError(
"unsupported element types for scaled_wmma");
1866 auto scaleAVecType = cast<VectorType>(op.getScaleA().getType());
1867 auto scaleBVecType = cast<VectorType>(op.getScaleB().getType());
1869 if (scaleAVecType.getNumElements() != scaleBVecType.getNumElements())
1870 return op.emitOpError(
"scaleA and scaleB must have equal vector length");
1873 Type scaleAElemType = scaleAVecType.getElementType();
1874 Type scaleBElemType = scaleBVecType.getElementType();
1879 if (!scaleAFmt || !scaleBFmt)
1880 return op.emitOpError(
"unsupported scale element types");
1883 bool isScale16 = (scaleAVecType.getNumElements() == 8);
1884 std::optional<StringRef> intrinsicName =
1887 return op.emitOpError(
"unsupported scaled_wmma dimensions: ")
1888 << m <<
"x" << n <<
"x" << k;
1890 SmallVector<NamedAttribute, 8> attrs;
1893 bool is32x16 = (m == 32 && n == 16 && k == 128);
1895 attrs.emplace_back(
"fmtA", rewriter.getI32IntegerAttr(*aFmtCode));
1896 attrs.emplace_back(
"fmtB", rewriter.getI32IntegerAttr(*bFmtCode));
1900 attrs.emplace_back(
"modC", rewriter.getI16IntegerAttr(0));
1905 "scaleAType", rewriter.getI32IntegerAttr(op.getAFirstScaleLane() / 16));
1906 attrs.emplace_back(
"fmtScaleA", rewriter.getI32IntegerAttr(*scaleAFmt));
1908 "scaleBType", rewriter.getI32IntegerAttr(op.getBFirstScaleLane() / 16));
1909 attrs.emplace_back(
"fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt));
1912 attrs.emplace_back(
"reuseA", rewriter.getBoolAttr(
false));
1913 attrs.emplace_back(
"reuseB", rewriter.getBoolAttr(
false));
1926 OperationState loweredOp(loc, *intrinsicName);
1927 loweredOp.addTypes(outType);
1928 loweredOp.addOperands(
1929 {sourceA, sourceB, adaptor.getDestC(), packedScaleA, packedScaleB});
1930 loweredOp.addAttributes(attrs);
1932 Operation *lowered = rewriter.create(loweredOp);
1933 rewriter.replaceOp(op, lowered->
getResults());
1939struct TransposeLoadOpLowering
1941 TransposeLoadOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1942 : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1947 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1948 ConversionPatternRewriter &rewriter)
const override {
1950 return op.emitOpError(
"Non-gfx950 chipset not supported");
1952 Location loc = op.getLoc();
1953 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1957 size_t srcElementSize =
1958 srcMemRefType.getElementType().getIntOrFloatBitWidth();
1959 if (srcElementSize < 8)
1960 return op.emitOpError(
"Expect source memref to have at least 8 bits "
1961 "element size, got ")
1964 auto resultType = cast<VectorType>(op.getResult().getType());
1967 (adaptor.getSrcIndices()));
1969 size_t numElements = resultType.getNumElements();
1970 size_t elementTypeSize =
1971 resultType.getElementType().getIntOrFloatBitWidth();
1975 Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
1976 rewriter.getIntegerType(32));
1977 Type llvmResultType = typeConverter->convertType(resultType);
1979 switch (elementTypeSize) {
1981 assert(numElements == 16);
1982 auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
1983 rocdlResultType, srcPtr);
1984 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1988 assert(numElements == 16);
1989 auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
1990 rocdlResultType, srcPtr);
1991 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1995 assert(numElements == 8);
1996 auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
1997 rocdlResultType, srcPtr);
1998 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2002 assert(numElements == 4);
2003 rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
2008 return op.emitOpError(
"Unsupported element size for transpose load");
2015 GatherToLDSOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2016 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
2021 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
2022 ConversionPatternRewriter &rewriter)
const override {
2023 if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
2024 return op.emitOpError(
"pre-gfx9 and post-gfx10 not supported");
2026 Location loc = op.getLoc();
2028 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2029 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
2034 Type transferType = op.getTransferType();
2035 int loadWidth = [&]() ->
int {
2036 if (
auto transferVectorType = dyn_cast<VectorType>(transferType)) {
2037 return (transferVectorType.getNumElements() *
2038 transferVectorType.getElementTypeBitWidth()) /
2045 if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
2046 return op.emitOpError(
"chipset unsupported element size");
2048 if (chipset !=
kGfx950 && llvm::is_contained({12, 16}, loadWidth))
2049 return op.emitOpError(
"Gather to LDS instructions with 12-byte and "
2050 "16-byte load widths are only supported on gfx950");
2054 (adaptor.getSrcIndices()));
2057 (adaptor.getDstIndices()));
2059 if (op.getAsync()) {
2060 rewriter.replaceOpWithNewOp<ROCDL::LoadAsyncToLDSOp>(
2061 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
2062 rewriter.getI32IntegerAttr(0),
2066 rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
2067 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
2068 rewriter.getI32IntegerAttr(0),
2078struct ExtPackedFp8OpLowering final
2080 ExtPackedFp8OpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2081 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
2086 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
2087 ConversionPatternRewriter &rewriter)
const override;
2090struct ScaledExtPackedMatrixOpLowering final
2092 ScaledExtPackedMatrixOpLowering(
const LLVMTypeConverter &converter,
2094 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
2099 matchAndRewrite(ScaledExtPackedMatrixOp op,
2100 ScaledExtPackedMatrixOpAdaptor adaptor,
2101 ConversionPatternRewriter &rewriter)
const override;
2104struct PackedTrunc2xFp8OpLowering final
2106 PackedTrunc2xFp8OpLowering(
const LLVMTypeConverter &converter,
2108 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
2113 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
2114 ConversionPatternRewriter &rewriter)
const override;
2117struct PackedStochRoundFp8OpLowering final
2119 PackedStochRoundFp8OpLowering(
const LLVMTypeConverter &converter,
2121 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
2126 matchAndRewrite(PackedStochRoundFp8Op op,
2127 PackedStochRoundFp8OpAdaptor adaptor,
2128 ConversionPatternRewriter &rewriter)
const override;
2131struct ScaledExtPackedOpLowering final
2133 ScaledExtPackedOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2134 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
2139 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2140 ConversionPatternRewriter &rewriter)
const override;
2143struct PackedScaledTruncOpLowering final
2145 PackedScaledTruncOpLowering(
const LLVMTypeConverter &converter,
2147 : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
2152 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2153 ConversionPatternRewriter &rewriter)
const override;
2158LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
2159 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
2160 ConversionPatternRewriter &rewriter)
const {
2161 Location loc = op.getLoc();
2163 return rewriter.notifyMatchFailure(
2164 loc,
"Fp8 conversion instructions are not available on target "
2165 "architecture and their emulation is not implemented");
2167 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
2168 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2169 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
2171 Value source = adaptor.getSource();
2172 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
2173 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
2176 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
2177 Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
2178 if (!sourceVecType) {
2179 longVec = LLVM::InsertElementOp::create(
2182 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2184 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2186 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2191 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2192 if (resultVecType) {
2194 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
2197 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
2202 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
2205 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
2212int32_t getScaleSel(int32_t blockSize,
unsigned bitWidth, int32_t scaleWaveHalf,
2213 int32_t firstScaleByte) {
2219 assert(llvm::is_contained({16, 32}, blockSize));
2220 assert(llvm::is_contained({4u, 6u, 8u}, bitWidth));
2222 const bool isFp8 = bitWidth == 8;
2223 const bool isBlock16 = blockSize == 16;
2226 int32_t bit0 = isBlock16;
2227 assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
2228 int32_t bit1 = (firstScaleByte == 2) << 1;
2229 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
2230 int32_t bit2 = scaleWaveHalf << 2;
2231 return bit2 | bit1 | bit0;
2234 int32_t bit0 = isBlock16;
2236 assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
2237 int32_t bits2and1 = firstScaleByte << 1;
2238 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
2239 int32_t bit3 = scaleWaveHalf << 3;
2240 int32_t bits = bit3 | bits2and1 | bit0;
2242 assert(!llvm::is_contained(
2243 {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
2247static std::optional<StringRef>
2248scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
2249 using fp4 = Float4E2M1FNType;
2250 using fp8 = Float8E4M3FNType;
2251 using bf8 = Float8E5M2Type;
2252 using fp6 = Float6E2M3FNType;
2253 using bf6 = Float6E3M2FNType;
2254 if (isa<fp4>(srcElemType)) {
2255 if (destElemType.
isF16())
2256 return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
2257 if (destElemType.
isBF16())
2258 return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
2259 if (destElemType.
isF32())
2260 return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
2261 return std::nullopt;
2263 if (isa<fp8>(srcElemType)) {
2264 if (destElemType.
isF16())
2265 return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
2266 if (destElemType.
isBF16())
2267 return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
2268 if (destElemType.
isF32())
2269 return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
2270 return std::nullopt;
2272 if (isa<bf8>(srcElemType)) {
2273 if (destElemType.
isF16())
2274 return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
2275 if (destElemType.
isBF16())
2276 return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
2277 if (destElemType.
isF32())
2278 return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
2279 return std::nullopt;
2281 if (isa<fp6>(srcElemType)) {
2282 if (destElemType.
isF16())
2283 return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
2284 if (destElemType.
isBF16())
2285 return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
2286 if (destElemType.
isF32())
2287 return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
2288 return std::nullopt;
2290 if (isa<bf6>(srcElemType)) {
2291 if (destElemType.
isF16())
2292 return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
2293 if (destElemType.
isBF16())
2294 return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
2295 if (destElemType.
isF32())
2296 return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
2297 return std::nullopt;
2299 llvm_unreachable(
"invalid combination of element types for packed conversion "
2303LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
2304 ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
2305 ConversionPatternRewriter &rewriter)
const {
2306 using fp4 = Float4E2M1FNType;
2307 using fp8 = Float8E4M3FNType;
2308 using bf8 = Float8E5M2Type;
2309 using fp6 = Float6E2M3FNType;
2310 using bf6 = Float6E3M2FNType;
2311 Location loc = op.getLoc();
2313 return rewriter.notifyMatchFailure(
2315 "Scaled fp packed conversion instructions are not available on target "
2316 "architecture and their emulation is not implemented");
2320 int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
2321 int32_t firstScaleByte = op.getFirstScaleByte();
2322 int32_t blockSize = op.getBlockSize();
2323 auto sourceType = cast<VectorType>(op.getSource().getType());
2324 auto srcElemType = cast<FloatType>(sourceType.getElementType());
2325 unsigned bitWidth = srcElemType.getWidth();
2327 auto targetType = cast<VectorType>(op.getResult().getType());
2328 auto destElemType = cast<FloatType>(targetType.getElementType());
2330 IntegerType i32 = rewriter.getI32Type();
2331 Value source = adaptor.getSource();
2332 Type llvmResultType = typeConverter->convertType(op.getResult().getType());
2333 Type packedType =
nullptr;
2334 if (isa<fp4>(srcElemType)) {
2336 packedType = getTypeConverter()->convertType(packedType);
2337 }
else if (isa<fp8, bf8>(srcElemType)) {
2338 packedType = VectorType::get(2, i32);
2339 packedType = getTypeConverter()->convertType(packedType);
2340 }
else if (isa<fp6, bf6>(srcElemType)) {
2341 packedType = VectorType::get(3, i32);
2342 packedType = getTypeConverter()->convertType(packedType);
2344 llvm_unreachable(
"invalid element type for packed scaled ext");
2347 if (!packedType || !llvmResultType) {
2348 return rewriter.notifyMatchFailure(op,
"type conversion failed");
2351 std::optional<StringRef> maybeIntrinsic =
2352 scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
2353 if (!maybeIntrinsic.has_value())
2354 return op.emitOpError(
2355 "no intrinsic matching packed scaled conversion on the given chipset");
2358 getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
2360 LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
2361 Value castedSource =
2362 LLVM::BitcastOp::create(rewriter, loc, packedType, source);
2364 OperationState loweredOp(loc, *maybeIntrinsic);
2365 loweredOp.addTypes({llvmResultType});
2366 loweredOp.addOperands({castedSource, castedScale});
2368 SmallVector<NamedAttribute, 1> attrs;
2370 NamedAttribute(
"scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
2372 loweredOp.addAttributes(attrs);
2373 Operation *lowered = rewriter.create(loweredOp);
2374 rewriter.replaceOp(op, lowered);
2379LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
2380 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2381 ConversionPatternRewriter &rewriter)
const {
2382 Location loc = op.getLoc();
2384 return rewriter.notifyMatchFailure(
2385 loc,
"Scaled fp conversion instructions are not available on target "
2386 "architecture and their emulation is not implemented");
2387 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2389 Value source = adaptor.getSource();
2390 Value scale = adaptor.getScale();
2392 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2393 Type sourceElemType = sourceVecType.getElementType();
2394 VectorType destVecType = cast<VectorType>(op.getResult().getType());
2395 Type destElemType = destVecType.getElementType();
2397 VectorType packedVecType;
2398 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
2399 VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
2400 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
2401 }
else if (isa<Float4E2M1FNType>(sourceElemType)) {
2402 VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
2403 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
2405 llvm_unreachable(
"invalid element type for scaled ext");
2409 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
2410 Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
2411 if (!sourceVecType) {
2412 longVec = LLVM::InsertElementOp::create(
2415 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2417 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2419 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2424 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2426 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF32())
2427 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
2428 op, destVecType, i32Source, scale, op.getIndex());
2429 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF16())
2430 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
2431 op, destVecType, i32Source, scale, op.getIndex());
2432 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isBF16())
2433 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
2434 op, destVecType, i32Source, scale, op.getIndex());
2435 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF32())
2436 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
2437 op, destVecType, i32Source, scale, op.getIndex());
2438 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF16())
2439 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
2440 op, destVecType, i32Source, scale, op.getIndex());
2441 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isBF16())
2442 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
2443 op, destVecType, i32Source, scale, op.getIndex());
2444 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF32())
2445 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
2446 op, destVecType, i32Source, scale, op.getIndex());
2447 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF16())
2448 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
2449 op, destVecType, i32Source, scale, op.getIndex());
2450 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isBF16())
2451 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
2452 op, destVecType, i32Source, scale, op.getIndex());
2459LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
2460 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2461 ConversionPatternRewriter &rewriter)
const {
2462 Location loc = op.getLoc();
2464 return rewriter.notifyMatchFailure(
2465 loc,
"Scaled fp conversion instructions are not available on target "
2466 "architecture and their emulation is not implemented");
2467 Type v2i16 = getTypeConverter()->convertType(
2468 VectorType::get(2, rewriter.getI16Type()));
2469 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2471 Type resultType = op.getResult().getType();
2473 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2474 Type sourceElemType = sourceVecType.getElementType();
2476 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
2478 Value source = adaptor.getSource();
2479 Value scale = adaptor.getScale();
2480 Value existing = adaptor.getExisting();
2482 existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
2484 existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
2486 if (sourceVecType.getNumElements() < 2) {
2488 Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2489 VectorType v2 = VectorType::get(2, sourceElemType);
2490 source = LLVM::ZeroOp::create(rewriter, loc, v2);
2491 source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
2494 Value sourceA, sourceB;
2495 if (sourceElemType.
isF32()) {
2498 sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2499 sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
2503 if (sourceElemType.
isF32() && isa<Float8E5M2Type>(resultElemType))
2504 result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
2505 existing, sourceA, sourceB,
2506 scale, op.getIndex());
2507 else if (sourceElemType.
isF16() && isa<Float8E5M2Type>(resultElemType))
2508 result = ROCDL::CvtScaleF32PkBf8F16Op::create(
2509 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2510 else if (sourceElemType.
isBF16() && isa<Float8E5M2Type>(resultElemType))
2511 result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
2512 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2513 else if (sourceElemType.
isF32() && isa<Float8E4M3FNType>(resultElemType))
2514 result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
2515 existing, sourceA, sourceB,
2516 scale, op.getIndex());
2517 else if (sourceElemType.
isF16() && isa<Float8E4M3FNType>(resultElemType))
2518 result = ROCDL::CvtScaleF32PkFp8F16Op::create(
2519 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2520 else if (sourceElemType.
isBF16() && isa<Float8E4M3FNType>(resultElemType))
2521 result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
2522 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2523 else if (sourceElemType.
isF32() && isa<Float4E2M1FNType>(resultElemType))
2524 result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
2525 existing, sourceA, sourceB,
2526 scale, op.getIndex());
2527 else if (sourceElemType.
isF16() && isa<Float4E2M1FNType>(resultElemType))
2528 result = ROCDL::CvtScaleF32PkFp4F16Op::create(
2529 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2530 else if (sourceElemType.
isBF16() && isa<Float4E2M1FNType>(resultElemType))
2531 result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
2532 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2536 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2537 op, getTypeConverter()->convertType(resultType),
result);
2541LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
2542 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
2543 ConversionPatternRewriter &rewriter)
const {
2544 Location loc = op.getLoc();
2546 return rewriter.notifyMatchFailure(
2547 loc,
"Fp8 conversion instructions are not available on target "
2548 "architecture and their emulation is not implemented");
2549 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2551 Type resultType = op.getResult().getType();
2554 Value sourceA = adaptor.getSourceA();
2555 Value sourceB = adaptor.getSourceB();
2557 sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.
getType());
2558 Value existing = adaptor.getExisting();
2560 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2562 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2566 result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2567 existing, op.getWordIndex());
2569 result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2570 existing, op.getWordIndex());
2572 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2573 op, getTypeConverter()->convertType(resultType),
result);
2577LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
2578 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
2579 ConversionPatternRewriter &rewriter)
const {
2580 Location loc = op.getLoc();
2582 return rewriter.notifyMatchFailure(
2583 loc,
"Fp8 conversion instructions are not available on target "
2584 "architecture and their emulation is not implemented");
2585 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2587 Type resultType = op.getResult().getType();
2590 Value source = adaptor.getSource();
2591 Value stoch = adaptor.getStochiasticParam();
2592 Value existing = adaptor.getExisting();
2594 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2596 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2600 result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
2601 existing, op.getStoreIndex());
2603 result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
2604 existing, op.getStoreIndex());
2606 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2607 op, getTypeConverter()->convertType(resultType),
result);
2613struct AMDGPUDPPLowering :
public ConvertOpToLLVMPattern<DPPOp> {
2614 AMDGPUDPPLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2615 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
2619 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
2620 ConversionPatternRewriter &rewriter)
const override {
2623 Location loc = DppOp.getLoc();
2624 Value src = adaptor.getSrc();
2625 Value old = adaptor.getOld();
2628 Type llvmType =
nullptr;
2630 llvmType = rewriter.getI32Type();
2631 }
else if (isa<FloatType>(srcType)) {
2633 ? rewriter.getF32Type()
2634 : rewriter.getF64Type();
2635 }
else if (isa<IntegerType>(srcType)) {
2637 ? rewriter.getI32Type()
2638 : rewriter.getI64Type();
2640 auto llvmSrcIntType = typeConverter->convertType(
2644 auto convertOperand = [&](Value operand, Type operandType) {
2645 if (operandType.getIntOrFloatBitWidth() <= 16) {
2646 if (llvm::isa<FloatType>(operandType)) {
2648 LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
2650 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
2651 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
2652 Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
2654 LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
2656 operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
2661 src = convertOperand(src, srcType);
2662 old = convertOperand(old, oldType);
2665 enum DppCtrl :
unsigned {
2674 ROW_HALF_MIRROR = 0x141,
2679 auto kind = DppOp.getKind();
2680 auto permArgument = DppOp.getPermArgument();
2681 uint32_t DppCtrl = 0;
2685 case DPPPerm::quad_perm: {
2686 auto quadPermAttr = cast<ArrayAttr>(*permArgument);
2688 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
2689 uint32_t num = elem.getInt();
2690 DppCtrl |= num << (i * 2);
2695 case DPPPerm::row_shl: {
2696 auto intAttr = cast<IntegerAttr>(*permArgument);
2697 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
2700 case DPPPerm::row_shr: {
2701 auto intAttr = cast<IntegerAttr>(*permArgument);
2702 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
2705 case DPPPerm::row_ror: {
2706 auto intAttr = cast<IntegerAttr>(*permArgument);
2707 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
2710 case DPPPerm::wave_shl:
2711 DppCtrl = DppCtrl::WAVE_SHL1;
2713 case DPPPerm::wave_shr:
2714 DppCtrl = DppCtrl::WAVE_SHR1;
2716 case DPPPerm::wave_rol:
2717 DppCtrl = DppCtrl::WAVE_ROL1;
2719 case DPPPerm::wave_ror:
2720 DppCtrl = DppCtrl::WAVE_ROR1;
2722 case DPPPerm::row_mirror:
2723 DppCtrl = DppCtrl::ROW_MIRROR;
2725 case DPPPerm::row_half_mirror:
2726 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
2728 case DPPPerm::row_bcast_15:
2729 DppCtrl = DppCtrl::BCAST15;
2731 case DPPPerm::row_bcast_31:
2732 DppCtrl = DppCtrl::BCAST31;
2738 auto rowMask = DppOp->getAttrOfType<IntegerAttr>(
"row_mask").getInt();
2739 auto bankMask = DppOp->getAttrOfType<IntegerAttr>(
"bank_mask").getInt();
2740 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>(
"bound_ctrl").getValue();
2744 ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
2745 rowMask, bankMask, boundCtrl);
2747 Value
result = dppMovOp.getRes();
2749 result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType,
result);
2750 if (!llvm::isa<IntegerType>(srcType)) {
2751 result = LLVM::BitcastOp::create(rewriter, loc, srcType,
result);
2762struct AMDGPUSwizzleBitModeLowering
2763 :
public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
2767 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
2768 ConversionPatternRewriter &rewriter)
const override {
2769 Location loc = op.getLoc();
2770 Type i32 = rewriter.getI32Type();
2771 Value src = adaptor.getSrc();
2772 SmallVector<Value> decomposed;
2773 if (
failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
2774 return rewriter.notifyMatchFailure(op,
2775 "failed to decompose value to i32");
2776 unsigned andMask = op.getAndMask();
2777 unsigned orMask = op.getOrMask();
2778 unsigned xorMask = op.getXorMask();
2782 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
2784 SmallVector<Value> swizzled;
2785 for (Value v : decomposed) {
2787 ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
2788 swizzled.emplace_back(res);
2791 Value
result = LLVM::composeValue(rewriter, loc, swizzled, src.
getType());
2792 rewriter.replaceOp(op,
result);
2797struct AMDGPUPermlaneLowering :
public ConvertOpToLLVMPattern<PermlaneSwapOp> {
2800 AMDGPUPermlaneLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2801 : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
2805 matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
2806 ConversionPatternRewriter &rewriter)
const override {
2808 return op->emitOpError(
"permlane_swap is only supported on gfx950+");
2810 Location loc = op.getLoc();
2811 Type i32 = rewriter.getI32Type();
2812 Value src = adaptor.getSrc();
2813 unsigned rowLength = op.getRowLength();
2814 bool fi = op.getFetchInactive();
2815 bool boundctrl = op.getBoundCtrl();
2817 SmallVector<Value> decomposed;
2818 if (
failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
2819 return rewriter.notifyMatchFailure(op,
2820 "failed to decompose value to i32");
2822 SmallVector<Value> permuted;
2823 for (Value v : decomposed) {
2825 Type i32pair = LLVM::LLVMStructType::getLiteral(
2826 rewriter.getContext(), {v.getType(), v.getType()});
2828 if (rowLength == 16)
2829 res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2831 else if (rowLength == 32)
2832 res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2835 llvm_unreachable(
"unsupported row length");
2837 Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
2838 Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
2840 Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
2841 LLVM::ICmpPredicate::eq, vdst0, v);
2846 LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
2847 permuted.emplace_back(vdstNew);
2850 Value
result = LLVM::composeValue(rewriter, loc, permuted, src.
getType());
2851 rewriter.replaceOp(op,
result);
2864constexpr int32_t kDsBarrierPendingCountBitWidth = 29;
2865constexpr int32_t kDsBarrierPhasePos = kDsBarrierPendingCountBitWidth;
2866constexpr int32_t kDsBarrierInitCountPos = 32;
2867constexpr int32_t kDsBarrierPendingCountMask =
2868 (1 << kDsBarrierPendingCountBitWidth) - 1;
2870struct DsBarrierInitOpLowering
2871 :
public ConvertOpToLLVMPattern<DsBarrierInitOp> {
2874 DsBarrierInitOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2875 : ConvertOpToLLVMPattern<DsBarrierInitOp>(converter), chipset(chipset) {}
2878 matchAndRewrite(DsBarrierInitOp op, OpAdaptor adaptor,
2879 ConversionPatternRewriter &rewriter)
const override {
2881 return op->emitOpError(
"only supported on gfx1250+");
2883 Location loc = op.getLoc();
2884 Type i64 = rewriter.getI64Type();
2886 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
2888 adaptor.getBase(), adaptor.getIndices());
2895 LLVM::SubOp::create(rewriter, loc, adaptor.getParticipants(),
2902 Value maskedCount32 =
2903 LLVM::AndOp::create(rewriter, loc, initCount, countMask);
2904 Value maskedCount = LLVM::ZExtOp::create(rewriter, loc, i64, maskedCount32);
2906 Value initCountShifted = LLVM::ShlOp::create(
2907 rewriter, loc, maskedCount,
2909 Value barrierState =
2910 LLVM::OrOp::create(rewriter, loc, initCountShifted, maskedCount);
2912 LLVM::StoreOp::create(
2913 rewriter, loc, barrierState, ptr, 8,
false,
2915 false, LLVM::AtomicOrdering::release,
2918 rewriter.eraseOp(op);
2923struct DsBarrierPollStateOpLowering
2924 :
public ConvertOpToLLVMPattern<DsBarrierPollStateOp> {
2927 DsBarrierPollStateOpLowering(
const LLVMTypeConverter &converter,
2929 : ConvertOpToLLVMPattern<DsBarrierPollStateOp>(converter),
2933 matchAndRewrite(DsBarrierPollStateOp op, OpAdaptor adaptor,
2934 ConversionPatternRewriter &rewriter)
const override {
2936 return op->emitOpError(
"only supported on gfx1250+");
2938 Location loc = op.getLoc();
2939 Type i64 = rewriter.getI64Type();
2941 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
2943 adaptor.getBase(), adaptor.getIndices());
2947 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
2948 op, i64, ptr, 8,
false,
2950 false, LLVM::AtomicOrdering::acquire,
2956struct DsAsyncBarrierArriveOpLowering
2957 :
public ConvertOpToLLVMPattern<DsAsyncBarrierArriveOp> {
2960 DsAsyncBarrierArriveOpLowering(
const LLVMTypeConverter &converter,
2962 : ConvertOpToLLVMPattern<DsAsyncBarrierArriveOp>(converter),
2966 matchAndRewrite(DsAsyncBarrierArriveOp op, OpAdaptor adaptor,
2967 ConversionPatternRewriter &rewriter)
const override {
2969 return op->emitOpError(
"only supported on gfx1250+");
2971 Location loc = op.getLoc();
2973 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
2975 adaptor.getBase(), adaptor.getIndices());
2977 rewriter.replaceOpWithNewOp<ROCDL::DsAtomicAsyncBarrierArriveOp>(
2978 op, ptr,
nullptr,
nullptr,
2984struct DsBarrierArriveOpLowering
2985 :
public ConvertOpToLLVMPattern<DsBarrierArriveOp> {
2988 DsBarrierArriveOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2989 : ConvertOpToLLVMPattern<DsBarrierArriveOp>(converter), chipset(chipset) {
2993 matchAndRewrite(DsBarrierArriveOp op, OpAdaptor adaptor,
2994 ConversionPatternRewriter &rewriter)
const override {
2996 return op->emitOpError(
"only supported on gfx1250+");
2998 Location loc = op.getLoc();
2999 Type i64 = rewriter.getI64Type();
3001 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3003 adaptor.getBase(), adaptor.getIndices());
3005 rewriter.replaceOpWithNewOp<ROCDL::DsAtomicBarrierArriveRtnOp>(
3006 op, i64, ptr, adaptor.getCount(),
nullptr,
3012struct DsBarrierStatePhaseOpLowering
3013 :
public ConvertOpToLLVMPattern<DsBarrierStatePhaseOp> {
3017 matchAndRewrite(DsBarrierStatePhaseOp op, OpAdaptor adaptor,
3018 ConversionPatternRewriter &rewriter)
const override {
3019 Location loc = op.getLoc();
3020 Type i32 = rewriter.getI32Type();
3022 Value state = adaptor.getState();
3024 Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state);
3025 Value phase = LLVM::LShrOp::create(
3026 rewriter, loc, noInitCount,
3029 rewriter.replaceOp(op, phase);
3034struct DsBarrierStatePendingCountOpLowering
3035 :
public ConvertOpToLLVMPattern<DsBarrierStatePendingCountOp> {
3039 matchAndRewrite(DsBarrierStatePendingCountOp op, OpAdaptor adaptor,
3040 ConversionPatternRewriter &rewriter)
const override {
3041 Location loc = op.getLoc();
3042 Type i32 = rewriter.getI32Type();
3044 Value state = adaptor.getState();
3046 Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state);
3047 Value pendingCount = LLVM::AndOp::create(
3048 rewriter, loc, noInitCount,
3050 static_cast<uint32_t
>(kDsBarrierPendingCountMask)));
3052 rewriter.replaceOp(op, pendingCount);
3057struct DsBarrierStateInitCountOpLowering
3058 :
public ConvertOpToLLVMPattern<DsBarrierStateInitCountOp> {
3062 matchAndRewrite(DsBarrierStateInitCountOp op, OpAdaptor adaptor,
3063 ConversionPatternRewriter &rewriter)
const override {
3064 Location loc = op.getLoc();
3065 Type i32 = rewriter.getI32Type();
3067 Value state = adaptor.getState();
3069 Value initCountI64 = LLVM::LShrOp::create(
3070 rewriter, loc, state,
3072 Value initCount = LLVM::TruncOp::create(rewriter, loc, i32, initCountI64);
3074 rewriter.replaceOp(op, initCount);
3079struct DsBarrierStatePhaseParityLowering
3080 :
public ConvertOpToLLVMPattern<DsBarrierStatePhaseParity> {
3084 matchAndRewrite(DsBarrierStatePhaseParity op, OpAdaptor adaptor,
3085 ConversionPatternRewriter &rewriter)
const override {
3086 Location loc = op.getLoc();
3087 Type i1 = rewriter.getI1Type();
3089 Value state = adaptor.getState();
3092 LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), state);
3093 Value phase = LLVM::LShrOp::create(
3094 rewriter, loc, noInitCount,
3096 Value parity = LLVM::TruncOp::create(rewriter, loc, i1, phase);
3098 rewriter.replaceOp(op, parity);
3107static Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
3108 Value accumulator, Value value, int64_t shift) {
3113 value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
3119 constexpr bool isDisjoint =
true;
3120 return LLVM::OrOp::create(rewriter, loc, accumulator, value, isDisjoint);
3123template <
typename BaseOp>
3124struct AMDGPUMakeDmaBaseLowering :
public ConvertOpToLLVMPattern<BaseOp> {
3125 using ConvertOpToLLVMPattern<BaseOp>::ConvertOpToLLVMPattern;
3128 AMDGPUMakeDmaBaseLowering(
const LLVMTypeConverter &converter, Chipset chipset)
3129 : ConvertOpToLLVMPattern<BaseOp>(converter), chipset(chipset) {}
3133 matchAndRewrite(BaseOp op, Adaptor adaptor,
3134 ConversionPatternRewriter &rewriter)
const override {
3136 return op->emitOpError(
"make_dma_base is only supported on gfx1250");
3138 Location loc = op.getLoc();
3140 constexpr int32_t constlen = 4;
3141 Value consts[constlen];
3142 for (int64_t i = 0; i < constlen; ++i)
3145 constexpr int32_t sgprslen = constlen;
3146 Value sgprs[sgprslen];
3147 for (int64_t i = 0; i < sgprslen; ++i) {
3148 sgprs[i] = consts[0];
3151 sgprs[0] = consts[1];
3153 if constexpr (BaseOp::isGather()) {
3154 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30);
3156 auto type = cast<TDMGatherBaseType>(op.getResult().getType());
3157 Type indexType = type.getIndexType();
3159 assert(llvm::is_contained({16u, 32u}, indexSize) &&
3160 "expected index_size to be 16 or 32");
3161 unsigned idx = (indexSize / 16) - 1;
3164 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 31);
3167 ValueRange ldsIndices = adaptor.getLdsIndices();
3168 Value lds = adaptor.getLds();
3169 auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
3172 rewriter, loc, ldsMemRefType, lds, ldsIndices);
3174 ValueRange globalIndices = adaptor.getGlobalIndices();
3175 Value global = adaptor.getGlobal();
3176 auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
3179 rewriter, loc, globalMemRefType, global, globalIndices);
3181 Type i32 = rewriter.getI32Type();
3182 Type i64 = rewriter.getI64Type();
3184 sgprs[1] = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
3185 Value castForGlobalAddr =
3186 LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
3188 sgprs[2] = LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
3190 Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
3193 Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
3196 highHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
3198 sgprs[3] = setValueAtOffset(rewriter, loc, highHalf, consts[2], 30);
3200 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3201 assert(v4i32 &&
"expected type conversion to succeed");
3202 Value
result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3204 for (
auto [sgpr, constant] : llvm::zip_equal(sgprs, consts))
3206 LLVM::InsertElementOp::create(rewriter, loc,
result, sgpr, constant);
3208 rewriter.replaceOp(op,
result);
3213template <
typename DescriptorOp>
3214struct AMDGPULowerDescriptor :
public ConvertOpToLLVMPattern<DescriptorOp> {
3215 using ConvertOpToLLVMPattern<DescriptorOp>::ConvertOpToLLVMPattern;
3218 AMDGPULowerDescriptor(
const LLVMTypeConverter &converter, Chipset chipset)
3219 : ConvertOpToLLVMPattern<DescriptorOp>(converter), chipset(chipset) {}
3222 Value getDGroup0(OpAdaptor adaptor)
const {
return adaptor.getBase(); }
3224 Value setWorkgroupMask(DescriptorOp op, OpAdaptor adaptor,
3225 ConversionPatternRewriter &rewriter, Location loc,
3226 Value sgpr0)
const {
3227 Value mask = op.getWorkgroupMask();
3231 Type i16 = rewriter.getI16Type();
3232 mask = LLVM::BitcastOp::create(rewriter, loc, i16, mask);
3233 Type i32 = rewriter.getI32Type();
3234 Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask);
3235 return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0);
3238 Value setDataSize(DescriptorOp op, OpAdaptor adaptor,
3239 ConversionPatternRewriter &rewriter, Location loc,
3240 Value sgpr0, ArrayRef<Value> consts)
const {
3241 unsigned elementTypeWidthInBits = op.getElementTypeWidth();
3242 assert(llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidthInBits) &&
3243 "expected type width to be 8, 16, 32, or 64.");
3244 int64_t idx = llvm::Log2_32(elementTypeWidthInBits / 8);
3245 Value size = consts[idx];
3246 return setValueAtOffset(rewriter, loc, sgpr0, size, 16);
3249 Value setAtomicBarrier(DescriptorOp op, OpAdaptor adaptor,
3250 ConversionPatternRewriter &rewriter, Location loc,
3251 Value sgpr0, ArrayRef<Value> consts)
const {
3252 if (!adaptor.getAtomicBarrierAddress())
3255 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
3258 Value setIterateEnable(DescriptorOp op, OpAdaptor adaptor,
3259 ConversionPatternRewriter &rewriter, Location loc,
3260 Value sgpr0, ArrayRef<Value> consts)
const {
3261 if (!adaptor.getGlobalIncrement())
3266 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
3269 Value setPadEnable(DescriptorOp op, OpAdaptor adaptor,
3270 ConversionPatternRewriter &rewriter, Location loc,
3271 Value sgpr0, ArrayRef<Value> consts)
const {
3272 if (!op.getPadAmount())
3275 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
3278 Value setEarlyTimeout(DescriptorOp op, OpAdaptor adaptor,
3279 ConversionPatternRewriter &rewriter, Location loc,
3280 Value sgpr0, ArrayRef<Value> consts)
const {
3281 if (!op.getWorkgroupMask())
3284 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21);
3287 Value setPadInterval(DescriptorOp op, OpAdaptor adaptor,
3288 ConversionPatternRewriter &rewriter, Location loc,
3289 Value sgpr0, ArrayRef<Value> consts)
const {
3290 if (!op.getPadAmount())
3299 IntegerType i32 = rewriter.getI32Type();
3300 Value padInterval = adaptor.getPadInterval();
3301 padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
3302 padInterval,
false);
3303 padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
3305 return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
3308 Value setPadAmount(DescriptorOp op, OpAdaptor adaptor,
3309 ConversionPatternRewriter &rewriter, Location loc,
3310 Value sgpr0, ArrayRef<Value> consts)
const {
3311 if (!op.getPadAmount())
3320 Value padAmount = adaptor.getPadAmount();
3321 padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
3323 return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
3326 Value setAtomicBarrierAddress(DescriptorOp op, OpAdaptor adaptor,
3327 ConversionPatternRewriter &rewriter,
3328 Location loc, Value sgpr1,
3329 ArrayRef<Value> consts)
const {
3330 if (!adaptor.getAtomicBarrierAddress())
3333 Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
3334 auto barrierAddressTy =
3335 cast<MemRefType>(op.getAtomicBarrierAddress().getType());
3336 ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
3338 rewriter, loc, barrierAddressTy, atomicBarrierAddress,
3339 atomicBarrierIndices);
3340 IntegerType i32 = rewriter.getI32Type();
3346 atomicBarrierAddress =
3347 LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
3348 atomicBarrierAddress =
3349 LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
3351 atomicBarrierAddress =
3352 LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
3353 return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
3356 std::pair<Value, Value> setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3357 ConversionPatternRewriter &rewriter,
3358 Location loc, Value sgpr1, Value sgpr2,
3359 ArrayRef<Value> consts, uint64_t dimX,
3360 uint32_t offset)
const {
3361 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3362 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3363 SmallVector<OpFoldResult> mixedGlobalSizes =
3365 if (mixedGlobalSizes.size() <= dimX)
3366 return {sgpr1, sgpr2};
3368 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3375 if (
auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3379 IntegerType i32 = rewriter.getI32Type();
3380 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3381 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3384 sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDimX, offset);
3387 Value tensorDimXHigh = LLVM::LShrOp::create(rewriter, loc, tensorDimX, c16);
3388 sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDimXHigh, offset + 16);
3389 return {sgpr1, sgpr2};
3392 std::pair<Value, Value> setTensorDim0(DescriptorOp op, OpAdaptor adaptor,
3393 ConversionPatternRewriter &rewriter,
3394 Location loc, Value sgpr1, Value sgpr2,
3395 ArrayRef<Value> consts)
const {
3396 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, 0,
3400 std::pair<Value, Value> setTensorDim1(DescriptorOp op, OpAdaptor adaptor,
3401 ConversionPatternRewriter &rewriter,
3402 Location loc, Value sgpr2, Value sgpr3,
3403 ArrayRef<Value> consts)
const {
3404 return setTensorDimX(op, adaptor, rewriter, loc, sgpr2, sgpr3, consts, 1,
3408 Value setTileDimX(DescriptorOp op, OpAdaptor adaptor,
3409 ConversionPatternRewriter &rewriter, Location loc,
3410 Value sgpr, ArrayRef<Value> consts,
size_t dimX,
3411 int64_t offset)
const {
3412 ArrayRef<int64_t> sharedStaticSizes = adaptor.getSharedStaticSizes();
3413 ValueRange sharedDynamicSizes = adaptor.getSharedDynamicSizes();
3414 SmallVector<OpFoldResult> mixedSharedSizes =
3416 if (mixedSharedSizes.size() <= dimX)
3419 OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
3428 if (
auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult)) {
3432 IntegerType i32 = rewriter.getI32Type();
3433 tileDimX = cast<Value>(tileDimXOpFoldResult);
3434 tileDimX = LLVM::TruncOp::create(rewriter, loc, i32, tileDimX);
3437 return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
3440 Value setTileDim0(DescriptorOp op, OpAdaptor adaptor,
3441 ConversionPatternRewriter &rewriter, Location loc,
3442 Value sgpr3, ArrayRef<Value> consts)
const {
3443 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
3446 Value setTileDim1(DescriptorOp op, OpAdaptor adaptor,
3447 ConversionPatternRewriter &rewriter, Location loc,
3448 Value sgpr4, ArrayRef<Value> consts)
const {
3449 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
3452 Value setValidIndices(DescriptorOp op, OpAdaptor adaptor,
3453 ConversionPatternRewriter &rewriter, Location loc,
3454 Value sgpr4, ArrayRef<Value> consts)
const {
3455 auto type = cast<VectorType>(op.getIndices().getType());
3456 ArrayRef<int64_t> shape = type.getShape();
3457 assert(shape.size() == 1 &&
"expected shape to be of rank 1.");
3458 unsigned length = shape.back();
3459 assert(0 < length && length <= 16 &&
"expected length to be at most 16.");
3461 return setValueAtOffset(rewriter, loc, sgpr4, value, 128);
3464 Value setTileDim1OrValidIndices(DescriptorOp op, OpAdaptor adaptor,
3465 ConversionPatternRewriter &rewriter,
3466 Location loc, Value sgpr4,
3467 ArrayRef<Value> consts)
const {
3468 if constexpr (DescriptorOp::isGather())
3469 return setValidIndices(op, adaptor, rewriter, loc, sgpr4, consts);
3470 return setTileDim1(op, adaptor, rewriter, loc, sgpr4, consts);
3473 Value setTileDim2(DescriptorOp op, OpAdaptor adaptor,
3474 ConversionPatternRewriter &rewriter, Location loc,
3475 Value sgpr4, ArrayRef<Value> consts)
const {
3477 if constexpr (DescriptorOp::isGather())
3479 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
3482 std::pair<Value, Value>
3483 setTensorDimXStride(DescriptorOp op, OpAdaptor adaptor,
3484 ConversionPatternRewriter &rewriter, Location loc,
3485 Value sgprY, Value sgprZ, ArrayRef<Value> consts,
3486 size_t dimX, int64_t offset)
const {
3487 ArrayRef<int64_t> globalStaticStrides = adaptor.getGlobalStaticStrides();
3488 ValueRange globalDynamicStrides = adaptor.getGlobalDynamicStrides();
3489 SmallVector<OpFoldResult> mixedGlobalStrides =
3490 getMixedValues(globalStaticStrides, globalDynamicStrides, rewriter);
3492 if (mixedGlobalStrides.size() <= (dimX + 1))
3493 return {sgprY, sgprZ};
3495 OpFoldResult tensorDimXStrideOpFoldResult =
3496 *(mixedGlobalStrides.rbegin() + dimX + 1);
3501 Value tensorDimXStride;
3502 if (
auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult))
3506 tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
3508 constexpr int64_t first48bits = (1ll << 48) - 1;
3511 LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
3512 IntegerType i32 = rewriter.getI32Type();
3513 Value tensorDimXStrideLow =
3514 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
3515 sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
3517 int64_t shift = (offset % 32) == 0 ? 32 : offset % 32;
3519 Value tensorDimXStrideHigh =
3520 LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
3521 tensorDimXStrideHigh =
3522 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
3523 sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
3525 return {sgprY, sgprZ};
3528 std::pair<Value, Value>
3529 setTensorDim0Stride(DescriptorOp op, OpAdaptor adaptor,
3530 ConversionPatternRewriter &rewriter, Location loc,
3531 Value sgpr5, Value sgpr6, ArrayRef<Value> consts)
const {
3532 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3536 std::pair<Value, Value>
3537 setTensorDim1Stride(DescriptorOp op, OpAdaptor adaptor,
3538 ConversionPatternRewriter &rewriter, Location loc,
3539 Value sgpr5, Value sgpr6, ArrayRef<Value> consts)
const {
3541 if constexpr (DescriptorOp::isGather())
3542 return {sgpr5, sgpr6};
3543 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3547 Value getDGroup1(DescriptorOp op, OpAdaptor adaptor,
3548 ConversionPatternRewriter &rewriter, Location loc,
3549 ArrayRef<Value> consts)
const {
3551 for (int64_t i = 0; i < 8; ++i) {
3552 sgprs[i] = consts[0];
3555 sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]);
3556 sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts);
3557 sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts);
3558 sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3559 sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3560 sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts);
3561 sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts);
3562 sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts);
3565 setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts);
3566 std::tie(sgprs[1], sgprs[2]) =
3567 setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3568 std::tie(sgprs[2], sgprs[3]) =
3569 setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3571 sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
3573 setTileDim1OrValidIndices(op, adaptor, rewriter, loc, sgprs[4], consts);
3574 sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
3575 std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
3576 op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
3577 std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride(
3578 op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts);
3580 IntegerType i32 = rewriter.getI32Type();
3581 Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32));
3582 assert(v8i32 &&
"expected type conversion to succeed");
3583 Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32);
3585 for (
auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) {
3587 LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant);
3593 Value setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3594 ConversionPatternRewriter &rewriter, Location loc,
3595 Value sgpr0, ArrayRef<Value> consts, int64_t dimX,
3596 int64_t offset)
const {
3597 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3598 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3599 SmallVector<OpFoldResult> mixedGlobalSizes =
3601 if (mixedGlobalSizes.size() <=
static_cast<unsigned long>(dimX))
3604 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3606 if (
auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3610 IntegerType i32 = rewriter.getI32Type();
3611 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3612 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3615 return setValueAtOffset(rewriter, loc, sgpr0, tensorDimX, offset);
3618 Value setTensorDim2(DescriptorOp op, OpAdaptor adaptor,
3619 ConversionPatternRewriter &rewriter, Location loc,
3620 Value sgpr0, ArrayRef<Value> consts)
const {
3621 return setTensorDimX(op, adaptor, rewriter, loc, sgpr0, consts, 2, 0);
3624 Value truncateAndSetValueAtOffset(ConversionPatternRewriter &rewriter,
3625 Location loc, Value accumulator,
3626 Value value, int64_t shift)
const {
3628 IntegerType i32 = rewriter.getI32Type();
3629 value = LLVM::TruncOp::create(rewriter, loc, i32, value);
3630 return setValueAtOffset(rewriter, loc, accumulator, value, shift);
3633 Value setLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3634 ConversionPatternRewriter &rewriter, Location loc,
3635 Value sgpr1, ArrayRef<Value> consts,
3636 int64_t offset)
const {
3637 Value ldsAddrIncrement = adaptor.getLdsIncrement();
3638 return setValueAtOffset(rewriter, loc, sgpr1, ldsAddrIncrement, offset);
3641 std::pair<Value, Value>
3642 setGlobalAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3643 ConversionPatternRewriter &rewriter, Location loc,
3644 Value sgpr2, Value sgpr3, ArrayRef<Value> consts,
3645 int64_t offset)
const {
3646 Value globalAddrIncrement = adaptor.getGlobalIncrement();
3647 sgpr2 = truncateAndSetValueAtOffset(rewriter, loc, sgpr2,
3648 globalAddrIncrement, offset);
3650 globalAddrIncrement =
3651 LLVM::LShrOp::create(rewriter, loc, globalAddrIncrement, shift);
3652 constexpr int64_t first16BitsHigh = (1ll << 16) - 1;
3653 sgpr3 = truncateAndSetValueAtOffset(rewriter, loc, sgpr3,
3654 globalAddrIncrement, offset + 32);
3656 sgpr3 = LLVM::AndOp::create(rewriter, loc, sgpr3, mask);
3657 return {sgpr2, sgpr3};
3660 Value setTensorDim3OrLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3661 ConversionPatternRewriter &rewriter,
3662 Location loc, Value sgpr1,
3663 ArrayRef<Value> consts)
const {
3664 Value ldsIncrement = op.getLdsIncrement();
3665 constexpr int64_t dim = 3;
3666 constexpr int64_t offset = 32;
3668 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, consts, dim,
3670 return setLDSAddrIncrement(op, adaptor, rewriter, loc, sgpr1, consts,
3674 std::pair<Value, Value> setTensorDim2StrideOrGlobalAddrIncrement(
3675 DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
3676 Location loc, Value sgpr2, Value sgpr3, ArrayRef<Value> consts)
const {
3677 Value globalIncrement = op.getGlobalIncrement();
3678 constexpr int32_t dim = 2;
3679 constexpr int32_t offset = 64;
3680 if (!globalIncrement)
3681 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3682 consts, dim, offset);
3683 return setGlobalAddrIncrement(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3687 Value setIterateCount(DescriptorOp op, OpAdaptor adaptor,
3688 ConversionPatternRewriter &rewriter, Location loc,
3689 Value sgpr3, ArrayRef<Value> consts,
3690 int32_t offset)
const {
3691 Value iterationCount = adaptor.getIterationCount();
3692 IntegerType i32 = rewriter.getI32Type();
3699 iterationCount = LLVM::TruncOp::create(rewriter, loc, i32, iterationCount);
3701 LLVM::SubOp::create(rewriter, loc, iterationCount, consts[1]);
3702 return setValueAtOffset(rewriter, loc, sgpr3, iterationCount, offset);
3705 Value setTileDim3OrIterateCount(DescriptorOp op, OpAdaptor adaptor,
3706 ConversionPatternRewriter &rewriter,
3707 Location loc, Value sgpr3,
3708 ArrayRef<Value> consts)
const {
3709 Value iterateCount = op.getIterationCount();
3710 constexpr int32_t dim = 2;
3711 constexpr int32_t offset = 112;
3713 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, dim,
3716 return setIterateCount(op, adaptor, rewriter, loc, sgpr3, consts, offset);
3719 Value getDGroup2(DescriptorOp op, OpAdaptor adaptor,
3720 ConversionPatternRewriter &rewriter, Location loc,
3721 ArrayRef<Value> consts)
const {
3722 if constexpr (DescriptorOp::isGather())
3723 return getDGroup2Gather(op, adaptor, rewriter, loc, consts);
3724 return getDGroup2NonGather(op, adaptor, rewriter, loc, consts);
3727 Value getDGroup2NonGather(DescriptorOp op, OpAdaptor adaptor,
3728 ConversionPatternRewriter &rewriter, Location loc,
3729 ArrayRef<Value> consts)
const {
3730 IntegerType i32 = rewriter.getI32Type();
3731 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3732 assert(v4i32 &&
"expected type conversion to succeed.");
3734 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
3735 if (onlyNeedsTwoDescriptors)
3736 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
3738 constexpr int64_t sgprlen = 4;
3739 Value sgprs[sgprlen];
3740 for (
int i = 0; i < sgprlen; ++i)
3741 sgprs[i] = consts[0];
3743 sgprs[0] = setTensorDim2(op, adaptor, rewriter, loc, sgprs[0], consts);
3744 sgprs[1] = setTensorDim3OrLDSAddrIncrement(op, adaptor, rewriter, loc,
3746 std::tie(sgprs[2], sgprs[3]) = setTensorDim2StrideOrGlobalAddrIncrement(
3747 op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3749 setTileDim3OrIterateCount(op, adaptor, rewriter, loc, sgprs[3], consts);
3751 Value dgroup2 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3752 for (
auto [sgpr, constant] : llvm::zip(sgprs, consts))
3754 LLVM::InsertElementOp::create(rewriter, loc, dgroup2, sgpr, constant);
3759 Value getGatherIndices(DescriptorOp op, OpAdaptor adaptor,
3760 ConversionPatternRewriter &rewriter, Location loc,
3761 ArrayRef<Value> consts,
bool firstHalf)
const {
3762 IntegerType i32 = rewriter.getI32Type();
3763 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3764 assert(v4i32 &&
"expected type conversion to succeed.");
3766 Value
indices = adaptor.getIndices();
3767 auto vectorType = cast<VectorType>(
indices.getType());
3768 unsigned length = vectorType.getShape().back();
3769 Type elementType = vectorType.getElementType();
3770 unsigned maxLength = elementType == i32 ? 4 : 8;
3771 int32_t offset = firstHalf ? 0 : maxLength;
3772 unsigned discountedLength =
3773 std::max(
static_cast<int32_t
>(length - offset), 0);
3775 unsigned targetSize = std::min(maxLength, discountedLength);
3777 SmallVector<Value> indicesVector;
3778 for (
unsigned i = offset; i < targetSize + offset; ++i) {
3780 if (i < consts.size())
3784 Value elem = LLVM::ExtractElementOp::create(rewriter, loc,
indices, idx);
3785 indicesVector.push_back(elem);
3788 SmallVector<Value> indicesI32Vector;
3789 if (elementType == i32) {
3790 indicesI32Vector = indicesVector;
3792 for (
unsigned i = 0; i < targetSize; ++i) {
3793 Value index = indicesVector[i];
3794 indicesI32Vector.push_back(
3795 LLVM::ZExtOp::create(rewriter, loc, i32, index));
3797 if ((targetSize % 2) != 0)
3799 indicesI32Vector.push_back(consts[0]);
3802 SmallVector<Value> indicesToInsert;
3803 if (elementType == i32) {
3804 indicesToInsert = indicesI32Vector;
3806 unsigned size = indicesI32Vector.size() / 2;
3807 for (
unsigned i = 0; i < size; ++i) {
3808 Value first = indicesI32Vector[2 * i];
3809 Value second = indicesI32Vector[2 * i + 1];
3810 Value joined = setValueAtOffset(rewriter, loc, first, second, 16);
3811 indicesToInsert.push_back(joined);
3815 Value dgroup = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3816 for (
auto [sgpr, constant] : llvm::zip_first(indicesToInsert, consts))
3818 LLVM::InsertElementOp::create(rewriter, loc, dgroup, sgpr, constant);
3823 Value getDGroup2Gather(DescriptorOp op, OpAdaptor adaptor,
3824 ConversionPatternRewriter &rewriter, Location loc,
3825 ArrayRef<Value> consts)
const {
3826 return getGatherIndices(op, adaptor, rewriter, loc, consts,
true);
3829 std::pair<Value, Value>
3830 setTensorDim3Stride(DescriptorOp op, OpAdaptor adaptor,
3831 ConversionPatternRewriter &rewriter, Location loc,
3832 Value sgpr0, Value sgpr1, ArrayRef<Value> consts)
const {
3833 constexpr int32_t dim = 3;
3834 constexpr int32_t offset = 0;
3835 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr0, sgpr1, consts,
3839 std::pair<Value, Value> setTensorDim4(DescriptorOp op, OpAdaptor adaptor,
3840 ConversionPatternRewriter &rewriter,
3841 Location loc, Value sgpr1, Value sgpr2,
3842 ArrayRef<Value> consts)
const {
3843 constexpr int32_t dim = 4;
3844 constexpr int32_t offset = 48;
3845 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, dim,
3849 Value setTileDim4(DescriptorOp op, OpAdaptor adaptor,
3850 ConversionPatternRewriter &rewriter, Location loc,
3851 Value sgpr2, ArrayRef<Value> consts)
const {
3852 constexpr int32_t dim = 4;
3853 constexpr int32_t offset = 80;
3854 return setTileDimX(op, adaptor, rewriter, loc, sgpr2, consts, dim, offset);
3857 Value getDGroup3(DescriptorOp op, OpAdaptor adaptor,
3858 ConversionPatternRewriter &rewriter, Location loc,
3859 ArrayRef<Value> consts)
const {
3860 if constexpr (DescriptorOp::isGather())
3861 return getDGroup3Gather(op, adaptor, rewriter, loc, consts);
3862 return getDGroup3NonGather(op, adaptor, rewriter, loc, consts);
3865 Value getDGroup3NonGather(DescriptorOp op, OpAdaptor adaptor,
3866 ConversionPatternRewriter &rewriter, Location loc,
3867 ArrayRef<Value> consts)
const {
3868 IntegerType i32 = rewriter.getI32Type();
3869 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3870 assert(v4i32 &&
"expected type conversion to succeed.");
3871 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
3872 if (onlyNeedsTwoDescriptors)
3873 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
3875 constexpr int32_t sgprlen = 4;
3876 Value sgprs[sgprlen];
3877 for (
int i = 0; i < sgprlen; ++i)
3878 sgprs[i] = consts[0];
3880 std::tie(sgprs[0], sgprs[1]) = setTensorDim3Stride(
3881 op, adaptor, rewriter, loc, sgprs[0], sgprs[1], consts);
3882 std::tie(sgprs[1], sgprs[2]) =
3883 setTensorDim4(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3884 sgprs[2] = setTileDim4(op, adaptor, rewriter, loc, sgprs[2], consts);
3886 Value dgroup3 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3887 for (
auto [sgpr, constant] : llvm::zip(sgprs, consts))
3889 LLVM::InsertElementOp::create(rewriter, loc, dgroup3, sgpr, constant);
3894 Value getDGroup3Gather(DescriptorOp op, OpAdaptor adaptor,
3895 ConversionPatternRewriter &rewriter, Location loc,
3896 ArrayRef<Value> consts)
const {
3897 return getGatherIndices(op, adaptor, rewriter, loc, consts,
false);
3901 matchAndRewrite(DescriptorOp op, OpAdaptor adaptor,
3902 ConversionPatternRewriter &rewriter)
const override {
3904 return op->emitOpError(
3905 "make_dma_descriptor is only supported on gfx1250");
3907 Location loc = op.getLoc();
3909 SmallVector<Value> consts;
3910 for (int64_t i = 0; i < 8; ++i)
3913 Value dgroup0 = this->getDGroup0(adaptor);
3914 Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts);
3915 Value dgroup2 = this->getDGroup2(op, adaptor, rewriter, loc, consts);
3916 Value dgroup3 = this->getDGroup3(op, adaptor, rewriter, loc, consts);
3917 SmallVector<Value> results = {dgroup0, dgroup1, dgroup2, dgroup3};
3918 rewriter.replaceOpWithMultiple(op, {results});
3923template <
typename SourceOp,
typename TargetOp>
3924struct AMDGPUTensorLoadStoreOpLowering
3925 :
public ConvertOpToLLVMPattern<SourceOp> {
3926 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
3928 AMDGPUTensorLoadStoreOpLowering(
const LLVMTypeConverter &converter,
3930 : ConvertOpToLLVMPattern<SourceOp>(converter), chipset(chipset) {}
3934 matchAndRewrite(SourceOp op, Adaptor adaptor,
3935 ConversionPatternRewriter &rewriter)
const override {
3937 return op->emitOpError(
"is only supported on gfx1250");
3942 auto v8i32 = VectorType::get(8, rewriter.getI32Type());
3943 Value dgroup4 = LLVM::ZeroOp::create(rewriter, op.getLoc(), v8i32);
3944 rewriter.replaceOpWithNewOp<TargetOp>(op, desc[0], desc[1], desc[2],
3945 desc[3], dgroup4, 0,
3953struct ConvertAMDGPUToROCDLPass
3954 :
public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
3957 void runOnOperation()
override {
3960 if (
failed(maybeChipset)) {
3961 emitError(UnknownLoc::get(ctx),
"Invalid chipset name: " + chipset);
3962 return signalPassFailure();
3965 RewritePatternSet patterns(ctx);
3966 LLVMTypeConverter converter(ctx);
3969 amdgpu::populateCommonGPUTypeAndAttributeConversions(converter);
3971 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
3972 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
3973 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
3974 if (
failed(applyPartialConversion(getOperation(),
target,
3975 std::move(patterns))))
3976 signalPassFailure();
3984 typeConverter, [](gpu::AddressSpace space) {
3986 case gpu::AddressSpace::Global:
3987 return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace;
3988 case gpu::AddressSpace::Workgroup:
3989 return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace;
3990 case gpu::AddressSpace::Private:
3991 return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace;
3993 llvm_unreachable(
"unknown address space enum value");
3999 typeConverter.addTypeAttributeConversion(
4001 -> TypeConverter::AttributeConversionResult {
4003 Type i64 = IntegerType::get(ctx, 64);
4004 switch (as.getValue()) {
4005 case amdgpu::AddressSpace::FatRawBuffer:
4006 return IntegerAttr::get(i64, 7);
4007 case amdgpu::AddressSpace::BufferRsrc:
4008 return IntegerAttr::get(i64, 8);
4009 case amdgpu::AddressSpace::FatStructuredBuffer:
4010 return IntegerAttr::get(i64, 9);
4012 return TypeConverter::AttributeConversionResult::abort();
4014 typeConverter.addConversion([&](DsBarrierStateType type) ->
Type {
4015 return IntegerType::get(type.
getContext(), 64);
4017 typeConverter.addConversion([&](TDMBaseType type) ->
Type {
4019 return typeConverter.convertType(VectorType::get(4, i32));
4021 typeConverter.addConversion([&](TDMGatherBaseType type) ->
Type {
4023 return typeConverter.convertType(VectorType::get(4, i32));
4025 typeConverter.addConversion(
4026 [&](TDMDescriptorType type,
4029 Type v4i32 = typeConverter.convertType(VectorType::get(4, i32));
4030 Type v8i32 = typeConverter.convertType(VectorType::get(8, i32));
4031 llvm::append_values(
result, v4i32, v8i32, v4i32, v4i32);
4041 if (inputs.size() != 1)
4044 if (!isa<TDMDescriptorType>(inputs[0].
getType()))
4047 auto cast = UnrealizedConversionCastOp::create(builder, loc, types, inputs);
4048 return cast.getResults();
4051 typeConverter.addTargetMaterialization(addUnrealizedCast);
4059 .
add<FatRawBufferCastLowering,
4060 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
4061 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
4062 RawBufferOpLowering<RawBufferAtomicFaddOp,
4063 ROCDL::RawPtrBufferAtomicFaddOp>,
4064 RawBufferOpLowering<RawBufferAtomicFmaxOp,
4065 ROCDL::RawPtrBufferAtomicFmaxOp>,
4066 RawBufferOpLowering<RawBufferAtomicSmaxOp,
4067 ROCDL::RawPtrBufferAtomicSmaxOp>,
4068 RawBufferOpLowering<RawBufferAtomicUminOp,
4069 ROCDL::RawPtrBufferAtomicUminOp>,
4070 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
4071 ROCDL::RawPtrBufferAtomicCmpSwap>,
4072 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
4073 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
4074 SparseMFMAOpLowering, WMMAOpLowering, ScaledWMMAOpLowering,
4075 SparseWMMAOpLowering, ExtPackedFp8OpLowering,
4076 ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
4077 PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
4078 PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
4079 TransposeLoadOpLowering, AMDGPUPermlaneLowering,
4080 AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
4081 AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
4082 AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
4083 AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>,
4084 AMDGPUTensorLoadStoreOpLowering<TensorLoadToLDSOp,
4085 ROCDL::TensorLoadToLDSOp>,
4086 AMDGPUTensorLoadStoreOpLowering<TensorStoreFromLDSOp,
4087 ROCDL::TensorStoreFromLDSOp>,
4088 DsBarrierInitOpLowering, DsBarrierPollStateOpLowering,
4089 DsAsyncBarrierArriveOpLowering, DsBarrierArriveOpLowering>(converter,
4091 patterns.
add<AMDGPUSwizzleBitModeLowering, DsBarrierStatePhaseOpLowering,
4092 DsBarrierStatePendingCountOpLowering,
4093 DsBarrierStateInitCountOpLowering,
4094 DsBarrierStatePhaseParityLowering>(converter);
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k, bool isRDNA3)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma for RDNA3/4 architectures.
static std::optional< SparseWMMAOpInfo > sparseWMMAOpToIntrinsic(SparseWMMAOp swmmac, Chipset chipset)
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
constexpr Chipset kGfx908
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs, StringRef attrName)
Push an input operand.
constexpr Chipset kGfx1250
static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA/WMMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to R...
constexpr Chipset kGfx90a
static std::optional< StringRef > getScaledWmmaIntrinsicName(int64_t m, int64_t n, int64_t k, bool isScale16)
Determines the ROCDL intrinsic name for scaled WMMA based on dimensions and scale block size (16 or 3...
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static std::optional< StringRef > smfmacOpToIntrinsic(SparseMFMAOp op, Chipset chipset)
Returns the rocdl intrinsic corresponding to a SparseMFMA (smfmac) operation if one exists.
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value)
static std::optional< StringRef > wmmaOpToIntrinsicGfx1250(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k)
Return the rocdl intrinsic corresponding to a WMMA operation wmma for the gfx1250 architecture.
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, int64_t elementByteWidth, amdgpu::Chipset chipset, bool boundsCheck)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Pack small float vector operands (fp4/fp6/fp8/bf16) into the format expected by scaled matrix multipl...
static std::optional< uint32_t > getWmmaScaleFormat(Type elemType)
Maps f8 scale element types to WMMA scale format codes.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static std::optional< uint32_t > smallFloatTypeToFormatCode(Type mlirElemType)
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i64.
constexpr Chipset kGfx950
static Value convertSparseVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts sparse MFMA/WMMA (smfmac/swmmac) operands to the expected ROCDL types.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
This class provides a shared interface for ranked and unranked memref types.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
typename SourceOp::Adaptor OpAdaptor
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
NamedAttribute represents a combination of a name and an Attribute value.
This class helps build Operations.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
bool hasOcpFp8(const Chipset &chipset)
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces and types,...
llvm::TypeSwitch< T, ResultT > TypeSwitch
void populateAMDGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
Returns the rocdl intrinsic corresponding to a SparseWMMA operation swmmac if one exists.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.