30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/TypeSwitch.h"
32#include "llvm/Support/AMDGPUAddrSpace.h"
33#include "llvm/Support/Casting.h"
34#include "llvm/Support/ErrorHandling.h"
39#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
40#include "mlir/Conversion/Passes.h.inc"
56 IntegerType i32 = rewriter.getI32Type();
58 auto valTy = cast<IntegerType>(val.
getType());
61 return valTy.getWidth() > 32
62 ?
Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
63 :
Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
68 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
74 IntegerType i64 = rewriter.getI64Type();
76 auto valTy = cast<IntegerType>(val.
getType());
79 return valTy.getWidth() > 64
80 ?
Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
81 :
Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
86 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
93 IntegerType i32 = rewriter.getI32Type();
95 for (
auto [i, increment, stride] : llvm::enumerate(
indices, strides)) {
98 ShapedType::isDynamic(stride)
100 memRefDescriptor.
stride(rewriter, loc, i))
101 : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
102 increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
114 MemRefType memrefType,
118 if (chipset >=
kGfx1250 && !boundsCheck) {
119 constexpr int64_t first45bits = (1ll << 45) - 1;
122 if (memrefType.hasStaticShape() &&
123 !llvm::any_of(strides, ShapedType::isDynamic)) {
124 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
126 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
127 size = std::max(
shape[i] * strides[i], size);
128 size = size * elementByteWidth;
132 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
133 Value size = memrefDescriptor.
size(rewriter, loc, i);
134 Value stride = memrefDescriptor.
stride(rewriter, loc, i);
135 Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
137 ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
142 return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
148 Value cacheSwizzleStride =
nullptr,
149 unsigned addressSpace = 8) {
153 Type i16 = rewriter.getI16Type();
156 Value cacheStrideZext =
157 LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
158 Value swizzleBit = LLVM::ConstantOp::create(
159 rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
160 stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
163 stride = LLVM::ConstantOp::create(rewriter, loc, i16,
164 rewriter.getI16IntegerAttr(0));
193 flags |= (7 << 12) | (4 << 15);
196 uint32_t oob = boundsCheck ? 3 : 2;
197 flags |= (oob << 28);
202 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
203 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
204 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
209struct FatRawBufferCastLowering
211 FatRawBufferCastLowering(
const LLVMTypeConverter &converter, Chipset chipset)
212 : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
218 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
219 ConversionPatternRewriter &rewriter)
const override {
220 Location loc = op.getLoc();
221 Value memRef = adaptor.getSource();
222 Value unconvertedMemref = op.getSource();
223 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
224 MemRefDescriptor descriptor(memRef);
226 DataLayout dataLayout = DataLayout::closest(op);
227 int64_t elementByteWidth =
230 int64_t unusedOffset = 0;
231 SmallVector<int64_t, 5> strideVals;
232 if (
failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
233 return op.emitOpError(
"Can't lower non-stride-offset memrefs");
235 Value numRecords = adaptor.getValidBytes();
238 getNumRecords(rewriter, loc, memrefType, descriptor, strideVals,
239 elementByteWidth, chipset, adaptor.getBoundsCheck());
242 adaptor.getResetOffset()
243 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
245 : descriptor.alignedPtr(rewriter, loc);
247 Value offset = adaptor.getResetOffset()
248 ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
249 rewriter.getIndexAttr(0))
250 : descriptor.offset(rewriter, loc);
252 bool hasSizes = memrefType.getRank() > 0;
255 Value sizes = hasSizes
256 ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
260 hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
265 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
266 chipset, adaptor.getCacheSwizzleStride(), 7);
268 Value
result = MemRefDescriptor::poison(
270 getTypeConverter()->convertType(op.getResult().getType()));
272 result = LLVM::InsertValueOp::create(rewriter, loc,
result, fatPtr, pos);
273 result = LLVM::InsertValueOp::create(rewriter, loc,
result, fatPtr,
275 result = LLVM::InsertValueOp::create(rewriter, loc,
result, offset,
278 result = LLVM::InsertValueOp::create(rewriter, loc,
result, sizes,
280 result = LLVM::InsertValueOp::create(rewriter, loc,
result, strides,
283 rewriter.replaceOp(op,
result);
289template <
typename GpuOp,
typename Intrinsic>
291 RawBufferOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
292 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
295 static constexpr uint32_t maxVectorOpWidth = 128;
298 matchAndRewrite(GpuOp gpuOp,
typename GpuOp::Adaptor adaptor,
299 ConversionPatternRewriter &rewriter)
const override {
300 Location loc = gpuOp.getLoc();
301 Value memref = adaptor.getMemref();
302 Value unconvertedMemref = gpuOp.getMemref();
303 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
305 if (chipset.majorVersion < 9)
306 return gpuOp.emitOpError(
"raw buffer ops require GCN or higher");
308 Value storeData = adaptor.getODSOperands(0)[0];
309 if (storeData == memref)
313 wantedDataType = storeData.
getType();
315 wantedDataType = gpuOp.getODSResults(0)[0].getType();
317 Value atomicCmpData = Value();
320 Value maybeCmpData = adaptor.getODSOperands(1)[0];
321 if (maybeCmpData != memref)
322 atomicCmpData = maybeCmpData;
325 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
327 Type i32 = rewriter.getI32Type();
330 DataLayout dataLayout = DataLayout::closest(gpuOp);
331 int64_t elementByteWidth =
340 Type llvmBufferValType = llvmWantedDataType;
342 if (
auto floatType = dyn_cast<FloatType>(wantedDataType))
343 llvmBufferValType = this->getTypeConverter()->convertType(
344 rewriter.getIntegerType(floatType.getWidth()));
346 if (
auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
347 uint32_t vecLen = dataVector.getNumElements();
350 uint32_t totalBits = elemBits * vecLen;
352 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
353 if (totalBits > maxVectorOpWidth)
354 return gpuOp.emitOpError(
355 "Total width of loads or stores must be no more than " +
356 Twine(maxVectorOpWidth) +
" bits, but we call for " +
358 " bits. This should've been caught in validation");
359 if (!usePackedFp16 && elemBits < 32) {
360 if (totalBits > 32) {
361 if (totalBits % 32 != 0)
362 return gpuOp.emitOpError(
"Load or store of more than 32-bits that "
363 "doesn't fit into words. Can't happen\n");
364 llvmBufferValType = this->typeConverter->convertType(
365 VectorType::get(totalBits / 32, i32));
367 llvmBufferValType = this->typeConverter->convertType(
368 rewriter.getIntegerType(totalBits));
372 if (
auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
375 if (vecType.getNumElements() == 1)
376 llvmBufferValType = vecType.getElementType();
379 SmallVector<Value, 6> args;
381 if (llvmBufferValType != llvmWantedDataType) {
382 Value castForStore = LLVM::BitcastOp::create(
383 rewriter, loc, llvmBufferValType, storeData);
384 args.push_back(castForStore);
386 args.push_back(storeData);
391 if (llvmBufferValType != llvmWantedDataType) {
392 Value castForCmp = LLVM::BitcastOp::create(
393 rewriter, loc, llvmBufferValType, atomicCmpData);
394 args.push_back(castForCmp);
396 args.push_back(atomicCmpData);
402 SmallVector<int64_t, 5> strides;
403 if (
failed(memrefType.getStridesAndOffset(strides, offset)))
404 return gpuOp.emitOpError(
"Can't lower non-stride-offset memrefs");
406 MemRefDescriptor memrefDescriptor(memref);
408 Value ptr = memrefDescriptor.bufferPtr(
409 rewriter, loc, *this->getTypeConverter(), memrefType);
411 getNumRecords(rewriter, loc, memrefType, memrefDescriptor, strides,
412 elementByteWidth, chipset, adaptor.getBoundsCheck());
414 adaptor.getBoundsCheck(), chipset);
415 args.push_back(resource);
419 adaptor.getIndices(), strides);
420 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
421 indexOffset && *indexOffset > 0) {
423 voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
427 voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
428 args.push_back(voffset);
431 Value sgprOffset = adaptor.getSgprOffset();
434 sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
435 args.push_back(sgprOffset);
442 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
444 Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
445 ArrayRef<NamedAttribute>());
448 if (llvmBufferValType != llvmWantedDataType) {
449 replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
454 rewriter.eraseOp(gpuOp);
471static FailureOr<unsigned> encodeWaitcnt(
Chipset chipset,
unsigned vmcnt,
472 unsigned expcnt,
unsigned lgkmcnt) {
474 vmcnt = std::min(15u, vmcnt);
475 expcnt = std::min(7u, expcnt);
476 lgkmcnt = std::min(15u, lgkmcnt);
477 return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
480 vmcnt = std::min(63u, vmcnt);
481 expcnt = std::min(7u, expcnt);
482 lgkmcnt = std::min(15u, lgkmcnt);
483 unsigned lowBits = vmcnt & 0xF;
484 unsigned highBits = (vmcnt >> 4) << 14;
485 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
486 return lowBits | highBits | otherCnts;
489 vmcnt = std::min(63u, vmcnt);
490 expcnt = std::min(7u, expcnt);
491 lgkmcnt = std::min(63u, lgkmcnt);
492 unsigned lowBits = vmcnt & 0xF;
493 unsigned highBits = (vmcnt >> 4) << 14;
494 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
495 return lowBits | highBits | otherCnts;
498 vmcnt = std::min(63u, vmcnt);
499 expcnt = std::min(7u, expcnt);
500 lgkmcnt = std::min(63u, lgkmcnt);
501 return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
506struct MemoryCounterWaitOpLowering
508 MemoryCounterWaitOpLowering(
const LLVMTypeConverter &converter,
510 : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
516 matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
517 ConversionPatternRewriter &rewriter)
const override {
518 if (chipset.majorVersion >= 12) {
519 Location loc = op.getLoc();
520 if (std::optional<int> ds = adaptor.getDs())
521 ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
523 if (std::optional<int>
load = adaptor.getLoad())
524 ROCDL::WaitLoadcntOp::create(rewriter, loc, *
load);
526 if (std::optional<int> store = adaptor.getStore())
527 ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
529 if (std::optional<int> exp = adaptor.getExp())
530 ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
532 if (std::optional<int> tensor = adaptor.getTensor())
533 ROCDL::WaitTensorcntOp::create(rewriter, loc, *tensor);
535 rewriter.eraseOp(op);
539 if (adaptor.getTensor())
540 return op.emitOpError(
"unsupported chipset");
542 auto getVal = [](Attribute attr) ->
unsigned {
544 return cast<IntegerAttr>(attr).getInt();
549 unsigned ds = getVal(adaptor.getDsAttr());
550 unsigned exp = getVal(adaptor.getExpAttr());
552 unsigned vmcnt = 1024;
553 Attribute
load = adaptor.getLoadAttr();
554 Attribute store = adaptor.getStoreAttr();
556 vmcnt = getVal(
load) + getVal(store);
558 vmcnt = getVal(
load);
560 vmcnt = getVal(store);
563 FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
565 return op.emitOpError(
"unsupported chipset");
567 rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
573 LDSBarrierOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
574 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
579 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
580 ConversionPatternRewriter &rewriter)
const override {
581 Location loc = op.getLoc();
584 bool requiresInlineAsm = chipset <
kGfx90a;
587 rewriter.getAttr<LLVM::MMRATagAttr>(
"amdgpu-synchronize-as",
"local");
596 StringRef scope =
"workgroup";
598 auto relFence = LLVM::FenceOp::create(rewriter, loc,
599 LLVM::AtomicOrdering::release, scope);
600 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
601 if (requiresInlineAsm) {
602 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
603 LLVM::AsmDialect::AD_ATT);
604 const char *asmStr =
";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
605 const char *constraints =
"";
606 LLVM::InlineAsmOp::create(
609 asmStr, constraints,
true,
610 false, LLVM::TailCallKind::None,
613 }
else if (chipset.majorVersion < 12) {
614 ROCDL::SBarrierOp::create(rewriter, loc);
616 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
617 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
620 auto acqFence = LLVM::FenceOp::create(rewriter, loc,
621 LLVM::AtomicOrdering::acquire, scope);
622 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
623 rewriter.replaceOp(op, acqFence);
629 SchedBarrierOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
630 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
635 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
636 ConversionPatternRewriter &rewriter)
const override {
637 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
638 (uint32_t)op.getOpts());
662 bool allowBf16 =
true) {
664 if (
auto vectorType = dyn_cast<VectorType>(inputType)) {
665 if (vectorType.getElementType().isBF16() && !allowBf16)
666 return LLVM::BitcastOp::create(
667 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
668 if (vectorType.getElementType().isInteger(8) &&
669 vectorType.getNumElements() <= 8)
670 return LLVM::BitcastOp::create(
672 rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
673 if (isa<IntegerType>(vectorType.getElementType()) &&
674 vectorType.getElementTypeBitWidth() <= 8) {
675 int64_t numWords = llvm::divideCeil(
676 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
678 return LLVM::BitcastOp::create(
679 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
690 bool allowBf16 =
true) {
692 auto vectorType = cast<VectorType>(inputType);
694 if (vectorType.getElementType().isBF16() && !allowBf16)
695 return LLVM::BitcastOp::create(
696 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
698 if (isa<IntegerType>(vectorType.getElementType()) &&
699 vectorType.getElementTypeBitWidth() <= 8) {
700 int64_t numWords = llvm::divideCeil(
701 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32);
702 Type castType = (numWords > 1)
703 ?
Type{VectorType::get(numWords, rewriter.getI32Type())}
704 : rewriter.getI32Type();
705 return LLVM::BitcastOp::create(rewriter, loc, castType, input);
723 .Case([&](IntegerType) {
725 return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(),
728 .Case([&](VectorType vectorType) {
730 int64_t numElements = vectorType.getNumElements();
731 assert((numElements == 4 || numElements == 8) &&
732 "scale operand must be a vector of length 4 or 8");
733 IntegerType outputType =
734 (numElements == 4) ? rewriter.getI32Type() : rewriter.getI64Type();
735 return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
737 .DefaultUnreachable(
"unexpected input type for scale operand");
743 .Case([](Float8E8M0FNUType) {
return 0; })
744 .Case([](Float8E4M3FNType) {
return 2; })
745 .Default(std::nullopt);
750static std::optional<StringRef>
752 if (m == 16 && n == 16 && k == 128)
754 ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
755 : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
757 if (m == 32 && n == 16 && k == 128)
758 return isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
759 : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
773 ConversionPatternRewriter &rewriter,
Location loc,
778 auto vectorType = dyn_cast<VectorType>(inputType);
780 operands.push_back(llvmInput);
783 Type elemType = vectorType.getElementType();
785 operands.push_back(llvmInput);
792 auto mlirInputType = cast<VectorType>(mlirInput.
getType());
793 bool isInputInteger = mlirInputType.getElementType().isInteger();
794 if (isInputInteger) {
796 bool localIsUnsigned = isUnsigned;
798 localIsUnsigned =
true;
800 localIsUnsigned =
false;
803 NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
808 Type i32 = rewriter.getI32Type();
809 Type intrinsicInType = numBits <= 32
810 ? (
Type)rewriter.getIntegerType(numBits)
811 : (
Type)VectorType::get(numBits / 32, i32);
812 auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
813 Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
814 loc, llvmIntrinsicInType, llvmInput);
819 castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
820 operands.push_back(castInput);
833 Value output, int32_t subwordOffset,
837 auto vectorType = dyn_cast<VectorType>(inputType);
838 Type elemType = vectorType.getElementType();
839 operands.push_back(output);
851 return (chipset ==
kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
852 (
hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
858 return (chipset ==
kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
859 (
hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
867 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
868 b = mfma.getBlocks();
873 if (mfma.getReducePrecision() && chipset >=
kGfx942) {
874 if (m == 32 && n == 32 && k == 4 &&
b == 1)
875 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
876 if (m == 16 && n == 16 && k == 8 &&
b == 1)
877 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
879 if (m == 32 && n == 32 && k == 1 &&
b == 2)
880 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
881 if (m == 16 && n == 16 && k == 1 &&
b == 4)
882 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
883 if (m == 4 && n == 4 && k == 1 &&
b == 16)
884 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
885 if (m == 32 && n == 32 && k == 2 &&
b == 1)
886 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
887 if (m == 16 && n == 16 && k == 4 &&
b == 1)
888 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
893 if (m == 32 && n == 32 && k == 16 &&
b == 1)
894 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
895 if (m == 16 && n == 16 && k == 32 &&
b == 1)
896 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
898 if (m == 32 && n == 32 && k == 4 &&
b == 2)
899 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
900 if (m == 16 && n == 16 && k == 4 &&
b == 4)
901 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
902 if (m == 4 && n == 4 && k == 4 &&
b == 16)
903 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
904 if (m == 32 && n == 32 && k == 8 &&
b == 1)
905 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
906 if (m == 16 && n == 16 && k == 16 &&
b == 1)
907 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
912 if (m == 32 && n == 32 && k == 16 &&
b == 1)
913 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
914 if (m == 16 && n == 16 && k == 32 &&
b == 1)
915 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
918 if (m == 32 && n == 32 && k == 4 &&
b == 2)
919 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
920 if (m == 16 && n == 16 && k == 4 &&
b == 4)
921 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
922 if (m == 4 && n == 4 && k == 4 &&
b == 16)
923 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
924 if (m == 32 && n == 32 && k == 8 &&
b == 1)
925 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
926 if (m == 16 && n == 16 && k == 16 &&
b == 1)
927 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
929 if (m == 32 && n == 32 && k == 2 &&
b == 2)
930 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
931 if (m == 16 && n == 16 && k == 2 &&
b == 4)
932 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
933 if (m == 4 && n == 4 && k == 2 &&
b == 16)
934 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
935 if (m == 32 && n == 32 && k == 4 &&
b == 1)
936 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
937 if (m == 16 && n == 16 && k == 8 &&
b == 1)
938 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
943 if (m == 32 && n == 32 && k == 32 &&
b == 1)
944 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
945 if (m == 16 && n == 16 && k == 64 &&
b == 1)
946 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
948 if (m == 32 && n == 32 && k == 4 &&
b == 2)
949 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
950 if (m == 16 && n == 16 && k == 4 &&
b == 4)
951 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
952 if (m == 4 && n == 4 && k == 4 &&
b == 16)
953 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
954 if (m == 32 && n == 32 && k == 8 &&
b == 1)
955 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
956 if (m == 16 && n == 16 && k == 16 &&
b == 1)
957 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
958 if (m == 32 && n == 32 && k == 16 &&
b == 1 && chipset >=
kGfx942)
959 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
960 if (m == 16 && n == 16 && k == 32 &&
b == 1 && chipset >=
kGfx942)
961 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
965 if (m == 16 && n == 16 && k == 4 &&
b == 1)
966 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
967 if (m == 4 && n == 4 && k == 4 &&
b == 4)
968 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
975 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
976 if (m == 16 && n == 16 && k == 32 &&
b == 1) {
978 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
980 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
982 if (m == 32 && n == 32 && k == 16 &&
b == 1) {
984 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
986 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
992 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
993 if (m == 16 && n == 16 && k == 32 &&
b == 1) {
995 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
997 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
999 if (m == 32 && n == 32 && k == 16 &&
b == 1) {
1001 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
1003 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
1007 return std::nullopt;
1012 .Case([](Float8E4M3FNType) {
return 0u; })
1013 .Case([](Float8E5M2Type) {
return 1u; })
1014 .Case([](Float6E2M3FNType) {
return 2u; })
1015 .Case([](Float6E3M2FNType) {
return 3u; })
1016 .Case([](Float4E2M1FNType) {
return 4u; })
1017 .Default(std::nullopt);
1027static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1029 uint32_t n, uint32_t k, uint32_t
b,
Chipset chipset) {
1035 return std::nullopt;
1036 if (!isa<Float32Type>(destType))
1037 return std::nullopt;
1041 if (!aTypeCode || !bTypeCode)
1042 return std::nullopt;
1044 if (m == 32 && n == 32 && k == 64 &&
b == 1)
1045 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
1046 *aTypeCode, *bTypeCode};
1047 if (m == 16 && n == 16 && k == 128 &&
b == 1)
1049 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
1052 return std::nullopt;
1055static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1058 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
1059 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
1060 mfma.getBlocks(), chipset);
1063static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1066 smfma.getSourceB().getType(),
1067 smfma.getDestC().getType(), smfma.getM(),
1068 smfma.getN(), smfma.getK(), 1u, chipset);
1073static std::optional<StringRef>
1075 Type elemDestType, uint32_t k,
bool isRDNA3) {
1076 using fp8 = Float8E4M3FNType;
1077 using bf8 = Float8E5M2Type;
1082 if (elemSourceType.
isF16() && elemDestType.
isF32())
1083 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
1084 if (elemSourceType.
isBF16() && elemDestType.
isF32())
1085 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
1086 if (elemSourceType.
isF16() && elemDestType.
isF16())
1087 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1089 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1091 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1096 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1097 return std::nullopt;
1101 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1102 elemDestType.
isF32())
1103 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1104 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1105 elemDestType.
isF32())
1106 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1107 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1108 elemDestType.
isF32())
1109 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1110 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1111 elemDestType.
isF32())
1112 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1114 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1116 return std::nullopt;
1120 if (k == 32 && !isRDNA3) {
1122 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1125 return std::nullopt;
1131 Type elemBSourceType,
1134 using fp8 = Float8E4M3FNType;
1135 using bf8 = Float8E5M2Type;
1138 if (elemSourceType.
isF32() && elemDestType.
isF32())
1139 return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1141 return std::nullopt;
1145 if (elemSourceType.
isF16() && elemDestType.
isF32())
1146 return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
1147 if (elemSourceType.
isBF16() && elemDestType.
isF32())
1148 return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
1149 if (elemSourceType.
isF16() && elemDestType.
isF16())
1150 return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
1152 return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
1154 return std::nullopt;
1158 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1159 if (elemDestType.
isF32())
1160 return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
1161 if (elemDestType.
isF16())
1162 return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
1164 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1165 if (elemDestType.
isF32())
1166 return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
1167 if (elemDestType.
isF16())
1168 return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
1170 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1171 if (elemDestType.
isF32())
1172 return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
1173 if (elemDestType.
isF16())
1174 return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
1176 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1177 if (elemDestType.
isF32())
1178 return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
1179 if (elemDestType.
isF16())
1180 return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
1183 return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
1185 return std::nullopt;
1189 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1190 if (elemDestType.
isF32())
1191 return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
1192 if (elemDestType.
isF16())
1193 return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
1195 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1196 if (elemDestType.
isF32())
1197 return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
1198 if (elemDestType.
isF16())
1199 return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
1201 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1202 if (elemDestType.
isF32())
1203 return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
1204 if (elemDestType.
isF16())
1205 return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
1207 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1208 if (elemDestType.
isF32())
1209 return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
1210 if (elemDestType.
isF16())
1211 return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
1214 return std::nullopt;
1217 return std::nullopt;
1225 bool isGfx950 = chipset >=
kGfx950;
1229 uint32_t m = op.getM(), n = op.getN(), k = op.getK();
1234 if (m == 16 && n == 16 && k == 32) {
1236 return ROCDL::smfmac_f32_16x16x32_f16::getOperationName();
1238 return ROCDL::smfmac_f32_16x16x32_bf16::getOperationName();
1241 if (m == 16 && n == 16 && k == 64) {
1244 return ROCDL::smfmac_f32_16x16x64_f16::getOperationName();
1246 return ROCDL::smfmac_f32_16x16x64_bf16::getOperationName();
1250 return ROCDL::smfmac_i32_16x16x64_i8::getOperationName();
1251 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1252 return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName();
1253 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1254 return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName();
1255 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1256 return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName();
1257 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1258 return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName();
1261 if (m == 16 && n == 16 && k == 128 && isGfx950) {
1264 return ROCDL::smfmac_i32_16x16x128_i8::getOperationName();
1265 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1266 return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName();
1267 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1268 return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName();
1269 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1270 return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName();
1271 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1272 return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName();
1275 if (m == 32 && n == 32 && k == 16) {
1277 return ROCDL::smfmac_f32_32x32x16_f16::getOperationName();
1279 return ROCDL::smfmac_f32_32x32x16_bf16::getOperationName();
1282 if (m == 32 && n == 32 && k == 32) {
1285 return ROCDL::smfmac_f32_32x32x32_f16::getOperationName();
1287 return ROCDL::smfmac_f32_32x32x32_bf16::getOperationName();
1291 return ROCDL::smfmac_i32_32x32x32_i8::getOperationName();
1292 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1293 return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName();
1294 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1295 return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName();
1296 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1297 return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName();
1298 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1299 return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName();
1302 if (m == 32 && n == 32 && k == 64 && isGfx950) {
1305 return ROCDL::smfmac_i32_32x32x64_i8::getOperationName();
1306 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1307 return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName();
1308 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1309 return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName();
1310 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1311 return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName();
1312 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1313 return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName();
1316 return std::nullopt;
1324 auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
1325 auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
1326 auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1327 Type elemSourceType = sourceVectorType.getElementType();
1328 Type elemBSourceType = sourceBVectorType.getElementType();
1329 Type elemDestType = destVectorType.getElementType();
1331 const uint32_t k = wmma.getK();
1336 if (isRDNA3 || isRDNA4)
1345 return std::nullopt;
1358static std::optional<SparseWMMAOpInfo>
1364 uint32_t m = swmmac.getM(), n = swmmac.getN(), k = swmmac.getK();
1366 if ((m != 16) || (n != 16))
1367 return std::nullopt;
1374 ROCDL::swmmac_f32_16x16x32_f16::getOperationName(),
false,
false,
1378 ROCDL::swmmac_f32_16x16x32_bf16::getOperationName(),
false,
false,
1382 ROCDL::swmmac_f16_16x16x32_f16::getOperationName(),
false,
false,
1386 ROCDL::swmmac_bf16_16x16x32_bf16::getOperationName(),
false,
false,
1391 ROCDL::swmmac_i32_16x16x32_iu8::getOperationName(),
true,
false,
1396 ROCDL::swmmac_i32_16x16x32_iu4::getOperationName(),
true,
false,
1401 ROCDL::swmmac_f32_16x16x32_fp8_fp8::getOperationName(),
false,
1406 ROCDL::swmmac_f32_16x16x32_fp8_bf8::getOperationName(),
false,
1411 ROCDL::swmmac_f32_16x16x32_bf8_fp8::getOperationName(),
false,
1415 ROCDL::swmmac_f32_16x16x32_bf8_bf8::getOperationName(),
false,
1422 ROCDL::swmmac_i32_16x16x64_iu4::getOperationName(),
true,
false,
1427 const bool isGFX1250 = chipset ==
kGfx1250;
1428 const bool isWavesize64 = swmmac.getWave64();
1429 if (isGFX1250 && !isWavesize64) {
1433 ROCDL::swmmac_f32_16x16x64_f16::getOperationName(),
true,
true,
1437 ROCDL::swmmac_f32_16x16x64_bf16::getOperationName(),
true,
true,
1441 ROCDL::swmmac_f16_16x16x64_f16::getOperationName(),
true,
true,
1445 ROCDL::swmmac_bf16_16x16x64_bf16::getOperationName(),
true,
true,
1452 ROCDL::swmmac_f32_16x16x128_fp8_fp8::getOperationName(),
false,
1457 ROCDL::swmmac_f32_16x16x128_fp8_bf8::getOperationName(),
false,
1462 ROCDL::swmmac_f32_16x16x128_bf8_fp8::getOperationName(),
false,
1466 ROCDL::swmmac_f32_16x16x128_bf8_bf8::getOperationName(),
false,
1471 ROCDL::swmmac_f16_16x16x128_fp8_fp8::getOperationName(),
false,
1476 ROCDL::swmmac_f16_16x16x128_fp8_bf8::getOperationName(),
false,
1481 ROCDL::swmmac_f16_16x16x128_bf8_fp8::getOperationName(),
false,
1485 ROCDL::swmmac_f16_16x16x128_bf8_bf8::getOperationName(),
false,
1490 ROCDL::swmmac_f16_16x16x128_bf8_bf8::getOperationName(),
false,
1495 ROCDL::swmmac_i32_16x16x128_iu8::getOperationName(),
true,
true,
1500 return std::nullopt;
1505 MFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1506 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
1511 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1512 ConversionPatternRewriter &rewriter)
const override {
1513 Location loc = op.getLoc();
1514 Type outType = typeConverter->convertType(op.getDestD().getType());
1515 Type intrinsicOutType = outType;
1516 if (
auto outVecType = dyn_cast<VectorType>(outType))
1517 if (outVecType.getElementType().isBF16())
1518 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
1520 if (chipset.majorVersion != 9 || chipset <
kGfx908)
1521 return op->emitOpError(
"MFMA only supported on gfx908+");
1522 uint32_t getBlgpField =
static_cast<uint32_t
>(op.getBlgp());
1523 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1525 return op.emitOpError(
"negation unsupported on older than gfx942");
1527 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1530 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1532 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1533 return op.emitOpError(
"no intrinsic matching MFMA size on given chipset");
1536 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1538 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1539 return op.emitOpError(
1540 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1541 "be scaled as those fields are used for type information");
1544 StringRef intrinsicName =
1545 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1548 bool allowBf16 = [&]() {
1553 return intrinsicName.contains(
"16x16x32.bf16") ||
1554 intrinsicName.contains(
"32x32x16.bf16");
1556 OperationState loweredOp(loc, intrinsicName);
1557 loweredOp.addTypes(intrinsicOutType);
1559 rewriter, loc, adaptor.getSourceA(), allowBf16),
1561 rewriter, loc, adaptor.getSourceB(), allowBf16),
1562 adaptor.getDestC()});
1565 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1566 loweredOp.addOperands({zero, zero});
1567 loweredOp.addAttributes({{
"cbsz", rewriter.getI32IntegerAttr(aTypeCode)},
1568 {
"blgp", rewriter.getI32IntegerAttr(bTypeCode)},
1569 {
"opselA", rewriter.getI32IntegerAttr(0)},
1570 {
"opselB", rewriter.getI32IntegerAttr(0)}});
1572 loweredOp.addAttributes(
1573 {{
"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1574 {
"abid", rewriter.getI32IntegerAttr(op.getAbid())},
1575 {
"blgp", rewriter.getI32IntegerAttr(getBlgpField)}});
1577 Value lowered = rewriter.create(loweredOp)->getResult(0);
1578 if (outType != intrinsicOutType)
1579 lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1580 rewriter.replaceOp(op, lowered);
1586 ScaledMFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1587 : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1592 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1593 ConversionPatternRewriter &rewriter)
const override {
1594 Location loc = op.getLoc();
1595 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1597 if (chipset.majorVersion != 9 || chipset <
kGfx950)
1598 return op->emitOpError(
"scaled MFMA only supported on gfx908+");
1599 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1601 if (!maybeScaledIntrinsic.has_value())
1602 return op.emitOpError(
1603 "no intrinsic matching scaled MFMA size on given chipset");
1605 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1606 OperationState loweredOp(loc, intrinsicName);
1607 loweredOp.addTypes(intrinsicOutType);
1608 loweredOp.addOperands(
1611 adaptor.getDestC()});
1612 loweredOp.addOperands(
1617 loweredOp.addAttributes(
1618 {{
"cbsz", rewriter.getI32IntegerAttr(aTypeCode)},
1619 {
"blgp", rewriter.getI32IntegerAttr(bTypeCode)},
1620 {
"opselA", rewriter.getI32IntegerAttr(adaptor.getScalesIdxA())},
1621 {
"opselB", rewriter.getI32IntegerAttr(adaptor.getScalesIdxB())}});
1623 Value lowered = rewriter.create(loweredOp)->getResult(0);
1624 rewriter.replaceOp(op, lowered);
1630 SparseMFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1631 : ConvertOpToLLVMPattern<SparseMFMAOp>(converter), chipset(chipset) {}
1636 matchAndRewrite(SparseMFMAOp op, SparseMFMAOpAdaptor adaptor,
1637 ConversionPatternRewriter &rewriter)
const override {
1638 Location loc = op.getLoc();
1640 typeConverter->convertType<VectorType>(op.getDestC().
getType());
1642 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1645 if (chipset.majorVersion != 9 || chipset <
kGfx942)
1646 return op->emitOpError(
"sparse MFMA (smfmac) only supported on gfx942+");
1647 bool isGfx950 = chipset >=
kGfx950;
1653 Value c = adaptor.getDestC();
1656 if (!maybeIntrinsic.has_value())
1657 return op.emitOpError(
1658 "no intrinsic matching sparse MFMA on the given chipset");
1661 Value sparseIdx = LLVM::BitcastOp::create(
1662 rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
1664 OperationState loweredOp(loc, maybeIntrinsic.value());
1665 loweredOp.addTypes(outType);
1666 loweredOp.addOperands({a,
b, c, sparseIdx});
1667 loweredOp.addAttributes(
1668 {{
"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1669 {
"abid", rewriter.getI32IntegerAttr(op.getAbid())}});
1670 Value lowered = rewriter.create(loweredOp)->getResult(0);
1671 rewriter.replaceOp(op, lowered);
1677 WMMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1678 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1683 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1684 ConversionPatternRewriter &rewriter)
const override {
1685 Location loc = op.getLoc();
1687 typeConverter->convertType<VectorType>(op.getDestD().
getType());
1689 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1691 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1692 return op->emitOpError(
"WMMA only supported on gfx11 and gfx12");
1694 bool isGFX1250 = chipset >=
kGfx1250;
1699 auto aType = cast<VectorType>(adaptor.getSourceA().getType());
1700 auto bType = cast<VectorType>(adaptor.getSourceB().getType());
1701 auto destCType = cast<VectorType>(adaptor.getDestC().getType());
1702 bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
1703 bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
1704 bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
1705 bool castOutToI16 = outType.getElementType().
isBF16() && !isGFX1250;
1706 VectorType rawOutType = outType;
1708 rawOutType = outType.clone(rewriter.getI16Type());
1709 Value a = adaptor.getSourceA();
1711 a = LLVM::BitcastOp::create(rewriter, loc,
1712 aType.clone(rewriter.getI16Type()), a);
1713 Value
b = adaptor.getSourceB();
1715 b = LLVM::BitcastOp::create(rewriter, loc,
1716 bType.clone(rewriter.getI16Type()),
b);
1717 Value destC = adaptor.getDestC();
1719 destC = LLVM::BitcastOp::create(
1720 rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
1724 if (!maybeIntrinsic.has_value())
1725 return op.emitOpError(
"no intrinsic matching WMMA on the given chipset");
1727 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1728 return op.emitOpError(
"subwordOffset not supported on gfx12+");
1730 SmallVector<Value, 4> operands;
1731 SmallVector<NamedAttribute, 4> attrs;
1733 op.getSourceA(), operands, attrs,
"signA");
1735 op.getSourceB(), operands, attrs,
"signB");
1737 op.getSubwordOffset(), op.getClamp(), operands,
1740 OperationState loweredOp(loc, *maybeIntrinsic);
1741 loweredOp.addTypes(rawOutType);
1742 loweredOp.addOperands(operands);
1743 loweredOp.addAttributes(attrs);
1744 Operation *lowered = rewriter.create(loweredOp);
1746 Operation *maybeCastBack = lowered;
1747 if (rawOutType != outType)
1748 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1750 rewriter.replaceOp(op, maybeCastBack->
getResults());
1757 SparseWMMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1758 : ConvertOpToLLVMPattern<SparseWMMAOp>(converter), chipset(chipset) {}
1763 matchAndRewrite(SparseWMMAOp op, SparseWMMAOpAdaptor adaptor,
1764 ConversionPatternRewriter &rewriter)
const override {
1765 Location loc = op.getLoc();
1767 typeConverter->convertType<VectorType>(op.getDestD().
getType());
1769 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1771 std::optional<SparseWMMAOpInfo> maybeIntrinsic =
1774 if (!maybeIntrinsic.has_value())
1775 return op.emitOpError(
1776 "no intrinsic matching Sparse WMMA on the given chipset");
1777 SparseWMMAOpInfo intrinsic = maybeIntrinsic.value();
1779 SmallVector<NamedAttribute> attrs;
1781 if ((op.getUnsignedA() || op.getUnsignedB()) && !intrinsic.
useSign)
1782 return op->emitOpError(
"intrinsic doesn't support unsign");
1784 if (
auto attr = op.getUnsignedAAttr())
1785 attrs.push_back({
"signA", attr});
1786 if (
auto attr = op.getUnsignedBAttr())
1787 attrs.push_back({
"signB", attr});
1790 if ((op.getReuseA() || op.getReuseB()) && !intrinsic.
useReuse)
1791 return op->emitOpError(
"intrinsic doesn't support reuse");
1793 if (
auto attr = op.getReuseAAttr())
1794 attrs.push_back({
"reuseA", attr});
1795 if (
auto attr = op.getReuseBAttr())
1796 attrs.push_back({
"reuseB", attr});
1799 if (op.getClamp() && !intrinsic.
useClamp)
1800 return op->emitOpError(
"intrinsic doesn't support clamp");
1801 if (intrinsic.
useClamp && op.getClampAttr())
1802 attrs.push_back({
"clamp", op.getClampAttr()});
1804 const bool isGFX1250orHigher =
1805 chipset.majorVersion == 12 && chipset.minorVersion >= 5;
1810 Value c = adaptor.getDestC();
1811 VectorType rawOutType = outType;
1812 if (!isGFX1250orHigher) {
1814 rawOutType = cast<VectorType>(c.
getType());
1818 Value sparseIdx = LLVM::BitcastOp::create(
1819 rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
1821 OperationState loweredOp(loc, intrinsic.
name);
1822 loweredOp.addTypes(rawOutType);
1823 loweredOp.addOperands({a,
b, c, sparseIdx});
1824 loweredOp.addAttributes(attrs);
1825 Operation *lowered = rewriter.create(loweredOp);
1827 Operation *maybeCastBack = lowered;
1828 if (rawOutType != outType)
1829 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1831 rewriter.replaceOp(op, maybeCastBack->
getResults());
1838 ScaledWMMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1839 : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
1844 matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
1845 ConversionPatternRewriter &rewriter)
const override {
1846 Location loc = op.getLoc();
1848 typeConverter->convertType<VectorType>(op.getDestD().
getType());
1850 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1853 return op->emitOpError(
"WMMA scale only supported on gfx1250+");
1855 int64_t m = op.getM();
1856 int64_t n = op.getN();
1857 int64_t k = op.getK();
1865 if (!aFmtCode || !bFmtCode)
1866 return op.emitOpError(
"unsupported element types for scaled_wmma");
1869 auto scaleAVecType = cast<VectorType>(op.getScaleA().getType());
1870 auto scaleBVecType = cast<VectorType>(op.getScaleB().getType());
1872 if (scaleAVecType.getNumElements() != scaleBVecType.getNumElements())
1873 return op.emitOpError(
"scaleA and scaleB must have equal vector length");
1876 Type scaleAElemType = scaleAVecType.getElementType();
1877 Type scaleBElemType = scaleBVecType.getElementType();
1882 if (!scaleAFmt || !scaleBFmt)
1883 return op.emitOpError(
"unsupported scale element types");
1886 bool isScale16 = (scaleAVecType.getNumElements() == 8);
1887 std::optional<StringRef> intrinsicName =
1890 return op.emitOpError(
"unsupported scaled_wmma dimensions: ")
1891 << m <<
"x" << n <<
"x" << k;
1893 SmallVector<NamedAttribute, 8> attrs;
1896 bool is32x16 = (m == 32 && n == 16 && k == 128);
1898 attrs.emplace_back(
"fmtA", rewriter.getI32IntegerAttr(*aFmtCode));
1899 attrs.emplace_back(
"fmtB", rewriter.getI32IntegerAttr(*bFmtCode));
1903 attrs.emplace_back(
"modC", rewriter.getI16IntegerAttr(0));
1908 "scaleAType", rewriter.getI32IntegerAttr(op.getAFirstScaleLane() / 16));
1909 attrs.emplace_back(
"fmtScaleA", rewriter.getI32IntegerAttr(*scaleAFmt));
1911 "scaleBType", rewriter.getI32IntegerAttr(op.getBFirstScaleLane() / 16));
1912 attrs.emplace_back(
"fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt));
1915 attrs.emplace_back(
"reuseA", rewriter.getBoolAttr(
false));
1916 attrs.emplace_back(
"reuseB", rewriter.getBoolAttr(
false));
1929 OperationState loweredOp(loc, *intrinsicName);
1930 loweredOp.addTypes(outType);
1931 loweredOp.addOperands(
1932 {sourceA, sourceB, adaptor.getDestC(), packedScaleA, packedScaleB});
1933 loweredOp.addAttributes(attrs);
1935 Operation *lowered = rewriter.create(loweredOp);
1936 rewriter.replaceOp(op, lowered->
getResults());
1942struct TransposeLoadOpLowering
1944 TransposeLoadOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1945 : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1950 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1951 ConversionPatternRewriter &rewriter)
const override {
1953 return op.emitOpError(
"Non-gfx950 chipset not supported");
1955 Location loc = op.getLoc();
1956 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1960 size_t srcElementSize =
1961 srcMemRefType.getElementType().getIntOrFloatBitWidth();
1962 if (srcElementSize < 8)
1963 return op.emitOpError(
"Expect source memref to have at least 8 bits "
1964 "element size, got ")
1967 auto resultType = cast<VectorType>(op.getResult().getType());
1970 (adaptor.getSrcIndices()));
1972 size_t numElements = resultType.getNumElements();
1973 size_t elementTypeSize =
1974 resultType.getElementType().getIntOrFloatBitWidth();
1978 Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
1979 rewriter.getIntegerType(32));
1980 Type llvmResultType = typeConverter->convertType(resultType);
1982 switch (elementTypeSize) {
1984 assert(numElements == 16);
1985 auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
1986 rocdlResultType, srcPtr);
1987 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1991 assert(numElements == 16);
1992 auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
1993 rocdlResultType, srcPtr);
1994 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1998 assert(numElements == 8);
1999 auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
2000 rocdlResultType, srcPtr);
2001 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2005 assert(numElements == 4);
2006 rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
2011 return op.emitOpError(
"Unsupported element size for transpose load");
2018 GatherToLDSOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2019 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
2024 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
2025 ConversionPatternRewriter &rewriter)
const override {
2026 if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
2027 return op.emitOpError(
"pre-gfx9 and post-gfx10 not supported");
2029 Location loc = op.getLoc();
2031 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2032 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
2037 Type transferType = op.getTransferType();
2038 int loadWidth = [&]() ->
int {
2039 if (
auto transferVectorType = dyn_cast<VectorType>(transferType)) {
2040 return (transferVectorType.getNumElements() *
2041 transferVectorType.getElementTypeBitWidth()) /
2048 if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
2049 return op.emitOpError(
"chipset unsupported element size");
2051 if (chipset !=
kGfx950 && llvm::is_contained({12, 16}, loadWidth))
2052 return op.emitOpError(
"Gather to LDS instructions with 12-byte and "
2053 "16-byte load widths are only supported on gfx950");
2057 (adaptor.getSrcIndices()));
2060 (adaptor.getDstIndices()));
2062 if (op.getAsync()) {
2063 rewriter.replaceOpWithNewOp<ROCDL::LoadAsyncToLDSOp>(
2064 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
2065 rewriter.getI32IntegerAttr(0),
2069 rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
2070 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
2071 rewriter.getI32IntegerAttr(0),
2080struct GlobalLoadAsyncToLDSOpLowering
2082 GlobalLoadAsyncToLDSOpLowering(
const LLVMTypeConverter &converter,
2084 : ConvertOpToLLVMPattern<GlobalLoadAsyncToLDSOp>(converter),
2090 matchAndRewrite(GlobalLoadAsyncToLDSOp op,
2091 GlobalLoadAsyncToLDSOpAdaptor adaptor,
2092 ConversionPatternRewriter &rewriter)
const override {
2094 return op.emitOpError(
2095 "global_load_async_to_lds is only supported on gfx1250+");
2097 Location loc = op.getLoc();
2098 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2099 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
2101 Type transferType = op.getTransferType();
2103 isa<VectorType>(transferType)
2104 ? cast<VectorType>(transferType).getNumElements() *
2105 cast<VectorType>(transferType).getElementTypeBitWidth()
2110 adaptor.getSrcIndices());
2113 adaptor.getDstIndices());
2116 Value mask = adaptor.getMask();
2117 int64_t nullptrVal =
2118 llvm::AMDGPU::getNullPointerValue(llvm::AMDGPUAS::LOCAL_ADDRESS);
2122 LLVM::IntToPtrOp::create(rewriter, loc, dstPtr.
getType(), nullInt);
2123 dstPtr = LLVM::SelectOp::create(rewriter, loc, mask, dstPtr, nullPtr);
2126 auto offset = rewriter.getI32IntegerAttr(0);
2127 auto aux = rewriter.getI32IntegerAttr(0);
2129 switch (transferBits) {
2131 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB8Op>(
2136 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB32Op>(
2141 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB64Op>(
2146 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB128Op>(
2151 return op.emitOpError(
"unsupported transfer width");
2158struct ExtPackedFp8OpLowering final
2160 ExtPackedFp8OpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2161 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
2166 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
2167 ConversionPatternRewriter &rewriter)
const override;
2170struct ScaledExtPackedMatrixOpLowering final
2172 ScaledExtPackedMatrixOpLowering(
const LLVMTypeConverter &converter,
2174 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
2179 matchAndRewrite(ScaledExtPackedMatrixOp op,
2180 ScaledExtPackedMatrixOpAdaptor adaptor,
2181 ConversionPatternRewriter &rewriter)
const override;
2184struct PackedTrunc2xFp8OpLowering final
2186 PackedTrunc2xFp8OpLowering(
const LLVMTypeConverter &converter,
2188 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
2193 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
2194 ConversionPatternRewriter &rewriter)
const override;
2197struct PackedStochRoundFp8OpLowering final
2199 PackedStochRoundFp8OpLowering(
const LLVMTypeConverter &converter,
2201 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
2206 matchAndRewrite(PackedStochRoundFp8Op op,
2207 PackedStochRoundFp8OpAdaptor adaptor,
2208 ConversionPatternRewriter &rewriter)
const override;
2211struct ScaledExtPackedOpLowering final
2213 ScaledExtPackedOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2214 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
2219 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2220 ConversionPatternRewriter &rewriter)
const override;
2223struct PackedScaledTruncOpLowering final
2225 PackedScaledTruncOpLowering(
const LLVMTypeConverter &converter,
2227 : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
2232 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2233 ConversionPatternRewriter &rewriter)
const override;
2238LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
2239 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
2240 ConversionPatternRewriter &rewriter)
const {
2241 Location loc = op.getLoc();
2243 return rewriter.notifyMatchFailure(
2244 loc,
"Fp8 conversion instructions are not available on target "
2245 "architecture and their emulation is not implemented");
2247 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
2248 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2249 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
2251 Value source = adaptor.getSource();
2252 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
2253 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
2256 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
2257 Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
2258 if (!sourceVecType) {
2259 longVec = LLVM::InsertElementOp::create(
2262 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2264 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2266 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2271 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2272 if (resultVecType) {
2274 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
2277 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
2282 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
2285 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
2292int32_t getScaleSel(int32_t blockSize,
unsigned bitWidth, int32_t scaleWaveHalf,
2293 int32_t firstScaleByte) {
2299 assert(llvm::is_contained({16, 32}, blockSize));
2300 assert(llvm::is_contained({4u, 6u, 8u}, bitWidth));
2302 const bool isFp8 = bitWidth == 8;
2303 const bool isBlock16 = blockSize == 16;
2306 int32_t bit0 = isBlock16;
2307 assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
2308 int32_t bit1 = (firstScaleByte == 2) << 1;
2309 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
2310 int32_t bit2 = scaleWaveHalf << 2;
2311 return bit2 | bit1 | bit0;
2314 int32_t bit0 = isBlock16;
2316 assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
2317 int32_t bits2and1 = firstScaleByte << 1;
2318 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
2319 int32_t bit3 = scaleWaveHalf << 3;
2320 int32_t bits = bit3 | bits2and1 | bit0;
2322 assert(!llvm::is_contained(
2323 {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
2327static std::optional<StringRef>
2328scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
2329 using fp4 = Float4E2M1FNType;
2330 using fp8 = Float8E4M3FNType;
2331 using bf8 = Float8E5M2Type;
2332 using fp6 = Float6E2M3FNType;
2333 using bf6 = Float6E3M2FNType;
2334 if (isa<fp4>(srcElemType)) {
2335 if (destElemType.
isF16())
2336 return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
2337 if (destElemType.
isBF16())
2338 return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
2339 if (destElemType.
isF32())
2340 return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
2341 return std::nullopt;
2343 if (isa<fp8>(srcElemType)) {
2344 if (destElemType.
isF16())
2345 return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
2346 if (destElemType.
isBF16())
2347 return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
2348 if (destElemType.
isF32())
2349 return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
2350 return std::nullopt;
2352 if (isa<bf8>(srcElemType)) {
2353 if (destElemType.
isF16())
2354 return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
2355 if (destElemType.
isBF16())
2356 return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
2357 if (destElemType.
isF32())
2358 return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
2359 return std::nullopt;
2361 if (isa<fp6>(srcElemType)) {
2362 if (destElemType.
isF16())
2363 return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
2364 if (destElemType.
isBF16())
2365 return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
2366 if (destElemType.
isF32())
2367 return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
2368 return std::nullopt;
2370 if (isa<bf6>(srcElemType)) {
2371 if (destElemType.
isF16())
2372 return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
2373 if (destElemType.
isBF16())
2374 return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
2375 if (destElemType.
isF32())
2376 return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
2377 return std::nullopt;
2379 llvm_unreachable(
"invalid combination of element types for packed conversion "
2383LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
2384 ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
2385 ConversionPatternRewriter &rewriter)
const {
2386 using fp4 = Float4E2M1FNType;
2387 using fp8 = Float8E4M3FNType;
2388 using bf8 = Float8E5M2Type;
2389 using fp6 = Float6E2M3FNType;
2390 using bf6 = Float6E3M2FNType;
2391 Location loc = op.getLoc();
2393 return rewriter.notifyMatchFailure(
2395 "Scaled fp packed conversion instructions are not available on target "
2396 "architecture and their emulation is not implemented");
2400 int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
2401 int32_t firstScaleByte = op.getFirstScaleByte();
2402 int32_t blockSize = op.getBlockSize();
2403 auto sourceType = cast<VectorType>(op.getSource().getType());
2404 auto srcElemType = cast<FloatType>(sourceType.getElementType());
2405 unsigned bitWidth = srcElemType.getWidth();
2407 auto targetType = cast<VectorType>(op.getResult().getType());
2408 auto destElemType = cast<FloatType>(targetType.getElementType());
2410 IntegerType i32 = rewriter.getI32Type();
2411 Value source = adaptor.getSource();
2412 Type llvmResultType = typeConverter->convertType(op.getResult().getType());
2413 Type packedType =
nullptr;
2414 if (isa<fp4>(srcElemType)) {
2416 packedType = getTypeConverter()->convertType(packedType);
2417 }
else if (isa<fp8, bf8>(srcElemType)) {
2418 packedType = VectorType::get(2, i32);
2419 packedType = getTypeConverter()->convertType(packedType);
2420 }
else if (isa<fp6, bf6>(srcElemType)) {
2421 packedType = VectorType::get(3, i32);
2422 packedType = getTypeConverter()->convertType(packedType);
2424 llvm_unreachable(
"invalid element type for packed scaled ext");
2427 if (!packedType || !llvmResultType) {
2428 return rewriter.notifyMatchFailure(op,
"type conversion failed");
2431 std::optional<StringRef> maybeIntrinsic =
2432 scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
2433 if (!maybeIntrinsic.has_value())
2434 return op.emitOpError(
2435 "no intrinsic matching packed scaled conversion on the given chipset");
2438 getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
2440 LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
2441 Value castedSource =
2442 LLVM::BitcastOp::create(rewriter, loc, packedType, source);
2444 OperationState loweredOp(loc, *maybeIntrinsic);
2445 loweredOp.addTypes({llvmResultType});
2446 loweredOp.addOperands({castedSource, castedScale});
2448 SmallVector<NamedAttribute, 1> attrs;
2450 NamedAttribute(
"scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
2452 loweredOp.addAttributes(attrs);
2453 Operation *lowered = rewriter.create(loweredOp);
2454 rewriter.replaceOp(op, lowered);
2459LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
2460 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2461 ConversionPatternRewriter &rewriter)
const {
2462 Location loc = op.getLoc();
2464 return rewriter.notifyMatchFailure(
2465 loc,
"Scaled fp conversion instructions are not available on target "
2466 "architecture and their emulation is not implemented");
2467 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2469 Value source = adaptor.getSource();
2470 Value scale = adaptor.getScale();
2472 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2473 Type sourceElemType = sourceVecType.getElementType();
2474 VectorType destVecType = cast<VectorType>(op.getResult().getType());
2475 Type destElemType = destVecType.getElementType();
2477 VectorType packedVecType;
2478 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
2479 VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
2480 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
2481 }
else if (isa<Float4E2M1FNType>(sourceElemType)) {
2482 VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
2483 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
2485 llvm_unreachable(
"invalid element type for scaled ext");
2489 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
2490 Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
2491 if (!sourceVecType) {
2492 longVec = LLVM::InsertElementOp::create(
2495 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2497 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2499 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2504 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2506 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF32())
2507 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
2508 op, destVecType, i32Source, scale, op.getIndex());
2509 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF16())
2510 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
2511 op, destVecType, i32Source, scale, op.getIndex());
2512 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isBF16())
2513 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
2514 op, destVecType, i32Source, scale, op.getIndex());
2515 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF32())
2516 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
2517 op, destVecType, i32Source, scale, op.getIndex());
2518 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF16())
2519 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
2520 op, destVecType, i32Source, scale, op.getIndex());
2521 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isBF16())
2522 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
2523 op, destVecType, i32Source, scale, op.getIndex());
2524 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF32())
2525 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
2526 op, destVecType, i32Source, scale, op.getIndex());
2527 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF16())
2528 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
2529 op, destVecType, i32Source, scale, op.getIndex());
2530 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isBF16())
2531 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
2532 op, destVecType, i32Source, scale, op.getIndex());
2539LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
2540 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2541 ConversionPatternRewriter &rewriter)
const {
2542 Location loc = op.getLoc();
2544 return rewriter.notifyMatchFailure(
2545 loc,
"Scaled fp conversion instructions are not available on target "
2546 "architecture and their emulation is not implemented");
2547 Type v2i16 = getTypeConverter()->convertType(
2548 VectorType::get(2, rewriter.getI16Type()));
2549 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2551 Type resultType = op.getResult().getType();
2553 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2554 Type sourceElemType = sourceVecType.getElementType();
2556 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
2558 Value source = adaptor.getSource();
2559 Value scale = adaptor.getScale();
2560 Value existing = adaptor.getExisting();
2562 existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
2564 existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
2566 if (sourceVecType.getNumElements() < 2) {
2568 Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2569 VectorType v2 = VectorType::get(2, sourceElemType);
2570 source = LLVM::ZeroOp::create(rewriter, loc, v2);
2571 source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
2574 Value sourceA, sourceB;
2575 if (sourceElemType.
isF32()) {
2578 sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2579 sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
2583 if (sourceElemType.
isF32() && isa<Float8E5M2Type>(resultElemType))
2584 result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
2585 existing, sourceA, sourceB,
2586 scale, op.getIndex());
2587 else if (sourceElemType.
isF16() && isa<Float8E5M2Type>(resultElemType))
2588 result = ROCDL::CvtScaleF32PkBf8F16Op::create(
2589 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2590 else if (sourceElemType.
isBF16() && isa<Float8E5M2Type>(resultElemType))
2591 result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
2592 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2593 else if (sourceElemType.
isF32() && isa<Float8E4M3FNType>(resultElemType))
2594 result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
2595 existing, sourceA, sourceB,
2596 scale, op.getIndex());
2597 else if (sourceElemType.
isF16() && isa<Float8E4M3FNType>(resultElemType))
2598 result = ROCDL::CvtScaleF32PkFp8F16Op::create(
2599 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2600 else if (sourceElemType.
isBF16() && isa<Float8E4M3FNType>(resultElemType))
2601 result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
2602 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2603 else if (sourceElemType.
isF32() && isa<Float4E2M1FNType>(resultElemType))
2604 result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
2605 existing, sourceA, sourceB,
2606 scale, op.getIndex());
2607 else if (sourceElemType.
isF16() && isa<Float4E2M1FNType>(resultElemType))
2608 result = ROCDL::CvtScaleF32PkFp4F16Op::create(
2609 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2610 else if (sourceElemType.
isBF16() && isa<Float4E2M1FNType>(resultElemType))
2611 result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
2612 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2616 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2617 op, getTypeConverter()->convertType(resultType),
result);
2621LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
2622 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
2623 ConversionPatternRewriter &rewriter)
const {
2624 Location loc = op.getLoc();
2626 return rewriter.notifyMatchFailure(
2627 loc,
"Fp8 conversion instructions are not available on target "
2628 "architecture and their emulation is not implemented");
2629 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2631 Type resultType = op.getResult().getType();
2634 Value sourceA = adaptor.getSourceA();
2635 Value sourceB = adaptor.getSourceB();
2637 sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.
getType());
2638 Value existing = adaptor.getExisting();
2640 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2642 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2646 result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2647 existing, op.getWordIndex());
2649 result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2650 existing, op.getWordIndex());
2652 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2653 op, getTypeConverter()->convertType(resultType),
result);
2657LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
2658 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
2659 ConversionPatternRewriter &rewriter)
const {
2660 Location loc = op.getLoc();
2662 return rewriter.notifyMatchFailure(
2663 loc,
"Fp8 conversion instructions are not available on target "
2664 "architecture and their emulation is not implemented");
2665 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2667 Type resultType = op.getResult().getType();
2670 Value source = adaptor.getSource();
2671 Value stoch = adaptor.getStochiasticParam();
2672 Value existing = adaptor.getExisting();
2674 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2676 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2680 result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
2681 existing, op.getStoreIndex());
2683 result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
2684 existing, op.getStoreIndex());
2686 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2687 op, getTypeConverter()->convertType(resultType),
result);
2693struct AMDGPUDPPLowering :
public ConvertOpToLLVMPattern<DPPOp> {
2694 AMDGPUDPPLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2695 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
2699 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
2700 ConversionPatternRewriter &rewriter)
const override {
2703 Location loc = DppOp.getLoc();
2704 Value src = adaptor.getSrc();
2705 Value old = adaptor.getOld();
2708 Type llvmType =
nullptr;
2710 llvmType = rewriter.getI32Type();
2711 }
else if (isa<FloatType>(srcType)) {
2713 ? rewriter.getF32Type()
2714 : rewriter.getF64Type();
2715 }
else if (isa<IntegerType>(srcType)) {
2717 ? rewriter.getI32Type()
2718 : rewriter.getI64Type();
2720 auto llvmSrcIntType = typeConverter->convertType(
2724 auto convertOperand = [&](Value operand, Type operandType) {
2725 if (operandType.getIntOrFloatBitWidth() <= 16) {
2726 if (llvm::isa<FloatType>(operandType)) {
2728 LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
2730 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
2731 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
2732 Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
2734 LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
2736 operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
2741 src = convertOperand(src, srcType);
2742 old = convertOperand(old, oldType);
2745 enum DppCtrl :
unsigned {
2754 ROW_HALF_MIRROR = 0x141,
2759 auto kind = DppOp.getKind();
2760 auto permArgument = DppOp.getPermArgument();
2761 uint32_t DppCtrl = 0;
2765 case DPPPerm::quad_perm: {
2766 auto quadPermAttr = cast<ArrayAttr>(*permArgument);
2768 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
2769 uint32_t num = elem.getInt();
2770 DppCtrl |= num << (i * 2);
2775 case DPPPerm::row_shl: {
2776 auto intAttr = cast<IntegerAttr>(*permArgument);
2777 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
2780 case DPPPerm::row_shr: {
2781 auto intAttr = cast<IntegerAttr>(*permArgument);
2782 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
2785 case DPPPerm::row_ror: {
2786 auto intAttr = cast<IntegerAttr>(*permArgument);
2787 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
2790 case DPPPerm::wave_shl:
2791 DppCtrl = DppCtrl::WAVE_SHL1;
2793 case DPPPerm::wave_shr:
2794 DppCtrl = DppCtrl::WAVE_SHR1;
2796 case DPPPerm::wave_rol:
2797 DppCtrl = DppCtrl::WAVE_ROL1;
2799 case DPPPerm::wave_ror:
2800 DppCtrl = DppCtrl::WAVE_ROR1;
2802 case DPPPerm::row_mirror:
2803 DppCtrl = DppCtrl::ROW_MIRROR;
2805 case DPPPerm::row_half_mirror:
2806 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
2808 case DPPPerm::row_bcast_15:
2809 DppCtrl = DppCtrl::BCAST15;
2811 case DPPPerm::row_bcast_31:
2812 DppCtrl = DppCtrl::BCAST31;
2818 auto rowMask = DppOp->getAttrOfType<IntegerAttr>(
"row_mask").getInt();
2819 auto bankMask = DppOp->getAttrOfType<IntegerAttr>(
"bank_mask").getInt();
2820 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>(
"bound_ctrl").getValue();
2824 ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
2825 rowMask, bankMask, boundCtrl);
2827 Value
result = dppMovOp.getRes();
2829 result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType,
result);
2830 if (!llvm::isa<IntegerType>(srcType)) {
2831 result = LLVM::BitcastOp::create(rewriter, loc, srcType,
result);
2842struct AMDGPUSwizzleBitModeLowering
2843 :
public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
2847 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
2848 ConversionPatternRewriter &rewriter)
const override {
2849 Location loc = op.getLoc();
2850 Type i32 = rewriter.getI32Type();
2851 Value src = adaptor.getSrc();
2852 SmallVector<Value> decomposed;
2853 if (
failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
2854 return rewriter.notifyMatchFailure(op,
2855 "failed to decompose value to i32");
2856 unsigned andMask = op.getAndMask();
2857 unsigned orMask = op.getOrMask();
2858 unsigned xorMask = op.getXorMask();
2862 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
2864 SmallVector<Value> swizzled;
2865 for (Value v : decomposed) {
2867 ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
2868 swizzled.emplace_back(res);
2871 Value
result = LLVM::composeValue(rewriter, loc, swizzled, src.
getType());
2872 rewriter.replaceOp(op,
result);
2877struct AMDGPUPermlaneLowering :
public ConvertOpToLLVMPattern<PermlaneSwapOp> {
2880 AMDGPUPermlaneLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2881 : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
2885 matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
2886 ConversionPatternRewriter &rewriter)
const override {
2888 return op->emitOpError(
"permlane_swap is only supported on gfx950+");
2890 Location loc = op.getLoc();
2891 Type i32 = rewriter.getI32Type();
2892 Value src = adaptor.getSrc();
2893 unsigned rowLength = op.getRowLength();
2894 bool fi = op.getFetchInactive();
2895 bool boundctrl = op.getBoundCtrl();
2897 SmallVector<Value> decomposed;
2898 if (
failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
2899 return rewriter.notifyMatchFailure(op,
2900 "failed to decompose value to i32");
2902 SmallVector<Value> permuted;
2903 for (Value v : decomposed) {
2905 Type i32pair = LLVM::LLVMStructType::getLiteral(
2906 rewriter.getContext(), {v.getType(), v.getType()});
2908 if (rowLength == 16)
2909 res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2911 else if (rowLength == 32)
2912 res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2915 llvm_unreachable(
"unsupported row length");
2917 Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
2918 Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
2920 Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
2921 LLVM::ICmpPredicate::eq, vdst0, v);
2926 LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
2927 permuted.emplace_back(vdstNew);
2930 Value
result = LLVM::composeValue(rewriter, loc, permuted, src.
getType());
2931 rewriter.replaceOp(op,
result);
2944constexpr int32_t kDsBarrierPendingCountBitWidth = 29;
2945constexpr int32_t kDsBarrierPhasePos = kDsBarrierPendingCountBitWidth;
2946constexpr int32_t kDsBarrierInitCountPos = 32;
2947constexpr int32_t kDsBarrierPendingCountMask =
2948 (1 << kDsBarrierPendingCountBitWidth) - 1;
2950struct DsBarrierInitOpLowering
2951 :
public ConvertOpToLLVMPattern<DsBarrierInitOp> {
2954 DsBarrierInitOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2955 : ConvertOpToLLVMPattern<DsBarrierInitOp>(converter), chipset(chipset) {}
2958 matchAndRewrite(DsBarrierInitOp op, OpAdaptor adaptor,
2959 ConversionPatternRewriter &rewriter)
const override {
2961 return op->emitOpError(
"only supported on gfx1250+");
2963 Location loc = op.getLoc();
2964 Type i64 = rewriter.getI64Type();
2966 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
2968 adaptor.getBase(), adaptor.getIndices());
2975 LLVM::SubOp::create(rewriter, loc, adaptor.getParticipants(),
2982 Value maskedCount32 =
2983 LLVM::AndOp::create(rewriter, loc, initCount, countMask);
2984 Value maskedCount = LLVM::ZExtOp::create(rewriter, loc, i64, maskedCount32);
2986 Value initCountShifted = LLVM::ShlOp::create(
2987 rewriter, loc, maskedCount,
2989 Value barrierState =
2990 LLVM::OrOp::create(rewriter, loc, initCountShifted, maskedCount);
2992 LLVM::StoreOp::create(
2993 rewriter, loc, barrierState, ptr, 8,
false,
2995 false, LLVM::AtomicOrdering::release,
2998 rewriter.eraseOp(op);
3003struct DsBarrierPollStateOpLowering
3004 :
public ConvertOpToLLVMPattern<DsBarrierPollStateOp> {
3007 DsBarrierPollStateOpLowering(
const LLVMTypeConverter &converter,
3009 : ConvertOpToLLVMPattern<DsBarrierPollStateOp>(converter),
3013 matchAndRewrite(DsBarrierPollStateOp op, OpAdaptor adaptor,
3014 ConversionPatternRewriter &rewriter)
const override {
3016 return op->emitOpError(
"only supported on gfx1250+");
3018 Location loc = op.getLoc();
3019 Type i64 = rewriter.getI64Type();
3021 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3023 adaptor.getBase(), adaptor.getIndices());
3027 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
3028 op, i64, ptr, 8,
false,
3030 false, LLVM::AtomicOrdering::acquire,
3036struct DsAsyncBarrierArriveOpLowering
3037 :
public ConvertOpToLLVMPattern<DsAsyncBarrierArriveOp> {
3040 DsAsyncBarrierArriveOpLowering(
const LLVMTypeConverter &converter,
3042 : ConvertOpToLLVMPattern<DsAsyncBarrierArriveOp>(converter),
3046 matchAndRewrite(DsAsyncBarrierArriveOp op, OpAdaptor adaptor,
3047 ConversionPatternRewriter &rewriter)
const override {
3049 return op->emitOpError(
"only supported on gfx1250+");
3051 Location loc = op.getLoc();
3053 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3055 adaptor.getBase(), adaptor.getIndices());
3057 rewriter.replaceOpWithNewOp<ROCDL::DsAtomicAsyncBarrierArriveOp>(
3058 op, ptr,
nullptr,
nullptr,
3064struct DsBarrierArriveOpLowering
3065 :
public ConvertOpToLLVMPattern<DsBarrierArriveOp> {
3068 DsBarrierArriveOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
3069 : ConvertOpToLLVMPattern<DsBarrierArriveOp>(converter), chipset(chipset) {
3073 matchAndRewrite(DsBarrierArriveOp op, OpAdaptor adaptor,
3074 ConversionPatternRewriter &rewriter)
const override {
3076 return op->emitOpError(
"only supported on gfx1250+");
3078 Location loc = op.getLoc();
3079 Type i64 = rewriter.getI64Type();
3081 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3083 adaptor.getBase(), adaptor.getIndices());
3085 rewriter.replaceOpWithNewOp<ROCDL::DsAtomicBarrierArriveRtnOp>(
3086 op, i64, ptr, adaptor.getCount(),
nullptr,
3092struct DsBarrierStatePhaseOpLowering
3093 :
public ConvertOpToLLVMPattern<DsBarrierStatePhaseOp> {
3097 matchAndRewrite(DsBarrierStatePhaseOp op, OpAdaptor adaptor,
3098 ConversionPatternRewriter &rewriter)
const override {
3099 Location loc = op.getLoc();
3100 Type i32 = rewriter.getI32Type();
3102 Value state = adaptor.getState();
3104 Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state);
3105 Value phase = LLVM::LShrOp::create(
3106 rewriter, loc, noInitCount,
3109 rewriter.replaceOp(op, phase);
3114struct DsBarrierStatePendingCountOpLowering
3115 :
public ConvertOpToLLVMPattern<DsBarrierStatePendingCountOp> {
3119 matchAndRewrite(DsBarrierStatePendingCountOp op, OpAdaptor adaptor,
3120 ConversionPatternRewriter &rewriter)
const override {
3121 Location loc = op.getLoc();
3122 Type i32 = rewriter.getI32Type();
3124 Value state = adaptor.getState();
3126 Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state);
3127 Value pendingCount = LLVM::AndOp::create(
3128 rewriter, loc, noInitCount,
3130 static_cast<uint32_t
>(kDsBarrierPendingCountMask)));
3132 rewriter.replaceOp(op, pendingCount);
3137struct DsBarrierStateInitCountOpLowering
3138 :
public ConvertOpToLLVMPattern<DsBarrierStateInitCountOp> {
3142 matchAndRewrite(DsBarrierStateInitCountOp op, OpAdaptor adaptor,
3143 ConversionPatternRewriter &rewriter)
const override {
3144 Location loc = op.getLoc();
3145 Type i32 = rewriter.getI32Type();
3147 Value state = adaptor.getState();
3149 Value initCountI64 = LLVM::LShrOp::create(
3150 rewriter, loc, state,
3152 Value initCount = LLVM::TruncOp::create(rewriter, loc, i32, initCountI64);
3154 rewriter.replaceOp(op, initCount);
3159struct DsBarrierStatePhaseParityLowering
3160 :
public ConvertOpToLLVMPattern<DsBarrierStatePhaseParity> {
3164 matchAndRewrite(DsBarrierStatePhaseParity op, OpAdaptor adaptor,
3165 ConversionPatternRewriter &rewriter)
const override {
3166 Location loc = op.getLoc();
3167 Type i1 = rewriter.getI1Type();
3169 Value state = adaptor.getState();
3172 LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), state);
3173 Value phase = LLVM::LShrOp::create(
3174 rewriter, loc, noInitCount,
3176 Value parity = LLVM::TruncOp::create(rewriter, loc, i1, phase);
3178 rewriter.replaceOp(op, parity);
3187static Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
3188 Value accumulator, Value value, int64_t shift) {
3193 value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
3199 constexpr bool isDisjoint =
true;
3200 return LLVM::OrOp::create(rewriter, loc, accumulator, value, isDisjoint);
3203template <
typename BaseOp>
3204struct AMDGPUMakeDmaBaseLowering :
public ConvertOpToLLVMPattern<BaseOp> {
3205 using ConvertOpToLLVMPattern<BaseOp>::ConvertOpToLLVMPattern;
3208 AMDGPUMakeDmaBaseLowering(
const LLVMTypeConverter &converter, Chipset chipset)
3209 : ConvertOpToLLVMPattern<BaseOp>(converter), chipset(chipset) {}
3213 matchAndRewrite(BaseOp op, Adaptor adaptor,
3214 ConversionPatternRewriter &rewriter)
const override {
3216 return op->emitOpError(
"make_dma_base is only supported on gfx1250");
3218 Location loc = op.getLoc();
3220 constexpr int32_t constlen = 4;
3221 Value consts[constlen];
3222 for (int64_t i = 0; i < constlen; ++i)
3225 constexpr int32_t sgprslen = constlen;
3226 Value sgprs[sgprslen];
3227 for (int64_t i = 0; i < sgprslen; ++i) {
3228 sgprs[i] = consts[0];
3231 sgprs[0] = consts[1];
3233 if constexpr (BaseOp::isGather()) {
3234 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30);
3236 auto type = cast<TDMGatherBaseType>(op.getResult().getType());
3237 Type indexType = type.getIndexType();
3239 assert(llvm::is_contained({16u, 32u}, indexSize) &&
3240 "expected index_size to be 16 or 32");
3241 unsigned idx = (indexSize / 16) - 1;
3244 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 31);
3247 ValueRange ldsIndices = adaptor.getLdsIndices();
3248 Value lds = adaptor.getLds();
3249 auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
3252 rewriter, loc, ldsMemRefType, lds, ldsIndices);
3254 ValueRange globalIndices = adaptor.getGlobalIndices();
3255 Value global = adaptor.getGlobal();
3256 auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
3259 rewriter, loc, globalMemRefType, global, globalIndices);
3261 Type i32 = rewriter.getI32Type();
3262 Type i64 = rewriter.getI64Type();
3264 sgprs[1] = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
3265 Value castForGlobalAddr =
3266 LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
3268 sgprs[2] = LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
3270 Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
3273 Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
3276 highHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
3278 sgprs[3] = setValueAtOffset(rewriter, loc, highHalf, consts[2], 30);
3280 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3281 assert(v4i32 &&
"expected type conversion to succeed");
3282 Value
result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3284 for (
auto [sgpr, constant] : llvm::zip_equal(sgprs, consts))
3286 LLVM::InsertElementOp::create(rewriter, loc,
result, sgpr, constant);
3288 rewriter.replaceOp(op,
result);
3293template <
typename DescriptorOp>
3294struct AMDGPULowerDescriptor :
public ConvertOpToLLVMPattern<DescriptorOp> {
3295 using ConvertOpToLLVMPattern<DescriptorOp>::ConvertOpToLLVMPattern;
3298 AMDGPULowerDescriptor(
const LLVMTypeConverter &converter, Chipset chipset)
3299 : ConvertOpToLLVMPattern<DescriptorOp>(converter), chipset(chipset) {}
3302 Value getDGroup0(OpAdaptor adaptor)
const {
return adaptor.getBase(); }
3304 Value setWorkgroupMask(DescriptorOp op, OpAdaptor adaptor,
3305 ConversionPatternRewriter &rewriter, Location loc,
3306 Value sgpr0)
const {
3307 Value mask = op.getWorkgroupMask();
3311 Type i16 = rewriter.getI16Type();
3312 mask = LLVM::BitcastOp::create(rewriter, loc, i16, mask);
3313 Type i32 = rewriter.getI32Type();
3314 Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask);
3315 return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0);
3318 Value setDataSize(DescriptorOp op, OpAdaptor adaptor,
3319 ConversionPatternRewriter &rewriter, Location loc,
3320 Value sgpr0, ArrayRef<Value> consts)
const {
3321 unsigned elementTypeWidthInBits = op.getElementTypeWidth();
3322 assert(llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidthInBits) &&
3323 "expected type width to be 8, 16, 32, or 64.");
3324 int64_t idx = llvm::Log2_32(elementTypeWidthInBits / 8);
3325 Value size = consts[idx];
3326 return setValueAtOffset(rewriter, loc, sgpr0, size, 16);
3329 Value setAtomicBarrier(DescriptorOp op, OpAdaptor adaptor,
3330 ConversionPatternRewriter &rewriter, Location loc,
3331 Value sgpr0, ArrayRef<Value> consts)
const {
3332 if (!adaptor.getAtomicBarrierAddress())
3335 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
3338 Value setIterateEnable(DescriptorOp op, OpAdaptor adaptor,
3339 ConversionPatternRewriter &rewriter, Location loc,
3340 Value sgpr0, ArrayRef<Value> consts)
const {
3341 if (!adaptor.getGlobalIncrement())
3346 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
3349 Value setPadEnable(DescriptorOp op, OpAdaptor adaptor,
3350 ConversionPatternRewriter &rewriter, Location loc,
3351 Value sgpr0, ArrayRef<Value> consts)
const {
3352 if (!op.getPadAmount())
3355 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
3358 Value setEarlyTimeout(DescriptorOp op, OpAdaptor adaptor,
3359 ConversionPatternRewriter &rewriter, Location loc,
3360 Value sgpr0, ArrayRef<Value> consts)
const {
3361 if (!op.getWorkgroupMask())
3364 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21);
3367 Value setPadInterval(DescriptorOp op, OpAdaptor adaptor,
3368 ConversionPatternRewriter &rewriter, Location loc,
3369 Value sgpr0, ArrayRef<Value> consts)
const {
3370 if (!op.getPadAmount())
3379 IntegerType i32 = rewriter.getI32Type();
3380 Value padInterval = adaptor.getPadInterval();
3381 padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
3382 padInterval,
false);
3383 padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
3385 return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
3388 Value setPadAmount(DescriptorOp op, OpAdaptor adaptor,
3389 ConversionPatternRewriter &rewriter, Location loc,
3390 Value sgpr0, ArrayRef<Value> consts)
const {
3391 if (!op.getPadAmount())
3400 Value padAmount = adaptor.getPadAmount();
3401 padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
3403 return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
3406 Value setAtomicBarrierAddress(DescriptorOp op, OpAdaptor adaptor,
3407 ConversionPatternRewriter &rewriter,
3408 Location loc, Value sgpr1,
3409 ArrayRef<Value> consts)
const {
3410 if (!adaptor.getAtomicBarrierAddress())
3413 Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
3414 auto barrierAddressTy =
3415 cast<MemRefType>(op.getAtomicBarrierAddress().getType());
3416 ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
3418 rewriter, loc, barrierAddressTy, atomicBarrierAddress,
3419 atomicBarrierIndices);
3420 IntegerType i32 = rewriter.getI32Type();
3426 atomicBarrierAddress =
3427 LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
3428 atomicBarrierAddress =
3429 LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
3431 atomicBarrierAddress =
3432 LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
3433 return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
3436 std::pair<Value, Value> setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3437 ConversionPatternRewriter &rewriter,
3438 Location loc, Value sgpr1, Value sgpr2,
3439 ArrayRef<Value> consts, uint64_t dimX,
3440 uint32_t offset)
const {
3441 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3442 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3443 SmallVector<OpFoldResult> mixedGlobalSizes =
3445 if (mixedGlobalSizes.size() <= dimX)
3446 return {sgpr1, sgpr2};
3448 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3455 if (
auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3459 IntegerType i32 = rewriter.getI32Type();
3460 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3461 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3464 sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDimX, offset);
3467 Value tensorDimXHigh = LLVM::LShrOp::create(rewriter, loc, tensorDimX, c16);
3468 sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDimXHigh, offset + 16);
3469 return {sgpr1, sgpr2};
3472 std::pair<Value, Value> setTensorDim0(DescriptorOp op, OpAdaptor adaptor,
3473 ConversionPatternRewriter &rewriter,
3474 Location loc, Value sgpr1, Value sgpr2,
3475 ArrayRef<Value> consts)
const {
3476 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, 0,
3480 std::pair<Value, Value> setTensorDim1(DescriptorOp op, OpAdaptor adaptor,
3481 ConversionPatternRewriter &rewriter,
3482 Location loc, Value sgpr2, Value sgpr3,
3483 ArrayRef<Value> consts)
const {
3484 return setTensorDimX(op, adaptor, rewriter, loc, sgpr2, sgpr3, consts, 1,
3488 Value setTileDimX(DescriptorOp op, OpAdaptor adaptor,
3489 ConversionPatternRewriter &rewriter, Location loc,
3490 Value sgpr, ArrayRef<Value> consts,
size_t dimX,
3491 int64_t offset)
const {
3492 ArrayRef<int64_t> sharedStaticSizes = adaptor.getSharedStaticSizes();
3493 ValueRange sharedDynamicSizes = adaptor.getSharedDynamicSizes();
3494 SmallVector<OpFoldResult> mixedSharedSizes =
3496 if (mixedSharedSizes.size() <= dimX)
3499 OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
3508 if (
auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult)) {
3512 IntegerType i32 = rewriter.getI32Type();
3513 tileDimX = cast<Value>(tileDimXOpFoldResult);
3514 tileDimX = LLVM::TruncOp::create(rewriter, loc, i32, tileDimX);
3517 return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
3520 Value setTileDim0(DescriptorOp op, OpAdaptor adaptor,
3521 ConversionPatternRewriter &rewriter, Location loc,
3522 Value sgpr3, ArrayRef<Value> consts)
const {
3523 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
3526 Value setTileDim1(DescriptorOp op, OpAdaptor adaptor,
3527 ConversionPatternRewriter &rewriter, Location loc,
3528 Value sgpr4, ArrayRef<Value> consts)
const {
3529 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
3532 Value setValidIndices(DescriptorOp op, OpAdaptor adaptor,
3533 ConversionPatternRewriter &rewriter, Location loc,
3534 Value sgpr4, ArrayRef<Value> consts)
const {
3535 auto type = cast<VectorType>(op.getIndices().getType());
3536 ArrayRef<int64_t> shape = type.getShape();
3537 assert(shape.size() == 1 &&
"expected shape to be of rank 1.");
3538 unsigned length = shape.back();
3539 assert(0 < length && length <= 16 &&
"expected length to be at most 16.");
3541 return setValueAtOffset(rewriter, loc, sgpr4, value, 128);
3544 Value setTileDim1OrValidIndices(DescriptorOp op, OpAdaptor adaptor,
3545 ConversionPatternRewriter &rewriter,
3546 Location loc, Value sgpr4,
3547 ArrayRef<Value> consts)
const {
3548 if constexpr (DescriptorOp::isGather())
3549 return setValidIndices(op, adaptor, rewriter, loc, sgpr4, consts);
3550 return setTileDim1(op, adaptor, rewriter, loc, sgpr4, consts);
3553 Value setTileDim2(DescriptorOp op, OpAdaptor adaptor,
3554 ConversionPatternRewriter &rewriter, Location loc,
3555 Value sgpr4, ArrayRef<Value> consts)
const {
3557 if constexpr (DescriptorOp::isGather())
3559 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
3562 std::pair<Value, Value>
3563 setTensorDimXStride(DescriptorOp op, OpAdaptor adaptor,
3564 ConversionPatternRewriter &rewriter, Location loc,
3565 Value sgprY, Value sgprZ, ArrayRef<Value> consts,
3566 size_t dimX, int64_t offset)
const {
3567 ArrayRef<int64_t> globalStaticStrides = adaptor.getGlobalStaticStrides();
3568 ValueRange globalDynamicStrides = adaptor.getGlobalDynamicStrides();
3569 SmallVector<OpFoldResult> mixedGlobalStrides =
3570 getMixedValues(globalStaticStrides, globalDynamicStrides, rewriter);
3572 if (mixedGlobalStrides.size() <= (dimX + 1))
3573 return {sgprY, sgprZ};
3575 OpFoldResult tensorDimXStrideOpFoldResult =
3576 *(mixedGlobalStrides.rbegin() + dimX + 1);
3581 Value tensorDimXStride;
3582 if (
auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult))
3586 tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
3588 constexpr int64_t first48bits = (1ll << 48) - 1;
3591 LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
3592 IntegerType i32 = rewriter.getI32Type();
3593 Value tensorDimXStrideLow =
3594 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
3595 sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
3597 int64_t shift = (offset % 32) == 0 ? 32 : offset % 32;
3599 Value tensorDimXStrideHigh =
3600 LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
3601 tensorDimXStrideHigh =
3602 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
3603 sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
3605 return {sgprY, sgprZ};
3608 std::pair<Value, Value>
3609 setTensorDim0Stride(DescriptorOp op, OpAdaptor adaptor,
3610 ConversionPatternRewriter &rewriter, Location loc,
3611 Value sgpr5, Value sgpr6, ArrayRef<Value> consts)
const {
3612 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3616 std::pair<Value, Value>
3617 setTensorDim1Stride(DescriptorOp op, OpAdaptor adaptor,
3618 ConversionPatternRewriter &rewriter, Location loc,
3619 Value sgpr5, Value sgpr6, ArrayRef<Value> consts)
const {
3621 if constexpr (DescriptorOp::isGather())
3622 return {sgpr5, sgpr6};
3623 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3627 Value getDGroup1(DescriptorOp op, OpAdaptor adaptor,
3628 ConversionPatternRewriter &rewriter, Location loc,
3629 ArrayRef<Value> consts)
const {
3631 for (int64_t i = 0; i < 8; ++i) {
3632 sgprs[i] = consts[0];
3635 sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]);
3636 sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts);
3637 sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts);
3638 sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3639 sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3640 sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts);
3641 sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts);
3642 sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts);
3645 setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts);
3646 std::tie(sgprs[1], sgprs[2]) =
3647 setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3648 std::tie(sgprs[2], sgprs[3]) =
3649 setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3651 sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
3653 setTileDim1OrValidIndices(op, adaptor, rewriter, loc, sgprs[4], consts);
3654 sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
3655 std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
3656 op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
3657 std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride(
3658 op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts);
3660 IntegerType i32 = rewriter.getI32Type();
3661 Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32));
3662 assert(v8i32 &&
"expected type conversion to succeed");
3663 Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32);
3665 for (
auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) {
3667 LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant);
3673 Value setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3674 ConversionPatternRewriter &rewriter, Location loc,
3675 Value sgpr0, ArrayRef<Value> consts, int64_t dimX,
3676 int64_t offset)
const {
3677 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3678 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3679 SmallVector<OpFoldResult> mixedGlobalSizes =
3681 if (mixedGlobalSizes.size() <=
static_cast<unsigned long>(dimX))
3684 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3686 if (
auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3690 IntegerType i32 = rewriter.getI32Type();
3691 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3692 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3695 return setValueAtOffset(rewriter, loc, sgpr0, tensorDimX, offset);
3698 Value setTensorDim2(DescriptorOp op, OpAdaptor adaptor,
3699 ConversionPatternRewriter &rewriter, Location loc,
3700 Value sgpr0, ArrayRef<Value> consts)
const {
3701 return setTensorDimX(op, adaptor, rewriter, loc, sgpr0, consts, 2, 0);
3704 Value truncateAndSetValueAtOffset(ConversionPatternRewriter &rewriter,
3705 Location loc, Value accumulator,
3706 Value value, int64_t shift)
const {
3708 IntegerType i32 = rewriter.getI32Type();
3709 value = LLVM::TruncOp::create(rewriter, loc, i32, value);
3710 return setValueAtOffset(rewriter, loc, accumulator, value, shift);
3713 Value setLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3714 ConversionPatternRewriter &rewriter, Location loc,
3715 Value sgpr1, ArrayRef<Value> consts,
3716 int64_t offset)
const {
3717 Value ldsAddrIncrement = adaptor.getLdsIncrement();
3718 return setValueAtOffset(rewriter, loc, sgpr1, ldsAddrIncrement, offset);
3721 std::pair<Value, Value>
3722 setGlobalAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3723 ConversionPatternRewriter &rewriter, Location loc,
3724 Value sgpr2, Value sgpr3, ArrayRef<Value> consts,
3725 int64_t offset)
const {
3726 Value globalAddrIncrement = adaptor.getGlobalIncrement();
3727 sgpr2 = truncateAndSetValueAtOffset(rewriter, loc, sgpr2,
3728 globalAddrIncrement, offset);
3730 globalAddrIncrement =
3731 LLVM::LShrOp::create(rewriter, loc, globalAddrIncrement, shift);
3732 constexpr int64_t first16BitsHigh = (1ll << 16) - 1;
3733 sgpr3 = truncateAndSetValueAtOffset(rewriter, loc, sgpr3,
3734 globalAddrIncrement, offset + 32);
3736 sgpr3 = LLVM::AndOp::create(rewriter, loc, sgpr3, mask);
3737 return {sgpr2, sgpr3};
3740 Value setTensorDim3OrLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3741 ConversionPatternRewriter &rewriter,
3742 Location loc, Value sgpr1,
3743 ArrayRef<Value> consts)
const {
3744 Value ldsIncrement = op.getLdsIncrement();
3745 constexpr int64_t dim = 3;
3746 constexpr int64_t offset = 32;
3748 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, consts, dim,
3750 return setLDSAddrIncrement(op, adaptor, rewriter, loc, sgpr1, consts,
3754 std::pair<Value, Value> setTensorDim2StrideOrGlobalAddrIncrement(
3755 DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
3756 Location loc, Value sgpr2, Value sgpr3, ArrayRef<Value> consts)
const {
3757 Value globalIncrement = op.getGlobalIncrement();
3758 constexpr int32_t dim = 2;
3759 constexpr int32_t offset = 64;
3760 if (!globalIncrement)
3761 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3762 consts, dim, offset);
3763 return setGlobalAddrIncrement(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3767 Value setIterateCount(DescriptorOp op, OpAdaptor adaptor,
3768 ConversionPatternRewriter &rewriter, Location loc,
3769 Value sgpr3, ArrayRef<Value> consts,
3770 int32_t offset)
const {
3771 Value iterationCount = adaptor.getIterationCount();
3772 IntegerType i32 = rewriter.getI32Type();
3779 iterationCount = LLVM::TruncOp::create(rewriter, loc, i32, iterationCount);
3781 LLVM::SubOp::create(rewriter, loc, iterationCount, consts[1]);
3782 return setValueAtOffset(rewriter, loc, sgpr3, iterationCount, offset);
3785 Value setTileDim3OrIterateCount(DescriptorOp op, OpAdaptor adaptor,
3786 ConversionPatternRewriter &rewriter,
3787 Location loc, Value sgpr3,
3788 ArrayRef<Value> consts)
const {
3789 Value iterateCount = op.getIterationCount();
3790 constexpr int32_t dim = 2;
3791 constexpr int32_t offset = 112;
3793 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, dim,
3796 return setIterateCount(op, adaptor, rewriter, loc, sgpr3, consts, offset);
3799 Value getDGroup2(DescriptorOp op, OpAdaptor adaptor,
3800 ConversionPatternRewriter &rewriter, Location loc,
3801 ArrayRef<Value> consts)
const {
3802 if constexpr (DescriptorOp::isGather())
3803 return getDGroup2Gather(op, adaptor, rewriter, loc, consts);
3804 return getDGroup2NonGather(op, adaptor, rewriter, loc, consts);
3807 Value getDGroup2NonGather(DescriptorOp op, OpAdaptor adaptor,
3808 ConversionPatternRewriter &rewriter, Location loc,
3809 ArrayRef<Value> consts)
const {
3810 IntegerType i32 = rewriter.getI32Type();
3811 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3812 assert(v4i32 &&
"expected type conversion to succeed.");
3814 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
3815 if (onlyNeedsTwoDescriptors)
3816 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
3818 constexpr int64_t sgprlen = 4;
3819 Value sgprs[sgprlen];
3820 for (
int i = 0; i < sgprlen; ++i)
3821 sgprs[i] = consts[0];
3823 sgprs[0] = setTensorDim2(op, adaptor, rewriter, loc, sgprs[0], consts);
3824 sgprs[1] = setTensorDim3OrLDSAddrIncrement(op, adaptor, rewriter, loc,
3826 std::tie(sgprs[2], sgprs[3]) = setTensorDim2StrideOrGlobalAddrIncrement(
3827 op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3829 setTileDim3OrIterateCount(op, adaptor, rewriter, loc, sgprs[3], consts);
3831 Value dgroup2 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3832 for (
auto [sgpr, constant] : llvm::zip(sgprs, consts))
3834 LLVM::InsertElementOp::create(rewriter, loc, dgroup2, sgpr, constant);
3839 Value getGatherIndices(DescriptorOp op, OpAdaptor adaptor,
3840 ConversionPatternRewriter &rewriter, Location loc,
3841 ArrayRef<Value> consts,
bool firstHalf)
const {
3842 IntegerType i32 = rewriter.getI32Type();
3843 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3844 assert(v4i32 &&
"expected type conversion to succeed.");
3846 Value
indices = adaptor.getIndices();
3847 auto vectorType = cast<VectorType>(
indices.getType());
3848 unsigned length = vectorType.getShape().back();
3849 Type elementType = vectorType.getElementType();
3850 unsigned maxLength = elementType == i32 ? 4 : 8;
3851 int32_t offset = firstHalf ? 0 : maxLength;
3852 unsigned discountedLength =
3853 std::max(
static_cast<int32_t
>(length - offset), 0);
3855 unsigned targetSize = std::min(maxLength, discountedLength);
3857 SmallVector<Value> indicesVector;
3858 for (
unsigned i = offset; i < targetSize + offset; ++i) {
3860 if (i < consts.size())
3864 Value elem = LLVM::ExtractElementOp::create(rewriter, loc,
indices, idx);
3865 indicesVector.push_back(elem);
3868 SmallVector<Value> indicesI32Vector;
3869 if (elementType == i32) {
3870 indicesI32Vector = indicesVector;
3872 for (
unsigned i = 0; i < targetSize; ++i) {
3873 Value index = indicesVector[i];
3874 indicesI32Vector.push_back(
3875 LLVM::ZExtOp::create(rewriter, loc, i32, index));
3877 if ((targetSize % 2) != 0)
3879 indicesI32Vector.push_back(consts[0]);
3882 SmallVector<Value> indicesToInsert;
3883 if (elementType == i32) {
3884 indicesToInsert = indicesI32Vector;
3886 unsigned size = indicesI32Vector.size() / 2;
3887 for (
unsigned i = 0; i < size; ++i) {
3888 Value first = indicesI32Vector[2 * i];
3889 Value second = indicesI32Vector[2 * i + 1];
3890 Value joined = setValueAtOffset(rewriter, loc, first, second, 16);
3891 indicesToInsert.push_back(joined);
3895 Value dgroup = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3896 for (
auto [sgpr, constant] : llvm::zip_first(indicesToInsert, consts))
3898 LLVM::InsertElementOp::create(rewriter, loc, dgroup, sgpr, constant);
3903 Value getDGroup2Gather(DescriptorOp op, OpAdaptor adaptor,
3904 ConversionPatternRewriter &rewriter, Location loc,
3905 ArrayRef<Value> consts)
const {
3906 return getGatherIndices(op, adaptor, rewriter, loc, consts,
true);
3909 std::pair<Value, Value>
3910 setTensorDim3Stride(DescriptorOp op, OpAdaptor adaptor,
3911 ConversionPatternRewriter &rewriter, Location loc,
3912 Value sgpr0, Value sgpr1, ArrayRef<Value> consts)
const {
3913 constexpr int32_t dim = 3;
3914 constexpr int32_t offset = 0;
3915 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr0, sgpr1, consts,
3919 std::pair<Value, Value> setTensorDim4(DescriptorOp op, OpAdaptor adaptor,
3920 ConversionPatternRewriter &rewriter,
3921 Location loc, Value sgpr1, Value sgpr2,
3922 ArrayRef<Value> consts)
const {
3923 constexpr int32_t dim = 4;
3924 constexpr int32_t offset = 48;
3925 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, dim,
3929 Value setTileDim4(DescriptorOp op, OpAdaptor adaptor,
3930 ConversionPatternRewriter &rewriter, Location loc,
3931 Value sgpr2, ArrayRef<Value> consts)
const {
3932 constexpr int32_t dim = 4;
3933 constexpr int32_t offset = 80;
3934 return setTileDimX(op, adaptor, rewriter, loc, sgpr2, consts, dim, offset);
3937 Value getDGroup3(DescriptorOp op, OpAdaptor adaptor,
3938 ConversionPatternRewriter &rewriter, Location loc,
3939 ArrayRef<Value> consts)
const {
3940 if constexpr (DescriptorOp::isGather())
3941 return getDGroup3Gather(op, adaptor, rewriter, loc, consts);
3942 return getDGroup3NonGather(op, adaptor, rewriter, loc, consts);
3945 Value getDGroup3NonGather(DescriptorOp op, OpAdaptor adaptor,
3946 ConversionPatternRewriter &rewriter, Location loc,
3947 ArrayRef<Value> consts)
const {
3948 IntegerType i32 = rewriter.getI32Type();
3949 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3950 assert(v4i32 &&
"expected type conversion to succeed.");
3951 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
3952 if (onlyNeedsTwoDescriptors)
3953 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
3955 constexpr int32_t sgprlen = 4;
3956 Value sgprs[sgprlen];
3957 for (
int i = 0; i < sgprlen; ++i)
3958 sgprs[i] = consts[0];
3960 std::tie(sgprs[0], sgprs[1]) = setTensorDim3Stride(
3961 op, adaptor, rewriter, loc, sgprs[0], sgprs[1], consts);
3962 std::tie(sgprs[1], sgprs[2]) =
3963 setTensorDim4(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3964 sgprs[2] = setTileDim4(op, adaptor, rewriter, loc, sgprs[2], consts);
3966 Value dgroup3 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3967 for (
auto [sgpr, constant] : llvm::zip(sgprs, consts))
3969 LLVM::InsertElementOp::create(rewriter, loc, dgroup3, sgpr, constant);
3974 Value getDGroup3Gather(DescriptorOp op, OpAdaptor adaptor,
3975 ConversionPatternRewriter &rewriter, Location loc,
3976 ArrayRef<Value> consts)
const {
3977 return getGatherIndices(op, adaptor, rewriter, loc, consts,
false);
3981 matchAndRewrite(DescriptorOp op, OpAdaptor adaptor,
3982 ConversionPatternRewriter &rewriter)
const override {
3984 return op->emitOpError(
3985 "make_dma_descriptor is only supported on gfx1250");
3987 Location loc = op.getLoc();
3989 SmallVector<Value> consts;
3990 for (int64_t i = 0; i < 8; ++i)
3993 Value dgroup0 = this->getDGroup0(adaptor);
3994 Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts);
3995 Value dgroup2 = this->getDGroup2(op, adaptor, rewriter, loc, consts);
3996 Value dgroup3 = this->getDGroup3(op, adaptor, rewriter, loc, consts);
3997 SmallVector<Value> results = {dgroup0, dgroup1, dgroup2, dgroup3};
3998 rewriter.replaceOpWithMultiple(op, {results});
4003template <
typename SourceOp,
typename TargetOp>
4004struct AMDGPUTensorLoadStoreOpLowering
4005 :
public ConvertOpToLLVMPattern<SourceOp> {
4006 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
4008 AMDGPUTensorLoadStoreOpLowering(
const LLVMTypeConverter &converter,
4010 : ConvertOpToLLVMPattern<SourceOp>(converter), chipset(chipset) {}
4014 matchAndRewrite(SourceOp op, Adaptor adaptor,
4015 ConversionPatternRewriter &rewriter)
const override {
4017 return op->emitOpError(
"is only supported on gfx1250");
4022 auto v8i32 = VectorType::get(8, rewriter.getI32Type());
4023 Value dgroup4 = LLVM::ZeroOp::create(rewriter, op.getLoc(), v8i32);
4024 rewriter.replaceOpWithNewOp<TargetOp>(op, desc[0], desc[1], desc[2],
4025 desc[3], dgroup4, 0,
4033struct GlobalPrefetchOpLowering
4034 :
public ConvertOpToLLVMPattern<GlobalPrefetchOp> {
4035 GlobalPrefetchOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
4036 : ConvertOpToLLVMPattern<GlobalPrefetchOp>(converter), chipset(chipset) {}
4039 matchAndRewrite(GlobalPrefetchOp op, GlobalPrefetchOpAdaptor adaptor,
4040 ConversionPatternRewriter &rewriter)
const override {
4042 return op->emitOpError(
"is only supported on gfx1250+");
4044 const bool isSpeculative = op.getSpeculative();
4046 op.getTemporalHint(), op.getCacheScope(), isSpeculative);
4047 IntegerAttr immArgAttr = rewriter.getI32IntegerAttr(immArgValue);
4050 Value memRef = adaptor.getSrc();
4051 MemRefDescriptor descriptor(memRef);
4052 MemRefType memRefType = op.getSrc().getType();
4053 Location loc = op->getLoc();
4054 auto inboundsFlags = isSpeculative ? LLVM::GEPNoWrapFlags::none
4055 : LLVM::GEPNoWrapFlags::inbounds |
4056 LLVM::GEPNoWrapFlags::nuw;
4058 rewriter, loc, memRefType, descriptor,
indices, inboundsFlags);
4060 rewriter.replaceOpWithNewOp<ROCDL::GlobalPrefetchOp>(
4061 op, prefetchPtr, immArgAttr, mlir::ArrayAttr{}, mlir::ArrayAttr{},
4070struct ConvertAMDGPUToROCDLPass
4071 :
public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
4074 void runOnOperation()
override {
4077 if (
failed(maybeChipset)) {
4078 emitError(UnknownLoc::get(ctx),
"Invalid chipset name: " + chipset);
4079 return signalPassFailure();
4082 RewritePatternSet patterns(ctx);
4083 LLVMTypeConverter converter(ctx);
4086 amdgpu::populateCommonGPUTypeAndAttributeConversions(converter);
4088 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
4089 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
4090 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
4091 if (
failed(applyPartialConversion(getOperation(),
target,
4092 std::move(patterns))))
4093 signalPassFailure();
4101 typeConverter, [](gpu::AddressSpace space) {
4103 case gpu::AddressSpace::Global:
4104 return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace;
4105 case gpu::AddressSpace::Workgroup:
4106 return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace;
4107 case gpu::AddressSpace::Private:
4108 return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace;
4109 case gpu::AddressSpace::Constant:
4110 return ROCDL::ROCDLDialect::kConstantMemoryAddressSpace;
4112 llvm_unreachable(
"unknown address space enum value");
4118 typeConverter.addTypeAttributeConversion(
4120 -> TypeConverter::AttributeConversionResult {
4122 Type i64 = IntegerType::get(ctx, 64);
4123 switch (as.getValue()) {
4124 case amdgpu::AddressSpace::FatRawBuffer:
4125 return IntegerAttr::get(i64, 7);
4126 case amdgpu::AddressSpace::BufferRsrc:
4127 return IntegerAttr::get(i64, 8);
4128 case amdgpu::AddressSpace::FatStructuredBuffer:
4129 return IntegerAttr::get(i64, 9);
4131 return TypeConverter::AttributeConversionResult::abort();
4133 typeConverter.addConversion([&](DsBarrierStateType type) ->
Type {
4134 return IntegerType::get(type.
getContext(), 64);
4136 typeConverter.addConversion([&](TDMBaseType type) ->
Type {
4138 return typeConverter.convertType(VectorType::get(4, i32));
4140 typeConverter.addConversion([&](TDMGatherBaseType type) ->
Type {
4142 return typeConverter.convertType(VectorType::get(4, i32));
4144 typeConverter.addConversion(
4145 [&](TDMDescriptorType type,
4148 Type v4i32 = typeConverter.convertType(VectorType::get(4, i32));
4149 Type v8i32 = typeConverter.convertType(VectorType::get(8, i32));
4150 llvm::append_values(
result, v4i32, v8i32, v4i32, v4i32);
4160 if (inputs.size() != 1)
4163 if (!isa<TDMDescriptorType>(inputs[0].
getType()))
4166 auto cast = UnrealizedConversionCastOp::create(builder, loc, types, inputs);
4167 return cast.getResults();
4170 typeConverter.addTargetMaterialization(addUnrealizedCast);
4178 .
add<FatRawBufferCastLowering,
4179 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
4180 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
4181 RawBufferOpLowering<RawBufferAtomicFaddOp,
4182 ROCDL::RawPtrBufferAtomicFaddOp>,
4183 RawBufferOpLowering<RawBufferAtomicFmaxOp,
4184 ROCDL::RawPtrBufferAtomicFmaxOp>,
4185 RawBufferOpLowering<RawBufferAtomicSmaxOp,
4186 ROCDL::RawPtrBufferAtomicSmaxOp>,
4187 RawBufferOpLowering<RawBufferAtomicUminOp,
4188 ROCDL::RawPtrBufferAtomicUminOp>,
4189 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
4190 ROCDL::RawPtrBufferAtomicCmpSwap>,
4191 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
4192 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
4193 SparseMFMAOpLowering, WMMAOpLowering, ScaledWMMAOpLowering,
4194 SparseWMMAOpLowering, ExtPackedFp8OpLowering,
4195 ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
4196 PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
4197 PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
4198 GlobalLoadAsyncToLDSOpLowering, TransposeLoadOpLowering,
4199 AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
4200 AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
4201 AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
4202 AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>,
4203 AMDGPUTensorLoadStoreOpLowering<TensorLoadToLDSOp,
4204 ROCDL::TensorLoadToLDSOp>,
4205 AMDGPUTensorLoadStoreOpLowering<TensorStoreFromLDSOp,
4206 ROCDL::TensorStoreFromLDSOp>,
4207 DsBarrierInitOpLowering, DsBarrierPollStateOpLowering,
4208 DsAsyncBarrierArriveOpLowering, DsBarrierArriveOpLowering,
4209 GlobalPrefetchOpLowering>(converter, chipset);
4210 patterns.
add<AMDGPUSwizzleBitModeLowering, DsBarrierStatePhaseOpLowering,
4211 DsBarrierStatePendingCountOpLowering,
4212 DsBarrierStateInitCountOpLowering,
4213 DsBarrierStatePhaseParityLowering>(converter);
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k, bool isRDNA3)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma for RDNA3/4 architectures.
static std::optional< SparseWMMAOpInfo > sparseWMMAOpToIntrinsic(SparseWMMAOp swmmac, Chipset chipset)
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
constexpr Chipset kGfx908
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs, StringRef attrName)
Push an input operand.
constexpr Chipset kGfx1250
static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA/WMMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to R...
constexpr Chipset kGfx90a
static std::optional< StringRef > getScaledWmmaIntrinsicName(int64_t m, int64_t n, int64_t k, bool isScale16)
Determines the ROCDL intrinsic name for scaled WMMA based on dimensions and scale block size (16 or 3...
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static std::optional< StringRef > smfmacOpToIntrinsic(SparseMFMAOp op, Chipset chipset)
Returns the rocdl intrinsic corresponding to a SparseMFMA (smfmac) operation if one exists.
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value)
static std::optional< StringRef > wmmaOpToIntrinsicGfx1250(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k)
Return the rocdl intrinsic corresponding to a WMMA operation wmma for the gfx1250 architecture.
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, int64_t elementByteWidth, amdgpu::Chipset chipset, bool boundsCheck)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Pack small float vector operands (fp4/fp6/fp8/bf16) into the format expected by scaled matrix multipl...
static std::optional< uint32_t > getWmmaScaleFormat(Type elemType)
Maps f8 scale element types to WMMA scale format codes.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static std::optional< uint32_t > smallFloatTypeToFormatCode(Type mlirElemType)
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i64.
constexpr Chipset kGfx950
static Value convertSparseVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts sparse MFMA/WMMA (smfmac/swmmac) operands to the expected ROCDL types.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
This class provides a shared interface for ranked and unranked memref types.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
typename SourceOp::Adaptor OpAdaptor
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
NamedAttribute represents a combination of a name and an Attribute value.
This class helps build Operations.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
int32_t getGlobalPrefetchLLVMEncoding(amdgpu::LoadTemporalHint hint, amdgpu::Scope scope, bool isSpeculative)
bool hasOcpFp8(const Chipset &chipset)
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces and types,...
llvm::TypeSwitch< T, ResultT > TypeSwitch
void populateAMDGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
Returns the rocdl intrinsic corresponding to a SparseWMMA operation swmmac if one exists.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.