29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/Support/Casting.h"
32#include "llvm/Support/ErrorHandling.h"
36#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
37#include "mlir/Conversion/Passes.h.inc"
53 IntegerType i32 = rewriter.getI32Type();
55 auto valTy = cast<IntegerType>(val.
getType());
58 return valTy.getWidth() > 32
59 ?
Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
60 :
Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
65 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
71 IntegerType i64 = rewriter.getI64Type();
73 auto valTy = cast<IntegerType>(val.
getType());
76 return valTy.getWidth() > 64
77 ?
Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
78 :
Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
83 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
90 IntegerType i32 = rewriter.getI32Type();
92 for (
auto [i, increment, stride] : llvm::enumerate(
indices, strides)) {
95 ShapedType::isDynamic(stride)
97 memRefDescriptor.
stride(rewriter, loc, i))
98 : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
99 increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
111 MemRefType memrefType,
115 if (memrefType.hasStaticShape() &&
116 !llvm::any_of(strides, ShapedType::isDynamic)) {
117 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
119 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
120 size = std::max(
shape[i] * strides[i], size);
121 size = size * elementByteWidth;
125 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
126 Value size = memrefDescriptor.
size(rewriter, loc, i);
127 Value stride = memrefDescriptor.
stride(rewriter, loc, i);
128 Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
130 ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
135 return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
141 Value cacheSwizzleStride =
nullptr,
142 unsigned addressSpace = 8) {
146 Type i16 = rewriter.getI16Type();
149 Value cacheStrideZext =
150 LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
151 Value swizzleBit = LLVM::ConstantOp::create(
152 rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
153 stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
156 stride = LLVM::ConstantOp::create(rewriter, loc, i16,
157 rewriter.getI16IntegerAttr(0));
174 uint32_t flags = (7 << 12) | (4 << 15);
177 uint32_t oob = boundsCheck ? 3 : 2;
178 flags |= (oob << 28);
182 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
183 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
184 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
189struct FatRawBufferCastLowering
191 FatRawBufferCastLowering(
const LLVMTypeConverter &converter, Chipset chipset)
192 : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
198 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
199 ConversionPatternRewriter &rewriter)
const override {
200 Location loc = op.getLoc();
201 Value memRef = adaptor.getSource();
202 Value unconvertedMemref = op.getSource();
203 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
204 MemRefDescriptor descriptor(memRef);
206 DataLayout dataLayout = DataLayout::closest(op);
207 int64_t elementByteWidth =
210 int64_t unusedOffset = 0;
211 SmallVector<int64_t, 5> strideVals;
212 if (
failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
213 return op.emitOpError(
"Can't lower non-stride-offset memrefs");
215 Value numRecords = adaptor.getValidBytes();
217 numRecords =
getNumRecords(rewriter, loc, memrefType, descriptor,
218 strideVals, elementByteWidth);
221 adaptor.getResetOffset()
222 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
224 : descriptor.alignedPtr(rewriter, loc);
226 Value offset = adaptor.getResetOffset()
227 ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
228 rewriter.getIndexAttr(0))
229 : descriptor.offset(rewriter, loc);
231 bool hasSizes = memrefType.getRank() > 0;
234 Value sizes = hasSizes
235 ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
239 hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
244 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
245 chipset, adaptor.getCacheSwizzleStride(), 7);
247 Value
result = MemRefDescriptor::poison(
249 getTypeConverter()->convertType(op.getResult().getType()));
251 result = LLVM::InsertValueOp::create(rewriter, loc,
result, fatPtr, pos);
252 result = LLVM::InsertValueOp::create(rewriter, loc,
result, fatPtr,
254 result = LLVM::InsertValueOp::create(rewriter, loc,
result, offset,
257 result = LLVM::InsertValueOp::create(rewriter, loc,
result, sizes,
259 result = LLVM::InsertValueOp::create(rewriter, loc,
result, strides,
262 rewriter.replaceOp(op,
result);
268template <
typename GpuOp,
typename Intrinsic>
270 RawBufferOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
271 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
274 static constexpr uint32_t maxVectorOpWidth = 128;
277 matchAndRewrite(GpuOp gpuOp,
typename GpuOp::Adaptor adaptor,
278 ConversionPatternRewriter &rewriter)
const override {
279 Location loc = gpuOp.getLoc();
280 Value memref = adaptor.getMemref();
281 Value unconvertedMemref = gpuOp.getMemref();
282 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
284 if (chipset.majorVersion < 9)
285 return gpuOp.emitOpError(
"raw buffer ops require GCN or higher");
287 Value storeData = adaptor.getODSOperands(0)[0];
288 if (storeData == memref)
292 wantedDataType = storeData.
getType();
294 wantedDataType = gpuOp.getODSResults(0)[0].getType();
296 Value atomicCmpData = Value();
299 Value maybeCmpData = adaptor.getODSOperands(1)[0];
300 if (maybeCmpData != memref)
301 atomicCmpData = maybeCmpData;
304 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
306 Type i32 = rewriter.getI32Type();
309 DataLayout dataLayout = DataLayout::closest(gpuOp);
310 int64_t elementByteWidth =
319 Type llvmBufferValType = llvmWantedDataType;
321 if (
auto floatType = dyn_cast<FloatType>(wantedDataType))
322 llvmBufferValType = this->getTypeConverter()->convertType(
323 rewriter.getIntegerType(floatType.getWidth()));
325 if (
auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
326 uint32_t vecLen = dataVector.getNumElements();
329 uint32_t totalBits = elemBits * vecLen;
331 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
332 if (totalBits > maxVectorOpWidth)
333 return gpuOp.emitOpError(
334 "Total width of loads or stores must be no more than " +
335 Twine(maxVectorOpWidth) +
" bits, but we call for " +
337 " bits. This should've been caught in validation");
338 if (!usePackedFp16 && elemBits < 32) {
339 if (totalBits > 32) {
340 if (totalBits % 32 != 0)
341 return gpuOp.emitOpError(
"Load or store of more than 32-bits that "
342 "doesn't fit into words. Can't happen\n");
343 llvmBufferValType = this->typeConverter->convertType(
344 VectorType::get(totalBits / 32, i32));
346 llvmBufferValType = this->typeConverter->convertType(
347 rewriter.getIntegerType(totalBits));
351 if (
auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
354 if (vecType.getNumElements() == 1)
355 llvmBufferValType = vecType.getElementType();
358 SmallVector<Value, 6> args;
360 if (llvmBufferValType != llvmWantedDataType) {
361 Value castForStore = LLVM::BitcastOp::create(
362 rewriter, loc, llvmBufferValType, storeData);
363 args.push_back(castForStore);
365 args.push_back(storeData);
370 if (llvmBufferValType != llvmWantedDataType) {
371 Value castForCmp = LLVM::BitcastOp::create(
372 rewriter, loc, llvmBufferValType, atomicCmpData);
373 args.push_back(castForCmp);
375 args.push_back(atomicCmpData);
381 SmallVector<int64_t, 5> strides;
382 if (
failed(memrefType.getStridesAndOffset(strides, offset)))
383 return gpuOp.emitOpError(
"Can't lower non-stride-offset memrefs");
385 MemRefDescriptor memrefDescriptor(memref);
387 Value ptr = memrefDescriptor.bufferPtr(
388 rewriter, loc, *this->getTypeConverter(), memrefType);
390 rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
392 adaptor.getBoundsCheck(), chipset);
393 args.push_back(resource);
397 adaptor.getIndices(), strides);
398 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
399 indexOffset && *indexOffset > 0) {
401 voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
405 voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
406 args.push_back(voffset);
409 Value sgprOffset = adaptor.getSgprOffset();
412 sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
413 args.push_back(sgprOffset);
420 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
422 Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
423 ArrayRef<NamedAttribute>());
426 if (llvmBufferValType != llvmWantedDataType) {
427 replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
432 rewriter.eraseOp(gpuOp);
449static FailureOr<unsigned> encodeWaitcnt(
Chipset chipset,
unsigned vmcnt,
450 unsigned expcnt,
unsigned lgkmcnt) {
452 vmcnt = std::min(15u, vmcnt);
453 expcnt = std::min(7u, expcnt);
454 lgkmcnt = std::min(15u, lgkmcnt);
455 return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
458 vmcnt = std::min(63u, vmcnt);
459 expcnt = std::min(7u, expcnt);
460 lgkmcnt = std::min(15u, lgkmcnt);
461 unsigned lowBits = vmcnt & 0xF;
462 unsigned highBits = (vmcnt >> 4) << 14;
463 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
464 return lowBits | highBits | otherCnts;
467 vmcnt = std::min(63u, vmcnt);
468 expcnt = std::min(7u, expcnt);
469 lgkmcnt = std::min(63u, lgkmcnt);
470 unsigned lowBits = vmcnt & 0xF;
471 unsigned highBits = (vmcnt >> 4) << 14;
472 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
473 return lowBits | highBits | otherCnts;
476 vmcnt = std::min(63u, vmcnt);
477 expcnt = std::min(7u, expcnt);
478 lgkmcnt = std::min(63u, lgkmcnt);
479 return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
484struct MemoryCounterWaitOpLowering
494 matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
495 ConversionPatternRewriter &rewriter)
const override {
496 if (
chipset.majorVersion >= 12) {
498 if (std::optional<int> ds = adaptor.getDs())
499 ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
501 if (std::optional<int>
load = adaptor.getLoad())
502 ROCDL::WaitLoadcntOp::create(rewriter, loc, *
load);
504 if (std::optional<int> store = adaptor.getStore())
505 ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
507 if (std::optional<int> exp = adaptor.getExp())
508 ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
510 if (std::optional<int>
tensor = adaptor.getTensor())
511 ROCDL::WaitTensorcntOp::create(rewriter, loc, *
tensor);
513 rewriter.eraseOp(op);
517 if (adaptor.getTensor())
518 return op.emitOpError(
"unsupported chipset");
520 auto getVal = [](
Attribute attr) ->
unsigned {
522 return cast<IntegerAttr>(attr).getInt();
527 unsigned ds = getVal(adaptor.getDsAttr());
528 unsigned exp = getVal(adaptor.getExpAttr());
530 unsigned vmcnt = 1024;
532 Attribute store = adaptor.getStoreAttr();
534 vmcnt = getVal(
load) + getVal(store);
536 vmcnt = getVal(
load);
538 vmcnt = getVal(store);
541 FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
543 return op.emitOpError(
"unsupported chipset");
545 rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
551 LDSBarrierOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
552 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
557 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
558 ConversionPatternRewriter &rewriter)
const override {
559 Location loc = op.getLoc();
562 bool requiresInlineAsm = chipset <
kGfx90a;
565 rewriter.getAttr<LLVM::MMRATagAttr>(
"amdgpu-synchronize-as",
"local");
574 StringRef scope =
"workgroup";
576 auto relFence = LLVM::FenceOp::create(rewriter, loc,
577 LLVM::AtomicOrdering::release, scope);
578 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
579 if (requiresInlineAsm) {
580 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
581 LLVM::AsmDialect::AD_ATT);
582 const char *asmStr =
";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
583 const char *constraints =
"";
584 LLVM::InlineAsmOp::create(
587 asmStr, constraints,
true,
588 false, LLVM::TailCallKind::None,
591 }
else if (chipset.majorVersion < 12) {
592 ROCDL::SBarrierOp::create(rewriter, loc);
594 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
595 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
598 auto acqFence = LLVM::FenceOp::create(rewriter, loc,
599 LLVM::AtomicOrdering::acquire, scope);
600 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
601 rewriter.replaceOp(op, acqFence);
607 SchedBarrierOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
608 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
613 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
614 ConversionPatternRewriter &rewriter)
const override {
615 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
616 (uint32_t)op.getOpts());
640 bool allowBf16 =
true) {
642 if (
auto vectorType = dyn_cast<VectorType>(inputType)) {
643 if (vectorType.getElementType().isBF16() && !allowBf16)
644 return LLVM::BitcastOp::create(
645 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
646 if (vectorType.getElementType().isInteger(8) &&
647 vectorType.getNumElements() <= 8)
648 return LLVM::BitcastOp::create(
650 rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
651 if (isa<IntegerType>(vectorType.getElementType()) &&
652 vectorType.getElementTypeBitWidth() <= 8) {
653 int64_t numWords = llvm::divideCeil(
654 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
656 return LLVM::BitcastOp::create(
657 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
667 bool allowBf16 =
true) {
669 auto vectorType = cast<VectorType>(inputType);
671 if (vectorType.getElementType().isBF16() && !allowBf16)
672 return LLVM::BitcastOp::create(
673 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
675 if (isa<IntegerType>(vectorType.getElementType()) &&
676 vectorType.getElementTypeBitWidth() <= 8) {
677 int64_t numWords = llvm::divideCeil(
678 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32);
679 return LLVM::BitcastOp::create(
680 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()), input);
698 .Case([&](IntegerType) {
700 return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(),
703 .Case([&](VectorType vectorType) {
705 int64_t numElements = vectorType.getNumElements();
706 assert((numElements == 4 || numElements == 8) &&
707 "scale operand must be a vector of length 4 or 8");
708 IntegerType outputType =
709 (numElements == 4) ? rewriter.getI32Type() : rewriter.getI64Type();
710 return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
712 .DefaultUnreachable(
"unexpected input type for scale operand");
718 .Case([](Float8E8M0FNUType) {
return 0; })
719 .Case([](Float8E4M3FNType) {
return 2; })
720 .Default(std::nullopt);
725static std::optional<StringRef>
727 if (m == 16 && n == 16 && k == 128)
729 ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
730 : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
732 if (m == 32 && n == 16 && k == 128)
733 return isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
734 : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
748 ConversionPatternRewriter &rewriter,
Location loc,
753 auto vectorType = dyn_cast<VectorType>(inputType);
755 operands.push_back(llvmInput);
758 Type elemType = vectorType.getElementType();
760 operands.push_back(llvmInput);
767 auto mlirInputType = cast<VectorType>(mlirInput.
getType());
768 bool isInputInteger = mlirInputType.getElementType().isInteger();
769 if (isInputInteger) {
771 bool localIsUnsigned = isUnsigned;
773 localIsUnsigned =
true;
775 localIsUnsigned =
false;
778 NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
783 Type i32 = rewriter.getI32Type();
784 Type intrinsicInType = numBits <= 32
785 ? (
Type)rewriter.getIntegerType(numBits)
786 : (
Type)VectorType::get(numBits / 32, i32);
787 auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
788 Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
789 loc, llvmIntrinsicInType, llvmInput);
794 castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
795 operands.push_back(castInput);
808 Value output, int32_t subwordOffset,
812 auto vectorType = dyn_cast<VectorType>(inputType);
813 Type elemType = vectorType.getElementType();
814 operands.push_back(output);
826 return (chipset ==
kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
827 (
hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
833 return (chipset ==
kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
834 (
hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
842 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
843 b = mfma.getBlocks();
848 if (mfma.getReducePrecision() && chipset >=
kGfx942) {
849 if (m == 32 && n == 32 && k == 4 &&
b == 1)
850 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
851 if (m == 16 && n == 16 && k == 8 &&
b == 1)
852 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
854 if (m == 32 && n == 32 && k == 1 &&
b == 2)
855 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
856 if (m == 16 && n == 16 && k == 1 &&
b == 4)
857 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
858 if (m == 4 && n == 4 && k == 1 &&
b == 16)
859 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
860 if (m == 32 && n == 32 && k == 2 &&
b == 1)
861 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
862 if (m == 16 && n == 16 && k == 4 &&
b == 1)
863 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
868 if (m == 32 && n == 32 && k == 16 &&
b == 1)
869 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
870 if (m == 16 && n == 16 && k == 32 &&
b == 1)
871 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
873 if (m == 32 && n == 32 && k == 4 &&
b == 2)
874 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
875 if (m == 16 && n == 16 && k == 4 &&
b == 4)
876 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
877 if (m == 4 && n == 4 && k == 4 &&
b == 16)
878 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
879 if (m == 32 && n == 32 && k == 8 &&
b == 1)
880 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
881 if (m == 16 && n == 16 && k == 16 &&
b == 1)
882 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
887 if (m == 32 && n == 32 && k == 16 &&
b == 1)
888 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
889 if (m == 16 && n == 16 && k == 32 &&
b == 1)
890 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
893 if (m == 32 && n == 32 && k == 4 &&
b == 2)
894 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
895 if (m == 16 && n == 16 && k == 4 &&
b == 4)
896 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
897 if (m == 4 && n == 4 && k == 4 &&
b == 16)
898 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
899 if (m == 32 && n == 32 && k == 8 &&
b == 1)
900 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
901 if (m == 16 && n == 16 && k == 16 &&
b == 1)
902 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
904 if (m == 32 && n == 32 && k == 2 &&
b == 2)
905 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
906 if (m == 16 && n == 16 && k == 2 &&
b == 4)
907 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
908 if (m == 4 && n == 4 && k == 2 &&
b == 16)
909 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
910 if (m == 32 && n == 32 && k == 4 &&
b == 1)
911 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
912 if (m == 16 && n == 16 && k == 8 &&
b == 1)
913 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
918 if (m == 32 && n == 32 && k == 32 &&
b == 1)
919 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
920 if (m == 16 && n == 16 && k == 64 &&
b == 1)
921 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
923 if (m == 32 && n == 32 && k == 4 &&
b == 2)
924 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
925 if (m == 16 && n == 16 && k == 4 &&
b == 4)
926 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
927 if (m == 4 && n == 4 && k == 4 &&
b == 16)
928 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
929 if (m == 32 && n == 32 && k == 8 &&
b == 1)
930 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
931 if (m == 16 && n == 16 && k == 16 &&
b == 1)
932 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
933 if (m == 32 && n == 32 && k == 16 &&
b == 1 && chipset >=
kGfx942)
934 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
935 if (m == 16 && n == 16 && k == 32 &&
b == 1 && chipset >=
kGfx942)
936 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
940 if (m == 16 && n == 16 && k == 4 &&
b == 1)
941 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
942 if (m == 4 && n == 4 && k == 4 &&
b == 4)
943 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
950 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
951 if (m == 16 && n == 16 && k == 32 &&
b == 1) {
953 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
955 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
957 if (m == 32 && n == 32 && k == 16 &&
b == 1) {
959 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
961 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
967 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
968 if (m == 16 && n == 16 && k == 32 &&
b == 1) {
970 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
972 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
974 if (m == 32 && n == 32 && k == 16 &&
b == 1) {
976 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
978 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
987 .Case([](Float8E4M3FNType) {
return 0u; })
988 .Case([](Float8E5M2Type) {
return 1u; })
989 .Case([](Float6E2M3FNType) {
return 2u; })
990 .Case([](Float6E3M2FNType) {
return 3u; })
991 .Case([](Float4E2M1FNType) {
return 4u; })
992 .Default(std::nullopt);
1002static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1004 uint32_t n, uint32_t k, uint32_t
b,
Chipset chipset) {
1010 return std::nullopt;
1011 if (!isa<Float32Type>(destType))
1012 return std::nullopt;
1016 if (!aTypeCode || !bTypeCode)
1017 return std::nullopt;
1019 if (m == 32 && n == 32 && k == 64 &&
b == 1)
1020 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
1021 *aTypeCode, *bTypeCode};
1022 if (m == 16 && n == 16 && k == 128 &&
b == 1)
1024 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
1027 return std::nullopt;
1030static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1033 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
1034 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
1035 mfma.getBlocks(), chipset);
1038static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1041 smfma.getSourceB().getType(),
1042 smfma.getDestC().getType(), smfma.getM(),
1043 smfma.getN(), smfma.getK(), 1u, chipset);
1048static std::optional<StringRef>
1050 Type elemDestType, uint32_t k,
bool isRDNA3) {
1051 using fp8 = Float8E4M3FNType;
1052 using bf8 = Float8E5M2Type;
1057 if (elemSourceType.
isF16() && elemDestType.
isF32())
1058 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
1059 if (elemSourceType.
isBF16() && elemDestType.
isF32())
1060 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
1061 if (elemSourceType.
isF16() && elemDestType.
isF16())
1062 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1064 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1066 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1071 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1072 return std::nullopt;
1076 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1077 elemDestType.
isF32())
1078 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1079 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1080 elemDestType.
isF32())
1081 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1082 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1083 elemDestType.
isF32())
1084 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1085 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1086 elemDestType.
isF32())
1087 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1089 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1091 return std::nullopt;
1095 if (k == 32 && !isRDNA3) {
1097 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1100 return std::nullopt;
1106 Type elemBSourceType,
1109 using fp8 = Float8E4M3FNType;
1110 using bf8 = Float8E5M2Type;
1113 if (elemSourceType.
isF32() && elemDestType.
isF32())
1114 return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1116 return std::nullopt;
1120 if (elemSourceType.
isF16() && elemDestType.
isF32())
1121 return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
1122 if (elemSourceType.
isBF16() && elemDestType.
isF32())
1123 return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
1124 if (elemSourceType.
isF16() && elemDestType.
isF16())
1125 return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
1127 return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
1129 return std::nullopt;
1133 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1134 if (elemDestType.
isF32())
1135 return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
1136 if (elemDestType.
isF16())
1137 return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
1139 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1140 if (elemDestType.
isF32())
1141 return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
1142 if (elemDestType.
isF16())
1143 return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
1145 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1146 if (elemDestType.
isF32())
1147 return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
1148 if (elemDestType.
isF16())
1149 return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
1151 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1152 if (elemDestType.
isF32())
1153 return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
1154 if (elemDestType.
isF16())
1155 return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
1158 return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
1160 return std::nullopt;
1164 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1165 if (elemDestType.
isF32())
1166 return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
1167 if (elemDestType.
isF16())
1168 return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
1170 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1171 if (elemDestType.
isF32())
1172 return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
1173 if (elemDestType.
isF16())
1174 return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
1176 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1177 if (elemDestType.
isF32())
1178 return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
1179 if (elemDestType.
isF16())
1180 return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
1182 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1183 if (elemDestType.
isF32())
1184 return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
1185 if (elemDestType.
isF16())
1186 return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
1189 return std::nullopt;
1192 return std::nullopt;
1200 bool isGfx950 = chipset >=
kGfx950;
1204 uint32_t m = op.getM(), n = op.getN(), k = op.getK();
1209 if (m == 16 && n == 16 && k == 32) {
1211 return ROCDL::smfmac_f32_16x16x32_f16::getOperationName();
1213 return ROCDL::smfmac_f32_16x16x32_bf16::getOperationName();
1216 if (m == 16 && n == 16 && k == 64) {
1219 return ROCDL::smfmac_f32_16x16x64_f16::getOperationName();
1221 return ROCDL::smfmac_f32_16x16x64_bf16::getOperationName();
1225 return ROCDL::smfmac_i32_16x16x64_i8::getOperationName();
1226 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1227 return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName();
1228 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1229 return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName();
1230 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1231 return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName();
1232 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1233 return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName();
1236 if (m == 16 && n == 16 && k == 128 && isGfx950) {
1239 return ROCDL::smfmac_i32_16x16x128_i8::getOperationName();
1240 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1241 return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName();
1242 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1243 return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName();
1244 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1245 return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName();
1246 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1247 return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName();
1250 if (m == 32 && n == 32 && k == 16) {
1252 return ROCDL::smfmac_f32_32x32x16_f16::getOperationName();
1254 return ROCDL::smfmac_f32_32x32x16_bf16::getOperationName();
1257 if (m == 32 && n == 32 && k == 32) {
1260 return ROCDL::smfmac_f32_32x32x32_f16::getOperationName();
1262 return ROCDL::smfmac_f32_32x32x32_bf16::getOperationName();
1266 return ROCDL::smfmac_i32_32x32x32_i8::getOperationName();
1267 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1268 return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName();
1269 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1270 return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName();
1271 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1272 return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName();
1273 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1274 return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName();
1277 if (m == 32 && n == 32 && k == 64 && isGfx950) {
1280 return ROCDL::smfmac_i32_32x32x64_i8::getOperationName();
1281 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1282 return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName();
1283 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1284 return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName();
1285 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1286 return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName();
1287 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1288 return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName();
1291 return std::nullopt;
1299 auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
1300 auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
1301 auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1302 Type elemSourceType = sourceVectorType.getElementType();
1303 Type elemBSourceType = sourceBVectorType.getElementType();
1304 Type elemDestType = destVectorType.getElementType();
1306 const uint32_t k = wmma.getK();
1311 if (isRDNA3 || isRDNA4)
1320 return std::nullopt;
1325 MFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1326 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
1331 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1332 ConversionPatternRewriter &rewriter)
const override {
1333 Location loc = op.getLoc();
1334 Type outType = typeConverter->convertType(op.getDestD().getType());
1335 Type intrinsicOutType = outType;
1336 if (
auto outVecType = dyn_cast<VectorType>(outType))
1337 if (outVecType.getElementType().isBF16())
1338 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
1340 if (chipset.majorVersion != 9 || chipset <
kGfx908)
1341 return op->emitOpError(
"MFMA only supported on gfx908+");
1342 uint32_t getBlgpField =
static_cast<uint32_t
>(op.getBlgp());
1343 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1345 return op.emitOpError(
"negation unsupported on older than gfx942");
1347 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1350 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1352 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1353 return op.emitOpError(
"no intrinsic matching MFMA size on given chipset");
1356 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1358 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1359 return op.emitOpError(
1360 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1361 "be scaled as those fields are used for type information");
1364 StringRef intrinsicName =
1365 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1368 bool allowBf16 = [&]() {
1373 return intrinsicName.contains(
"16x16x32.bf16") ||
1374 intrinsicName.contains(
"32x32x16.bf16");
1376 OperationState loweredOp(loc, intrinsicName);
1377 loweredOp.addTypes(intrinsicOutType);
1379 rewriter, loc, adaptor.getSourceA(), allowBf16),
1381 rewriter, loc, adaptor.getSourceB(), allowBf16),
1382 adaptor.getDestC()});
1385 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1395 Value lowered = rewriter.create(loweredOp)->getResult(0);
1396 if (outType != intrinsicOutType)
1397 lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1398 rewriter.replaceOp(op, lowered);
1404 ScaledMFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1405 : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1410 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1411 ConversionPatternRewriter &rewriter)
const override {
1412 Location loc = op.getLoc();
1413 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1415 if (chipset.majorVersion != 9 || chipset <
kGfx950)
1416 return op->emitOpError(
"scaled MFMA only supported on gfx908+");
1417 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1419 if (!maybeScaledIntrinsic.has_value())
1420 return op.emitOpError(
1421 "no intrinsic matching scaled MFMA size on given chipset");
1423 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1424 OperationState loweredOp(loc, intrinsicName);
1425 loweredOp.addTypes(intrinsicOutType);
1426 loweredOp.addOperands(
1429 adaptor.getDestC()});
1434 loweredOp.addOperands(
1443 Value lowered = rewriter.create(loweredOp)->getResult(0);
1444 rewriter.replaceOp(op, lowered);
1450 SparseMFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1451 : ConvertOpToLLVMPattern<SparseMFMAOp>(converter), chipset(chipset) {}
1456 matchAndRewrite(SparseMFMAOp op, SparseMFMAOpAdaptor adaptor,
1457 ConversionPatternRewriter &rewriter)
const override {
1458 Location loc = op.getLoc();
1460 typeConverter->convertType<VectorType>(op.getDestC().
getType());
1462 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1465 if (chipset.majorVersion != 9 || chipset <
kGfx942)
1466 return op->emitOpError(
"sparse MFMA (smfmac) only supported on gfx942+");
1467 bool isGfx950 = chipset >=
kGfx950;
1470 adaptor.getSourceA(), isGfx950);
1472 adaptor.getSourceB(), isGfx950);
1473 Value c = adaptor.getDestC();
1476 if (!maybeIntrinsic.has_value())
1477 return op.emitOpError(
1478 "no intrinsic matching sparse MFMA on the given chipset");
1481 Value sparseIdx = LLVM::BitcastOp::create(
1482 rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
1484 OperationState loweredOp(loc, maybeIntrinsic.value());
1485 loweredOp.addTypes(outType);
1486 loweredOp.addOperands({a,
b, c, sparseIdx,
1489 Value lowered = rewriter.create(loweredOp)->getResult(0);
1490 rewriter.replaceOp(op, lowered);
1496 WMMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1497 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1502 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1503 ConversionPatternRewriter &rewriter)
const override {
1504 Location loc = op.getLoc();
1506 typeConverter->convertType<VectorType>(op.getDestD().
getType());
1508 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1510 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1511 return op->emitOpError(
"WMMA only supported on gfx11 and gfx12");
1513 bool isGFX1250 = chipset >=
kGfx1250;
1518 auto aType = cast<VectorType>(adaptor.getSourceA().getType());
1519 auto bType = cast<VectorType>(adaptor.getSourceB().getType());
1520 auto destCType = cast<VectorType>(adaptor.getDestC().getType());
1521 bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
1522 bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
1523 bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
1524 bool castOutToI16 = outType.getElementType().
isBF16() && !isGFX1250;
1525 VectorType rawOutType = outType;
1527 rawOutType = outType.clone(rewriter.getI16Type());
1528 Value a = adaptor.getSourceA();
1530 a = LLVM::BitcastOp::create(rewriter, loc,
1531 aType.clone(rewriter.getI16Type()), a);
1532 Value
b = adaptor.getSourceB();
1534 b = LLVM::BitcastOp::create(rewriter, loc,
1535 bType.clone(rewriter.getI16Type()),
b);
1536 Value destC = adaptor.getDestC();
1538 destC = LLVM::BitcastOp::create(
1539 rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
1543 if (!maybeIntrinsic.has_value())
1544 return op.emitOpError(
"no intrinsic matching WMMA on the given chipset");
1546 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1547 return op.emitOpError(
"subwordOffset not supported on gfx12+");
1549 SmallVector<Value, 4> operands;
1550 SmallVector<NamedAttribute, 4> attrs;
1552 op.getSourceA(), operands, attrs,
"signA");
1554 op.getSourceB(), operands, attrs,
"signB");
1556 op.getSubwordOffset(), op.getClamp(), operands,
1559 OperationState loweredOp(loc, *maybeIntrinsic);
1560 loweredOp.addTypes(rawOutType);
1561 loweredOp.addOperands(operands);
1562 loweredOp.addAttributes(attrs);
1563 Operation *lowered = rewriter.create(loweredOp);
1565 Operation *maybeCastBack = lowered;
1566 if (rawOutType != outType)
1567 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1569 rewriter.replaceOp(op, maybeCastBack->
getResults());
1576 ScaledWMMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1577 : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
1582 matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
1583 ConversionPatternRewriter &rewriter)
const override {
1584 Location loc = op.getLoc();
1586 typeConverter->convertType<VectorType>(op.getDestD().
getType());
1588 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1591 return op->emitOpError(
"WMMA scale only supported on gfx1250+");
1593 int64_t m = op.getM();
1594 int64_t n = op.getN();
1595 int64_t k = op.getK();
1603 if (!aFmtCode || !bFmtCode)
1604 return op.emitOpError(
"unsupported element types for scaled_wmma");
1607 auto scaleAVecType = cast<VectorType>(op.getScaleA().getType());
1608 auto scaleBVecType = cast<VectorType>(op.getScaleB().getType());
1610 if (scaleAVecType.getNumElements() != scaleBVecType.getNumElements())
1611 return op.emitOpError(
"scaleA and scaleB must have equal vector length");
1614 Type scaleAElemType = scaleAVecType.getElementType();
1615 Type scaleBElemType = scaleBVecType.getElementType();
1620 if (!scaleAFmt || !scaleBFmt)
1621 return op.emitOpError(
"unsupported scale element types");
1624 bool isScale16 = (scaleAVecType.getNumElements() == 8);
1625 std::optional<StringRef> intrinsicName =
1628 return op.emitOpError(
"unsupported scaled_wmma dimensions: ")
1629 << m <<
"x" << n <<
"x" << k;
1631 SmallVector<NamedAttribute, 8> attrs;
1634 bool is32x16 = (m == 32 && n == 16 && k == 128);
1636 attrs.emplace_back(
"fmtA", rewriter.getI32IntegerAttr(*aFmtCode));
1637 attrs.emplace_back(
"fmtB", rewriter.getI32IntegerAttr(*bFmtCode));
1641 attrs.emplace_back(
"modC", rewriter.getI16IntegerAttr(0));
1646 "scaleAType", rewriter.getI32IntegerAttr(op.getAFirstScaleLane() / 16));
1647 attrs.emplace_back(
"fmtScaleA", rewriter.getI32IntegerAttr(*scaleAFmt));
1649 "scaleBType", rewriter.getI32IntegerAttr(op.getBFirstScaleLane() / 16));
1650 attrs.emplace_back(
"fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt));
1653 attrs.emplace_back(
"reuseA", rewriter.getBoolAttr(
false));
1654 attrs.emplace_back(
"reuseB", rewriter.getBoolAttr(
false));
1667 OperationState loweredOp(loc, *intrinsicName);
1668 loweredOp.addTypes(outType);
1669 loweredOp.addOperands(
1670 {sourceA, sourceB, adaptor.getDestC(), packedScaleA, packedScaleB});
1671 loweredOp.addAttributes(attrs);
1673 Operation *lowered = rewriter.create(loweredOp);
1674 rewriter.replaceOp(op, lowered->
getResults());
1680struct TransposeLoadOpLowering
1682 TransposeLoadOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1683 : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1688 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1689 ConversionPatternRewriter &rewriter)
const override {
1691 return op.emitOpError(
"Non-gfx950 chipset not supported");
1693 Location loc = op.getLoc();
1694 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1698 size_t srcElementSize =
1699 srcMemRefType.getElementType().getIntOrFloatBitWidth();
1700 if (srcElementSize < 8)
1701 return op.emitOpError(
"Expect source memref to have at least 8 bits "
1702 "element size, got ")
1705 auto resultType = cast<VectorType>(op.getResult().getType());
1708 (adaptor.getSrcIndices()));
1710 size_t numElements = resultType.getNumElements();
1711 size_t elementTypeSize =
1712 resultType.getElementType().getIntOrFloatBitWidth();
1716 Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
1717 rewriter.getIntegerType(32));
1718 Type llvmResultType = typeConverter->convertType(resultType);
1720 switch (elementTypeSize) {
1722 assert(numElements == 16);
1723 auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
1724 rocdlResultType, srcPtr);
1725 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1729 assert(numElements == 16);
1730 auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
1731 rocdlResultType, srcPtr);
1732 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1736 assert(numElements == 8);
1737 auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
1738 rocdlResultType, srcPtr);
1739 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1743 assert(numElements == 4);
1744 rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
1749 return op.emitOpError(
"Unsupported element size for transpose load");
1756 GatherToLDSOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1757 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
1762 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1763 ConversionPatternRewriter &rewriter)
const override {
1764 if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
1765 return op.emitOpError(
"pre-gfx9 and post-gfx10 not supported");
1767 Location loc = op.getLoc();
1769 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1770 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
1775 Type transferType = op.getTransferType();
1776 int loadWidth = [&]() ->
int {
1777 if (
auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1778 return (transferVectorType.getNumElements() *
1779 transferVectorType.getElementTypeBitWidth()) /
1786 if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
1787 return op.emitOpError(
"chipset unsupported element size");
1789 if (chipset !=
kGfx950 && llvm::is_contained({12, 16}, loadWidth))
1790 return op.emitOpError(
"Gather to LDS instructions with 12-byte and "
1791 "16-byte load widths are only supported on gfx950");
1795 (adaptor.getSrcIndices()));
1798 (adaptor.getDstIndices()));
1800 rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
1801 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
1802 rewriter.getI32IntegerAttr(0),
1811struct ExtPackedFp8OpLowering final
1813 ExtPackedFp8OpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1814 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
1819 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1820 ConversionPatternRewriter &rewriter)
const override;
1823struct ScaledExtPackedMatrixOpLowering final
1825 ScaledExtPackedMatrixOpLowering(
const LLVMTypeConverter &converter,
1827 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
1832 matchAndRewrite(ScaledExtPackedMatrixOp op,
1833 ScaledExtPackedMatrixOpAdaptor adaptor,
1834 ConversionPatternRewriter &rewriter)
const override;
1837struct PackedTrunc2xFp8OpLowering final
1839 PackedTrunc2xFp8OpLowering(
const LLVMTypeConverter &converter,
1841 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
1846 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1847 ConversionPatternRewriter &rewriter)
const override;
1850struct PackedStochRoundFp8OpLowering final
1852 PackedStochRoundFp8OpLowering(
const LLVMTypeConverter &converter,
1854 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
1859 matchAndRewrite(PackedStochRoundFp8Op op,
1860 PackedStochRoundFp8OpAdaptor adaptor,
1861 ConversionPatternRewriter &rewriter)
const override;
1864struct ScaledExtPackedOpLowering final
1866 ScaledExtPackedOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1867 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
1872 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1873 ConversionPatternRewriter &rewriter)
const override;
1876struct PackedScaledTruncOpLowering final
1878 PackedScaledTruncOpLowering(
const LLVMTypeConverter &converter,
1880 : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
1885 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1886 ConversionPatternRewriter &rewriter)
const override;
1891LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1892 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1893 ConversionPatternRewriter &rewriter)
const {
1894 Location loc = op.getLoc();
1896 return rewriter.notifyMatchFailure(
1897 loc,
"Fp8 conversion instructions are not available on target "
1898 "architecture and their emulation is not implemented");
1900 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
1901 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1902 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1904 Value source = adaptor.getSource();
1905 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1906 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1909 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1910 Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
1911 if (!sourceVecType) {
1912 longVec = LLVM::InsertElementOp::create(
1915 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1917 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1919 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1924 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1925 if (resultVecType) {
1927 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
1930 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
1935 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
1938 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
1945int32_t getScaleSel(int32_t blockSize,
unsigned bitWidth, int32_t scaleWaveHalf,
1946 int32_t firstScaleByte) {
1952 assert(llvm::is_contained({16, 32}, blockSize));
1953 assert(llvm::is_contained({4u, 6u, 8u}, bitWidth));
1955 const bool isFp8 = bitWidth == 8;
1956 const bool isBlock16 = blockSize == 16;
1959 int32_t bit0 = isBlock16;
1960 assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
1961 int32_t bit1 = (firstScaleByte == 2) << 1;
1962 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
1963 int32_t bit2 = scaleWaveHalf << 2;
1964 return bit2 | bit1 | bit0;
1967 int32_t bit0 = isBlock16;
1969 assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
1970 int32_t bits2and1 = firstScaleByte << 1;
1971 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
1972 int32_t bit3 = scaleWaveHalf << 3;
1973 int32_t bits = bit3 | bits2and1 | bit0;
1975 assert(!llvm::is_contained(
1976 {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
1980static std::optional<StringRef>
1981scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
1982 using fp4 = Float4E2M1FNType;
1983 using fp8 = Float8E4M3FNType;
1984 using bf8 = Float8E5M2Type;
1985 using fp6 = Float6E2M3FNType;
1986 using bf6 = Float6E3M2FNType;
1987 if (isa<fp4>(srcElemType)) {
1988 if (destElemType.
isF16())
1989 return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
1990 if (destElemType.
isBF16())
1991 return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
1992 if (destElemType.
isF32())
1993 return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
1994 return std::nullopt;
1996 if (isa<fp8>(srcElemType)) {
1997 if (destElemType.
isF16())
1998 return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
1999 if (destElemType.
isBF16())
2000 return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
2001 if (destElemType.
isF32())
2002 return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
2003 return std::nullopt;
2005 if (isa<bf8>(srcElemType)) {
2006 if (destElemType.
isF16())
2007 return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
2008 if (destElemType.
isBF16())
2009 return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
2010 if (destElemType.
isF32())
2011 return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
2012 return std::nullopt;
2014 if (isa<fp6>(srcElemType)) {
2015 if (destElemType.
isF16())
2016 return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
2017 if (destElemType.
isBF16())
2018 return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
2019 if (destElemType.
isF32())
2020 return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
2021 return std::nullopt;
2023 if (isa<bf6>(srcElemType)) {
2024 if (destElemType.
isF16())
2025 return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
2026 if (destElemType.
isBF16())
2027 return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
2028 if (destElemType.
isF32())
2029 return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
2030 return std::nullopt;
2032 llvm_unreachable(
"invalid combination of element types for packed conversion "
2036LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
2037 ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
2038 ConversionPatternRewriter &rewriter)
const {
2039 using fp4 = Float4E2M1FNType;
2040 using fp8 = Float8E4M3FNType;
2041 using bf8 = Float8E5M2Type;
2042 using fp6 = Float6E2M3FNType;
2043 using bf6 = Float6E3M2FNType;
2044 Location loc = op.getLoc();
2046 return rewriter.notifyMatchFailure(
2048 "Scaled fp packed conversion instructions are not available on target "
2049 "architecture and their emulation is not implemented");
2053 int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
2054 int32_t firstScaleByte = op.getFirstScaleByte();
2055 int32_t blockSize = op.getBlockSize();
2056 auto sourceType = cast<VectorType>(op.getSource().getType());
2057 auto srcElemType = cast<FloatType>(sourceType.getElementType());
2058 unsigned bitWidth = srcElemType.getWidth();
2060 auto targetType = cast<VectorType>(op.getResult().getType());
2061 auto destElemType = cast<FloatType>(targetType.getElementType());
2063 IntegerType i32 = rewriter.getI32Type();
2064 Value source = adaptor.getSource();
2065 Type llvmResultType = typeConverter->convertType(op.getResult().getType());
2066 Type packedType =
nullptr;
2067 if (isa<fp4>(srcElemType)) {
2069 packedType = getTypeConverter()->convertType(packedType);
2070 }
else if (isa<fp8, bf8>(srcElemType)) {
2071 packedType = VectorType::get(2, i32);
2072 packedType = getTypeConverter()->convertType(packedType);
2073 }
else if (isa<fp6, bf6>(srcElemType)) {
2074 packedType = VectorType::get(3, i32);
2075 packedType = getTypeConverter()->convertType(packedType);
2077 llvm_unreachable(
"invalid element type for packed scaled ext");
2080 if (!packedType || !llvmResultType) {
2081 return rewriter.notifyMatchFailure(op,
"type conversion failed");
2084 std::optional<StringRef> maybeIntrinsic =
2085 scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
2086 if (!maybeIntrinsic.has_value())
2087 return op.emitOpError(
2088 "no intrinsic matching packed scaled conversion on the given chipset");
2091 getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
2093 LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
2094 Value castedSource =
2095 LLVM::BitcastOp::create(rewriter, loc, packedType, source);
2097 OperationState loweredOp(loc, *maybeIntrinsic);
2098 loweredOp.addTypes({llvmResultType});
2099 loweredOp.addOperands({castedSource, castedScale});
2101 SmallVector<NamedAttribute, 1> attrs;
2103 NamedAttribute(
"scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
2105 loweredOp.addAttributes(attrs);
2106 Operation *lowered = rewriter.create(loweredOp);
2107 rewriter.replaceOp(op, lowered);
2112LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
2113 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2114 ConversionPatternRewriter &rewriter)
const {
2115 Location loc = op.getLoc();
2117 return rewriter.notifyMatchFailure(
2118 loc,
"Scaled fp conversion instructions are not available on target "
2119 "architecture and their emulation is not implemented");
2120 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2122 Value source = adaptor.getSource();
2123 Value scale = adaptor.getScale();
2125 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2126 Type sourceElemType = sourceVecType.getElementType();
2127 VectorType destVecType = cast<VectorType>(op.getResult().getType());
2128 Type destElemType = destVecType.getElementType();
2130 VectorType packedVecType;
2131 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
2132 VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
2133 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
2134 }
else if (isa<Float4E2M1FNType>(sourceElemType)) {
2135 VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
2136 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
2138 llvm_unreachable(
"invalid element type for scaled ext");
2142 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
2143 Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
2144 if (!sourceVecType) {
2145 longVec = LLVM::InsertElementOp::create(
2148 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2150 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2152 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2157 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2159 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF32())
2160 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
2161 op, destVecType, i32Source, scale, op.getIndex());
2162 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF16())
2163 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
2164 op, destVecType, i32Source, scale, op.getIndex());
2165 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isBF16())
2166 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
2167 op, destVecType, i32Source, scale, op.getIndex());
2168 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF32())
2169 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
2170 op, destVecType, i32Source, scale, op.getIndex());
2171 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF16())
2172 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
2173 op, destVecType, i32Source, scale, op.getIndex());
2174 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isBF16())
2175 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
2176 op, destVecType, i32Source, scale, op.getIndex());
2177 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF32())
2178 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
2179 op, destVecType, i32Source, scale, op.getIndex());
2180 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF16())
2181 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
2182 op, destVecType, i32Source, scale, op.getIndex());
2183 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isBF16())
2184 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
2185 op, destVecType, i32Source, scale, op.getIndex());
2192LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
2193 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2194 ConversionPatternRewriter &rewriter)
const {
2195 Location loc = op.getLoc();
2197 return rewriter.notifyMatchFailure(
2198 loc,
"Scaled fp conversion instructions are not available on target "
2199 "architecture and their emulation is not implemented");
2200 Type v2i16 = getTypeConverter()->convertType(
2201 VectorType::get(2, rewriter.getI16Type()));
2202 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2204 Type resultType = op.getResult().getType();
2206 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2207 Type sourceElemType = sourceVecType.getElementType();
2209 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
2211 Value source = adaptor.getSource();
2212 Value scale = adaptor.getScale();
2213 Value existing = adaptor.getExisting();
2215 existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
2217 existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
2219 if (sourceVecType.getNumElements() < 2) {
2221 Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2222 VectorType v2 = VectorType::get(2, sourceElemType);
2223 source = LLVM::ZeroOp::create(rewriter, loc, v2);
2224 source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
2227 Value sourceA, sourceB;
2228 if (sourceElemType.
isF32()) {
2231 sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2232 sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
2236 if (sourceElemType.
isF32() && isa<Float8E5M2Type>(resultElemType))
2237 result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
2238 existing, sourceA, sourceB,
2239 scale, op.getIndex());
2240 else if (sourceElemType.
isF16() && isa<Float8E5M2Type>(resultElemType))
2241 result = ROCDL::CvtScaleF32PkBf8F16Op::create(
2242 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2243 else if (sourceElemType.
isBF16() && isa<Float8E5M2Type>(resultElemType))
2244 result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
2245 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2246 else if (sourceElemType.
isF32() && isa<Float8E4M3FNType>(resultElemType))
2247 result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
2248 existing, sourceA, sourceB,
2249 scale, op.getIndex());
2250 else if (sourceElemType.
isF16() && isa<Float8E4M3FNType>(resultElemType))
2251 result = ROCDL::CvtScaleF32PkFp8F16Op::create(
2252 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2253 else if (sourceElemType.
isBF16() && isa<Float8E4M3FNType>(resultElemType))
2254 result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
2255 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2256 else if (sourceElemType.
isF32() && isa<Float4E2M1FNType>(resultElemType))
2257 result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
2258 existing, sourceA, sourceB,
2259 scale, op.getIndex());
2260 else if (sourceElemType.
isF16() && isa<Float4E2M1FNType>(resultElemType))
2261 result = ROCDL::CvtScaleF32PkFp4F16Op::create(
2262 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2263 else if (sourceElemType.
isBF16() && isa<Float4E2M1FNType>(resultElemType))
2264 result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
2265 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2269 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2270 op, getTypeConverter()->convertType(resultType),
result);
2274LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
2275 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
2276 ConversionPatternRewriter &rewriter)
const {
2277 Location loc = op.getLoc();
2279 return rewriter.notifyMatchFailure(
2280 loc,
"Fp8 conversion instructions are not available on target "
2281 "architecture and their emulation is not implemented");
2282 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2284 Type resultType = op.getResult().getType();
2287 Value sourceA = adaptor.getSourceA();
2288 Value sourceB = adaptor.getSourceB();
2290 sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.
getType());
2291 Value existing = adaptor.getExisting();
2293 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2295 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2299 result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2300 existing, op.getWordIndex());
2302 result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2303 existing, op.getWordIndex());
2305 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2306 op, getTypeConverter()->convertType(resultType),
result);
2310LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
2311 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
2312 ConversionPatternRewriter &rewriter)
const {
2313 Location loc = op.getLoc();
2315 return rewriter.notifyMatchFailure(
2316 loc,
"Fp8 conversion instructions are not available on target "
2317 "architecture and their emulation is not implemented");
2318 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2320 Type resultType = op.getResult().getType();
2323 Value source = adaptor.getSource();
2324 Value stoch = adaptor.getStochiasticParam();
2325 Value existing = adaptor.getExisting();
2327 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2329 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2333 result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
2334 existing, op.getStoreIndex());
2336 result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
2337 existing, op.getStoreIndex());
2339 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2340 op, getTypeConverter()->convertType(resultType),
result);
2346struct AMDGPUDPPLowering :
public ConvertOpToLLVMPattern<DPPOp> {
2347 AMDGPUDPPLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2348 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
2352 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
2353 ConversionPatternRewriter &rewriter)
const override {
2356 Location loc = DppOp.getLoc();
2357 Value src = adaptor.getSrc();
2358 Value old = adaptor.getOld();
2361 Type llvmType =
nullptr;
2363 llvmType = rewriter.getI32Type();
2364 }
else if (isa<FloatType>(srcType)) {
2366 ? rewriter.getF32Type()
2367 : rewriter.getF64Type();
2368 }
else if (isa<IntegerType>(srcType)) {
2370 ? rewriter.getI32Type()
2371 : rewriter.getI64Type();
2373 auto llvmSrcIntType = typeConverter->convertType(
2377 auto convertOperand = [&](Value operand, Type operandType) {
2378 if (operandType.getIntOrFloatBitWidth() <= 16) {
2379 if (llvm::isa<FloatType>(operandType)) {
2381 LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
2383 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
2384 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
2385 Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
2387 LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
2389 operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
2394 src = convertOperand(src, srcType);
2395 old = convertOperand(old, oldType);
2398 enum DppCtrl :
unsigned {
2407 ROW_HALF_MIRROR = 0x141,
2412 auto kind = DppOp.getKind();
2413 auto permArgument = DppOp.getPermArgument();
2414 uint32_t DppCtrl = 0;
2418 case DPPPerm::quad_perm: {
2419 auto quadPermAttr = cast<ArrayAttr>(*permArgument);
2421 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
2422 uint32_t num = elem.getInt();
2423 DppCtrl |= num << (i * 2);
2428 case DPPPerm::row_shl: {
2429 auto intAttr = cast<IntegerAttr>(*permArgument);
2430 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
2433 case DPPPerm::row_shr: {
2434 auto intAttr = cast<IntegerAttr>(*permArgument);
2435 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
2438 case DPPPerm::row_ror: {
2439 auto intAttr = cast<IntegerAttr>(*permArgument);
2440 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
2443 case DPPPerm::wave_shl:
2444 DppCtrl = DppCtrl::WAVE_SHL1;
2446 case DPPPerm::wave_shr:
2447 DppCtrl = DppCtrl::WAVE_SHR1;
2449 case DPPPerm::wave_rol:
2450 DppCtrl = DppCtrl::WAVE_ROL1;
2452 case DPPPerm::wave_ror:
2453 DppCtrl = DppCtrl::WAVE_ROR1;
2455 case DPPPerm::row_mirror:
2456 DppCtrl = DppCtrl::ROW_MIRROR;
2458 case DPPPerm::row_half_mirror:
2459 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
2461 case DPPPerm::row_bcast_15:
2462 DppCtrl = DppCtrl::BCAST15;
2464 case DPPPerm::row_bcast_31:
2465 DppCtrl = DppCtrl::BCAST31;
2471 auto rowMask = DppOp->getAttrOfType<IntegerAttr>(
"row_mask").getInt();
2472 auto bankMask = DppOp->getAttrOfType<IntegerAttr>(
"bank_mask").getInt();
2473 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>(
"bound_ctrl").getValue();
2477 ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
2478 rowMask, bankMask, boundCtrl);
2480 Value
result = dppMovOp.getRes();
2482 result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType,
result);
2483 if (!llvm::isa<IntegerType>(srcType)) {
2484 result = LLVM::BitcastOp::create(rewriter, loc, srcType,
result);
2495struct AMDGPUSwizzleBitModeLowering
2496 :
public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
2500 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
2501 ConversionPatternRewriter &rewriter)
const override {
2502 Location loc = op.getLoc();
2503 Type i32 = rewriter.getI32Type();
2504 Value src = adaptor.getSrc();
2505 SmallVector<Value> decomposed =
2507 unsigned andMask = op.getAndMask();
2508 unsigned orMask = op.getOrMask();
2509 unsigned xorMask = op.getXorMask();
2513 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
2515 SmallVector<Value> swizzled;
2516 for (Value v : decomposed) {
2518 ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
2519 swizzled.emplace_back(res);
2523 rewriter.replaceOp(op,
result);
2528struct AMDGPUPermlaneLowering :
public ConvertOpToLLVMPattern<PermlaneSwapOp> {
2531 AMDGPUPermlaneLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2532 : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
2536 matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
2537 ConversionPatternRewriter &rewriter)
const override {
2539 return op->emitOpError(
"permlane_swap is only supported on gfx950+");
2541 Location loc = op.getLoc();
2542 Type i32 = rewriter.getI32Type();
2543 Value src = adaptor.getSrc();
2544 unsigned rowLength = op.getRowLength();
2545 bool fi = op.getFetchInactive();
2546 bool boundctrl = op.getBoundCtrl();
2548 SmallVector<Value> decomposed =
2551 SmallVector<Value> permuted;
2552 for (Value v : decomposed) {
2554 Type i32pair = LLVM::LLVMStructType::getLiteral(
2555 rewriter.getContext(), {v.getType(), v.getType()});
2557 if (rowLength == 16)
2558 res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2560 else if (rowLength == 32)
2561 res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2564 llvm_unreachable(
"unsupported row length");
2566 Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
2567 Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
2569 Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
2570 LLVM::ICmpPredicate::eq, vdst0, v);
2575 LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
2576 permuted.emplace_back(vdstNew);
2580 rewriter.replaceOp(op,
result);
2585static Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
2586 Value accumulator, Value value, int64_t shift) {
2591 value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
2597 constexpr bool isDisjoint =
true;
2598 return LLVM::OrOp::create(rewriter, loc, accumulator, value, isDisjoint);
2601template <
typename BaseOp>
2602struct AMDGPUMakeDmaBaseLowering :
public ConvertOpToLLVMPattern<BaseOp> {
2603 using ConvertOpToLLVMPattern<BaseOp>::ConvertOpToLLVMPattern;
2606 AMDGPUMakeDmaBaseLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2607 : ConvertOpToLLVMPattern<BaseOp>(converter), chipset(chipset) {}
2611 matchAndRewrite(BaseOp op, Adaptor adaptor,
2612 ConversionPatternRewriter &rewriter)
const override {
2614 return op->emitOpError(
"make_dma_base is only supported on gfx1250");
2616 Location loc = op.getLoc();
2618 constexpr int32_t constlen = 4;
2619 Value consts[constlen];
2620 for (int64_t i = 0; i < constlen; ++i)
2623 constexpr int32_t sgprslen = constlen;
2624 Value sgprs[sgprslen];
2625 for (int64_t i = 0; i < sgprslen; ++i) {
2626 sgprs[i] = consts[0];
2629 sgprs[0] = consts[1];
2631 if constexpr (BaseOp::isGather()) {
2632 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30);
2634 auto type = cast<TDMGatherBaseType>(op.getResult().getType());
2635 Type indexType = type.getIndexType();
2637 assert(llvm::is_contained({16u, 32u}, indexSize) &&
2638 "expected index_size to be 16 or 32");
2639 unsigned idx = (indexSize / 16) - 1;
2642 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 31);
2645 ValueRange ldsIndices = adaptor.getLdsIndices();
2646 Value lds = adaptor.getLds();
2647 auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
2650 rewriter, loc, ldsMemRefType, lds, ldsIndices);
2652 ValueRange globalIndices = adaptor.getGlobalIndices();
2653 Value global = adaptor.getGlobal();
2654 auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
2657 rewriter, loc, globalMemRefType, global, globalIndices);
2659 Type i32 = rewriter.getI32Type();
2660 Type i64 = rewriter.getI64Type();
2662 sgprs[1] = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
2663 Value castForGlobalAddr =
2664 LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
2666 sgprs[2] = LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
2668 Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
2671 Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
2674 highHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
2676 sgprs[3] = setValueAtOffset(rewriter, loc, highHalf, consts[2], 30);
2678 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
2679 assert(v4i32 &&
"expected type conversion to succeed");
2680 Value
result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
2682 for (
auto [sgpr, constant] : llvm::zip_equal(sgprs, consts))
2684 LLVM::InsertElementOp::create(rewriter, loc,
result, sgpr, constant);
2686 rewriter.replaceOp(op,
result);
2691template <
typename DescriptorOp>
2692struct AMDGPULowerDescriptor :
public ConvertOpToLLVMPattern<DescriptorOp> {
2693 using ConvertOpToLLVMPattern<DescriptorOp>::ConvertOpToLLVMPattern;
2696 AMDGPULowerDescriptor(
const LLVMTypeConverter &converter, Chipset chipset)
2697 : ConvertOpToLLVMPattern<DescriptorOp>(converter), chipset(chipset) {}
2700 Value getDGroup0(OpAdaptor adaptor)
const {
return adaptor.getBase(); }
2702 Value setWorkgroupMask(DescriptorOp op, OpAdaptor adaptor,
2703 ConversionPatternRewriter &rewriter, Location loc,
2704 Value sgpr0)
const {
2705 Value mask = op.getWorkgroupMask();
2709 Type i16 = rewriter.getI16Type();
2710 mask = LLVM::BitcastOp::create(rewriter, loc, i16, mask);
2711 Type i32 = rewriter.getI32Type();
2712 Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask);
2713 return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0);
2716 Value setDataSize(DescriptorOp op, OpAdaptor adaptor,
2717 ConversionPatternRewriter &rewriter, Location loc,
2718 Value sgpr0, ArrayRef<Value> consts)
const {
2719 unsigned elementTypeWidthInBits = op.getElementTypeWidth();
2720 assert(llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidthInBits) &&
2721 "expected type width to be 8, 16, 32, or 64.");
2722 int64_t idx = llvm::Log2_32(elementTypeWidthInBits / 8);
2723 Value size = consts[idx];
2724 return setValueAtOffset(rewriter, loc, sgpr0, size, 16);
2727 Value setAtomicBarrier(DescriptorOp op, OpAdaptor adaptor,
2728 ConversionPatternRewriter &rewriter, Location loc,
2729 Value sgpr0, ArrayRef<Value> consts)
const {
2730 if (!adaptor.getAtomicBarrierAddress())
2733 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
2736 Value setIterateEnable(DescriptorOp op, OpAdaptor adaptor,
2737 ConversionPatternRewriter &rewriter, Location loc,
2738 Value sgpr0, ArrayRef<Value> consts)
const {
2739 if (!adaptor.getGlobalIncrement())
2744 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
2747 Value setPadEnable(DescriptorOp op, OpAdaptor adaptor,
2748 ConversionPatternRewriter &rewriter, Location loc,
2749 Value sgpr0, ArrayRef<Value> consts)
const {
2750 if (!op.getPadAmount())
2753 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
2756 Value setEarlyTimeout(DescriptorOp op, OpAdaptor adaptor,
2757 ConversionPatternRewriter &rewriter, Location loc,
2758 Value sgpr0, ArrayRef<Value> consts)
const {
2759 if (!op.getWorkgroupMask())
2762 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21);
2765 Value setPadInterval(DescriptorOp op, OpAdaptor adaptor,
2766 ConversionPatternRewriter &rewriter, Location loc,
2767 Value sgpr0, ArrayRef<Value> consts)
const {
2768 if (!op.getPadAmount())
2777 IntegerType i32 = rewriter.getI32Type();
2778 Value padInterval = adaptor.getPadInterval();
2779 padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
2780 padInterval,
false);
2781 padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
2783 return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
2786 Value setPadAmount(DescriptorOp op, OpAdaptor adaptor,
2787 ConversionPatternRewriter &rewriter, Location loc,
2788 Value sgpr0, ArrayRef<Value> consts)
const {
2789 if (!op.getPadAmount())
2798 Value padAmount = adaptor.getPadAmount();
2799 padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
2801 return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
2804 Value setAtomicBarrierAddress(DescriptorOp op, OpAdaptor adaptor,
2805 ConversionPatternRewriter &rewriter,
2806 Location loc, Value sgpr1,
2807 ArrayRef<Value> consts)
const {
2808 if (!adaptor.getAtomicBarrierAddress())
2811 Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
2812 auto barrierAddressTy =
2813 cast<MemRefType>(op.getAtomicBarrierAddress().getType());
2814 ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
2816 rewriter, loc, barrierAddressTy, atomicBarrierAddress,
2817 atomicBarrierIndices);
2818 IntegerType i32 = rewriter.getI32Type();
2824 atomicBarrierAddress =
2825 LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
2826 atomicBarrierAddress =
2827 LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
2829 atomicBarrierAddress =
2830 LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
2831 return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
2834 std::pair<Value, Value> setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
2835 ConversionPatternRewriter &rewriter,
2836 Location loc, Value sgpr1, Value sgpr2,
2837 ArrayRef<Value> consts, uint64_t dimX,
2838 uint32_t offset)
const {
2839 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
2840 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
2841 SmallVector<OpFoldResult> mixedGlobalSizes =
2843 if (mixedGlobalSizes.size() <= dimX)
2844 return {sgpr1, sgpr2};
2846 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
2853 if (
auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
2857 IntegerType i32 = rewriter.getI32Type();
2858 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
2859 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
2862 sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDimX, offset);
2865 Value tensorDimXHigh = LLVM::LShrOp::create(rewriter, loc, tensorDimX, c16);
2866 sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDimXHigh, offset + 16);
2867 return {sgpr1, sgpr2};
2870 std::pair<Value, Value> setTensorDim0(DescriptorOp op, OpAdaptor adaptor,
2871 ConversionPatternRewriter &rewriter,
2872 Location loc, Value sgpr1, Value sgpr2,
2873 ArrayRef<Value> consts)
const {
2874 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, 0,
2878 std::pair<Value, Value> setTensorDim1(DescriptorOp op, OpAdaptor adaptor,
2879 ConversionPatternRewriter &rewriter,
2880 Location loc, Value sgpr2, Value sgpr3,
2881 ArrayRef<Value> consts)
const {
2882 return setTensorDimX(op, adaptor, rewriter, loc, sgpr2, sgpr3, consts, 1,
2886 Value setTileDimX(DescriptorOp op, OpAdaptor adaptor,
2887 ConversionPatternRewriter &rewriter, Location loc,
2888 Value sgpr, ArrayRef<Value> consts,
size_t dimX,
2889 int64_t offset)
const {
2890 ArrayRef<int64_t> sharedStaticSizes = adaptor.getSharedStaticSizes();
2891 ValueRange sharedDynamicSizes = adaptor.getSharedDynamicSizes();
2892 SmallVector<OpFoldResult> mixedSharedSizes =
2894 if (mixedSharedSizes.size() <= dimX)
2897 OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
2906 if (
auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult)) {
2910 IntegerType i32 = rewriter.getI32Type();
2911 tileDimX = cast<Value>(tileDimXOpFoldResult);
2912 tileDimX = LLVM::TruncOp::create(rewriter, loc, i32, tileDimX);
2915 return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
2918 Value setTileDim0(DescriptorOp op, OpAdaptor adaptor,
2919 ConversionPatternRewriter &rewriter, Location loc,
2920 Value sgpr3, ArrayRef<Value> consts)
const {
2921 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
2924 Value setTileDim1(DescriptorOp op, OpAdaptor adaptor,
2925 ConversionPatternRewriter &rewriter, Location loc,
2926 Value sgpr4, ArrayRef<Value> consts)
const {
2927 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
2930 Value setValidIndices(DescriptorOp op, OpAdaptor adaptor,
2931 ConversionPatternRewriter &rewriter, Location loc,
2932 Value sgpr4, ArrayRef<Value> consts)
const {
2933 auto type = cast<VectorType>(op.getIndices().getType());
2934 ArrayRef<int64_t> shape = type.getShape();
2935 assert(shape.size() == 1 &&
"expected shape to be of rank 1.");
2936 unsigned length = shape.back();
2937 assert(0 < length && length <= 16 &&
"expected length to be at most 16.");
2939 return setValueAtOffset(rewriter, loc, sgpr4, value, 128);
2942 Value setTileDim1OrValidIndices(DescriptorOp op, OpAdaptor adaptor,
2943 ConversionPatternRewriter &rewriter,
2944 Location loc, Value sgpr4,
2945 ArrayRef<Value> consts)
const {
2946 if constexpr (DescriptorOp::isGather())
2947 return setValidIndices(op, adaptor, rewriter, loc, sgpr4, consts);
2948 return setTileDim1(op, adaptor, rewriter, loc, sgpr4, consts);
2951 Value setTileDim2(DescriptorOp op, OpAdaptor adaptor,
2952 ConversionPatternRewriter &rewriter, Location loc,
2953 Value sgpr4, ArrayRef<Value> consts)
const {
2955 if constexpr (DescriptorOp::isGather())
2957 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
2960 std::pair<Value, Value>
2961 setTensorDimXStride(DescriptorOp op, OpAdaptor adaptor,
2962 ConversionPatternRewriter &rewriter, Location loc,
2963 Value sgprY, Value sgprZ, ArrayRef<Value> consts,
2964 size_t dimX, int64_t offset)
const {
2965 ArrayRef<int64_t> globalStaticStrides = adaptor.getGlobalStaticStrides();
2966 ValueRange globalDynamicStrides = adaptor.getGlobalDynamicStrides();
2967 SmallVector<OpFoldResult> mixedGlobalStrides =
2968 getMixedValues(globalStaticStrides, globalDynamicStrides, rewriter);
2970 if (mixedGlobalStrides.size() <= dimX)
2971 return {sgprY, sgprZ};
2973 OpFoldResult tensorDimXStrideOpFoldResult =
2974 *(mixedGlobalStrides.rbegin() + dimX);
2979 Value tensorDimXStride;
2980 if (
auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult))
2984 tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
2986 constexpr int64_t first48bits = (1ll << 48) - 1;
2989 LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
2990 IntegerType i32 = rewriter.getI32Type();
2991 Value tensorDimXStrideLow =
2992 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
2993 sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
2995 int64_t shift = (offset % 32) == 0 ? 32 : offset % 32;
2997 Value tensorDimXStrideHigh =
2998 LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
2999 tensorDimXStrideHigh =
3000 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
3001 sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
3003 return {sgprY, sgprZ};
3006 std::pair<Value, Value>
3007 setTensorDim0Stride(DescriptorOp op, OpAdaptor adaptor,
3008 ConversionPatternRewriter &rewriter, Location loc,
3009 Value sgpr5, Value sgpr6, ArrayRef<Value> consts)
const {
3010 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3014 std::pair<Value, Value>
3015 setTensorDim1Stride(DescriptorOp op, OpAdaptor adaptor,
3016 ConversionPatternRewriter &rewriter, Location loc,
3017 Value sgpr5, Value sgpr6, ArrayRef<Value> consts)
const {
3019 if constexpr (DescriptorOp::isGather())
3020 return {sgpr5, sgpr6};
3021 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3025 Value getDGroup1(DescriptorOp op, OpAdaptor adaptor,
3026 ConversionPatternRewriter &rewriter, Location loc,
3027 ArrayRef<Value> consts)
const {
3029 for (int64_t i = 0; i < 8; ++i) {
3030 sgprs[i] = consts[0];
3033 sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]);
3034 sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts);
3035 sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts);
3036 sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3037 sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3038 sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts);
3039 sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts);
3040 sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts);
3043 setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts);
3044 std::tie(sgprs[1], sgprs[2]) =
3045 setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3046 std::tie(sgprs[2], sgprs[3]) =
3047 setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3049 sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
3051 setTileDim1OrValidIndices(op, adaptor, rewriter, loc, sgprs[4], consts);
3052 sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
3053 std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
3054 op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
3055 std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride(
3056 op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts);
3058 IntegerType i32 = rewriter.getI32Type();
3059 Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32));
3060 assert(v8i32 &&
"expected type conversion to succeed");
3061 Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32);
3063 for (
auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) {
3065 LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant);
3071 Value setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3072 ConversionPatternRewriter &rewriter, Location loc,
3073 Value sgpr0, ArrayRef<Value> consts, int64_t dimX,
3074 int64_t offset)
const {
3075 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3076 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3077 SmallVector<OpFoldResult> mixedGlobalSizes =
3079 if (mixedGlobalSizes.size() <=
static_cast<unsigned long>(dimX))
3082 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3084 if (
auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3088 IntegerType i32 = rewriter.getI32Type();
3089 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3090 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3093 return setValueAtOffset(rewriter, loc, sgpr0, tensorDimX, offset);
3096 Value setTensorDim2(DescriptorOp op, OpAdaptor adaptor,
3097 ConversionPatternRewriter &rewriter, Location loc,
3098 Value sgpr0, ArrayRef<Value> consts)
const {
3099 return setTensorDimX(op, adaptor, rewriter, loc, sgpr0, consts, 2, 0);
3102 Value truncateAndSetValueAtOffset(ConversionPatternRewriter &rewriter,
3103 Location loc, Value accumulator,
3104 Value value, int64_t shift)
const {
3106 IntegerType i32 = rewriter.getI32Type();
3107 value = LLVM::TruncOp::create(rewriter, loc, i32, value);
3108 return setValueAtOffset(rewriter, loc, accumulator, value, shift);
3111 Value setLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3112 ConversionPatternRewriter &rewriter, Location loc,
3113 Value sgpr1, ArrayRef<Value> consts,
3114 int64_t offset)
const {
3115 Value ldsAddrIncrement = adaptor.getLdsIncrement();
3116 return setValueAtOffset(rewriter, loc, sgpr1, ldsAddrIncrement, offset);
3119 std::pair<Value, Value>
3120 setGlobalAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3121 ConversionPatternRewriter &rewriter, Location loc,
3122 Value sgpr2, Value sgpr3, ArrayRef<Value> consts,
3123 int64_t offset)
const {
3124 Value globalAddrIncrement = adaptor.getGlobalIncrement();
3125 sgpr2 = truncateAndSetValueAtOffset(rewriter, loc, sgpr2,
3126 globalAddrIncrement, offset);
3128 globalAddrIncrement =
3129 LLVM::LShrOp::create(rewriter, loc, globalAddrIncrement, shift);
3130 constexpr int64_t first16BitsHigh = (1ll << 16) - 1;
3131 sgpr3 = truncateAndSetValueAtOffset(rewriter, loc, sgpr3,
3132 globalAddrIncrement, offset + 32);
3134 sgpr3 = LLVM::AndOp::create(rewriter, loc, sgpr3, mask);
3135 return {sgpr2, sgpr3};
3138 Value setTensorDim3OrLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3139 ConversionPatternRewriter &rewriter,
3140 Location loc, Value sgpr1,
3141 ArrayRef<Value> consts)
const {
3142 Value ldsIncrement = op.getLdsIncrement();
3143 constexpr int64_t dim = 3;
3144 constexpr int64_t offset = 32;
3146 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, consts, dim,
3148 return setLDSAddrIncrement(op, adaptor, rewriter, loc, sgpr1, consts,
3152 std::pair<Value, Value> setTensorDim2StrideOrGlobalAddrIncrement(
3153 DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
3154 Location loc, Value sgpr2, Value sgpr3, ArrayRef<Value> consts)
const {
3155 Value globalIncrement = op.getGlobalIncrement();
3156 constexpr int32_t dim = 2;
3157 constexpr int32_t offset = 64;
3158 if (!globalIncrement)
3159 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3160 consts, dim, offset);
3161 return setGlobalAddrIncrement(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3165 Value setIterateCount(DescriptorOp op, OpAdaptor adaptor,
3166 ConversionPatternRewriter &rewriter, Location loc,
3167 Value sgpr3, ArrayRef<Value> consts,
3168 int32_t offset)
const {
3169 Value iterationCount = adaptor.getIterationCount();
3170 IntegerType i32 = rewriter.getI32Type();
3177 iterationCount = LLVM::TruncOp::create(rewriter, loc, i32, iterationCount);
3179 LLVM::SubOp::create(rewriter, loc, iterationCount, consts[1]);
3180 return setValueAtOffset(rewriter, loc, sgpr3, iterationCount, offset);
3183 Value setTileDim3OrIterateCount(DescriptorOp op, OpAdaptor adaptor,
3184 ConversionPatternRewriter &rewriter,
3185 Location loc, Value sgpr3,
3186 ArrayRef<Value> consts)
const {
3187 Value iterateCount = op.getIterationCount();
3188 constexpr int32_t dim = 2;
3189 constexpr int32_t offset = 112;
3191 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, dim,
3194 return setIterateCount(op, adaptor, rewriter, loc, sgpr3, consts, offset);
3197 Value getDGroup2(DescriptorOp op, OpAdaptor adaptor,
3198 ConversionPatternRewriter &rewriter, Location loc,
3199 ArrayRef<Value> consts)
const {
3200 if constexpr (DescriptorOp::isGather())
3201 return getDGroup2Gather(op, adaptor, rewriter, loc, consts);
3202 return getDGroup2NonGather(op, adaptor, rewriter, loc, consts);
3205 Value getDGroup2NonGather(DescriptorOp op, OpAdaptor adaptor,
3206 ConversionPatternRewriter &rewriter, Location loc,
3207 ArrayRef<Value> consts)
const {
3208 IntegerType i32 = rewriter.getI32Type();
3209 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3210 assert(v4i32 &&
"expected type conversion to succeed.");
3212 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
3213 if (onlyNeedsTwoDescriptors)
3214 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
3216 constexpr int64_t sgprlen = 4;
3217 Value sgprs[sgprlen];
3218 for (
int i = 0; i < sgprlen; ++i)
3219 sgprs[i] = consts[0];
3221 sgprs[0] = setTensorDim2(op, adaptor, rewriter, loc, sgprs[0], consts);
3222 sgprs[1] = setTensorDim3OrLDSAddrIncrement(op, adaptor, rewriter, loc,
3224 std::tie(sgprs[2], sgprs[3]) = setTensorDim2StrideOrGlobalAddrIncrement(
3225 op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3227 setTileDim3OrIterateCount(op, adaptor, rewriter, loc, sgprs[3], consts);
3229 Value dgroup2 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3230 for (
auto [sgpr, constant] : llvm::zip(sgprs, consts))
3232 LLVM::InsertElementOp::create(rewriter, loc, dgroup2, sgpr, constant);
3237 Value getGatherIndices(DescriptorOp op, OpAdaptor adaptor,
3238 ConversionPatternRewriter &rewriter, Location loc,
3239 ArrayRef<Value> consts,
bool firstHalf)
const {
3240 IntegerType i32 = rewriter.getI32Type();
3241 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3242 assert(v4i32 &&
"expected type conversion to succeed.");
3244 Value
indices = adaptor.getIndices();
3245 auto vectorType = cast<VectorType>(
indices.getType());
3246 unsigned length = vectorType.getShape().back();
3247 Type elementType = vectorType.getElementType();
3248 unsigned maxLength = elementType == i32 ? 4 : 8;
3249 int32_t offset = firstHalf ? 0 : maxLength;
3250 unsigned discountedLength =
3251 std::max(
static_cast<int32_t
>(length - offset), 0);
3253 unsigned targetSize = std::min(maxLength, discountedLength);
3255 SmallVector<Value> indicesVector;
3256 for (
unsigned i = offset; i < targetSize + offset; ++i) {
3258 if (i < consts.size())
3262 Value elem = LLVM::ExtractElementOp::create(rewriter, loc,
indices, idx);
3263 indicesVector.push_back(elem);
3266 SmallVector<Value> indicesI32Vector;
3267 if (elementType == i32) {
3268 indicesI32Vector = indicesVector;
3270 for (
unsigned i = 0; i < targetSize; ++i) {
3271 Value index = indicesVector[i];
3272 indicesI32Vector.push_back(
3273 LLVM::ZExtOp::create(rewriter, loc, i32, index));
3275 if ((targetSize % 2) != 0)
3277 indicesI32Vector.push_back(consts[0]);
3280 SmallVector<Value> indicesToInsert;
3281 if (elementType == i32) {
3282 indicesToInsert = indicesI32Vector;
3284 unsigned size = indicesI32Vector.size() / 2;
3285 for (
unsigned i = 0; i < size; ++i) {
3286 Value first = indicesI32Vector[2 * i];
3287 Value second = indicesI32Vector[2 * i + 1];
3288 Value joined = setValueAtOffset(rewriter, loc, first, second, 16);
3289 indicesToInsert.push_back(joined);
3293 Value dgroup = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3294 for (
auto [sgpr, constant] : llvm::zip_first(indicesToInsert, consts))
3296 LLVM::InsertElementOp::create(rewriter, loc, dgroup, sgpr, constant);
3301 Value getDGroup2Gather(DescriptorOp op, OpAdaptor adaptor,
3302 ConversionPatternRewriter &rewriter, Location loc,
3303 ArrayRef<Value> consts)
const {
3304 return getGatherIndices(op, adaptor, rewriter, loc, consts,
true);
3307 std::pair<Value, Value>
3308 setTensorDim3Stride(DescriptorOp op, OpAdaptor adaptor,
3309 ConversionPatternRewriter &rewriter, Location loc,
3310 Value sgpr0, Value sgpr1, ArrayRef<Value> consts)
const {
3311 constexpr int32_t dim = 3;
3312 constexpr int32_t offset = 0;
3313 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr0, sgpr1, consts,
3317 std::pair<Value, Value> setTensorDim4(DescriptorOp op, OpAdaptor adaptor,
3318 ConversionPatternRewriter &rewriter,
3319 Location loc, Value sgpr1, Value sgpr2,
3320 ArrayRef<Value> consts)
const {
3321 constexpr int32_t dim = 4;
3322 constexpr int32_t offset = 48;
3323 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, dim,
3327 Value setTileDim4(DescriptorOp op, OpAdaptor adaptor,
3328 ConversionPatternRewriter &rewriter, Location loc,
3329 Value sgpr2, ArrayRef<Value> consts)
const {
3330 constexpr int32_t dim = 4;
3331 constexpr int32_t offset = 80;
3332 return setTileDimX(op, adaptor, rewriter, loc, sgpr2, consts, dim, offset);
3335 Value getDGroup3(DescriptorOp op, OpAdaptor adaptor,
3336 ConversionPatternRewriter &rewriter, Location loc,
3337 ArrayRef<Value> consts)
const {
3338 if constexpr (DescriptorOp::isGather())
3339 return getDGroup3Gather(op, adaptor, rewriter, loc, consts);
3340 return getDGroup3NonGather(op, adaptor, rewriter, loc, consts);
3343 Value getDGroup3NonGather(DescriptorOp op, OpAdaptor adaptor,
3344 ConversionPatternRewriter &rewriter, Location loc,
3345 ArrayRef<Value> consts)
const {
3346 IntegerType i32 = rewriter.getI32Type();
3347 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3348 assert(v4i32 &&
"expected type conversion to succeed.");
3349 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
3350 if (onlyNeedsTwoDescriptors)
3351 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
3353 constexpr int32_t sgprlen = 4;
3354 Value sgprs[sgprlen];
3355 for (
int i = 0; i < sgprlen; ++i)
3356 sgprs[i] = consts[0];
3358 std::tie(sgprs[0], sgprs[1]) = setTensorDim3Stride(
3359 op, adaptor, rewriter, loc, sgprs[0], sgprs[1], consts);
3360 std::tie(sgprs[1], sgprs[2]) =
3361 setTensorDim4(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3362 sgprs[2] = setTileDim4(op, adaptor, rewriter, loc, sgprs[2], consts);
3364 Value dgroup3 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3365 for (
auto [sgpr, constant] : llvm::zip(sgprs, consts))
3367 LLVM::InsertElementOp::create(rewriter, loc, dgroup3, sgpr, constant);
3372 Value getDGroup3Gather(DescriptorOp op, OpAdaptor adaptor,
3373 ConversionPatternRewriter &rewriter, Location loc,
3374 ArrayRef<Value> consts)
const {
3375 return getGatherIndices(op, adaptor, rewriter, loc, consts,
false);
3379 matchAndRewrite(DescriptorOp op, OpAdaptor adaptor,
3380 ConversionPatternRewriter &rewriter)
const override {
3382 return op->emitOpError(
3383 "make_dma_descriptor is only supported on gfx1250");
3385 Location loc = op.getLoc();
3387 SmallVector<Value> consts;
3388 for (int64_t i = 0; i < 8; ++i)
3391 Value dgroup0 = this->getDGroup0(adaptor);
3392 Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts);
3393 Value dgroup2 = this->getDGroup2(op, adaptor, rewriter, loc, consts);
3394 Value dgroup3 = this->getDGroup3(op, adaptor, rewriter, loc, consts);
3395 SmallVector<Value> results = {dgroup0, dgroup1, dgroup2, dgroup3};
3396 rewriter.replaceOpWithMultiple(op, {results});
3401template <
typename SourceOp,
typename TargetOp>
3402struct AMDGPUTensorLoadStoreOpLowering
3403 :
public ConvertOpToLLVMPattern<SourceOp> {
3404 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
3406 AMDGPUTensorLoadStoreOpLowering(
const LLVMTypeConverter &converter,
3408 : ConvertOpToLLVMPattern<SourceOp>(converter), chipset(chipset) {}
3412 matchAndRewrite(SourceOp op, Adaptor adaptor,
3413 ConversionPatternRewriter &rewriter)
const override {
3415 return op->emitOpError(
"is only supported on gfx1250");
3418 rewriter.replaceOpWithNewOp<TargetOp>(op, desc[0], desc[1], desc[2],
3427struct ConvertAMDGPUToROCDLPass
3428 :
public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
3431 void runOnOperation()
override {
3434 if (
failed(maybeChipset)) {
3435 emitError(UnknownLoc::get(ctx),
"Invalid chipset name: " + chipset);
3436 return signalPassFailure();
3440 LLVMTypeConverter converter(ctx);
3445 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
3446 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
3447 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
3448 if (
failed(applyPartialConversion(getOperation(),
target,
3450 signalPassFailure();
3458 typeConverter, [](gpu::AddressSpace space) {
3460 case gpu::AddressSpace::Global:
3461 return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace;
3462 case gpu::AddressSpace::Workgroup:
3463 return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace;
3464 case gpu::AddressSpace::Private:
3465 return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace;
3467 llvm_unreachable(
"unknown address space enum value");
3473 typeConverter.addTypeAttributeConversion(
3475 -> TypeConverter::AttributeConversionResult {
3477 Type i64 = IntegerType::get(ctx, 64);
3478 switch (as.getValue()) {
3479 case amdgpu::AddressSpace::FatRawBuffer:
3480 return IntegerAttr::get(i64, 7);
3481 case amdgpu::AddressSpace::BufferRsrc:
3482 return IntegerAttr::get(i64, 8);
3483 case amdgpu::AddressSpace::FatStructuredBuffer:
3484 return IntegerAttr::get(i64, 9);
3486 return TypeConverter::AttributeConversionResult::abort();
3488 typeConverter.addConversion([&](TDMBaseType type) ->
Type {
3490 return typeConverter.convertType(VectorType::get(4, i32));
3492 typeConverter.addConversion([&](TDMGatherBaseType type) ->
Type {
3494 return typeConverter.convertType(VectorType::get(4, i32));
3496 typeConverter.addConversion(
3497 [&](TDMDescriptorType type,
3500 Type v4i32 = typeConverter.convertType(VectorType::get(4, i32));
3501 Type v8i32 = typeConverter.convertType(VectorType::get(8, i32));
3502 llvm::append_values(
result, v4i32, v8i32, v4i32, v4i32);
3512 if (inputs.size() != 1)
3515 if (!isa<TDMDescriptorType>(inputs[0].
getType()))
3518 auto cast = UnrealizedConversionCastOp::create(builder, loc, types, inputs);
3519 return cast.getResults();
3522 typeConverter.addTargetMaterialization(addUnrealizedCast);
3530 .add<FatRawBufferCastLowering,
3531 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
3532 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
3533 RawBufferOpLowering<RawBufferAtomicFaddOp,
3534 ROCDL::RawPtrBufferAtomicFaddOp>,
3535 RawBufferOpLowering<RawBufferAtomicFmaxOp,
3536 ROCDL::RawPtrBufferAtomicFmaxOp>,
3537 RawBufferOpLowering<RawBufferAtomicSmaxOp,
3538 ROCDL::RawPtrBufferAtomicSmaxOp>,
3539 RawBufferOpLowering<RawBufferAtomicUminOp,
3540 ROCDL::RawPtrBufferAtomicUminOp>,
3541 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
3542 ROCDL::RawPtrBufferAtomicCmpSwap>,
3543 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
3544 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
3545 SparseMFMAOpLowering, WMMAOpLowering, ScaledWMMAOpLowering,
3546 ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
3547 ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
3548 PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
3549 GatherToLDSOpLowering, TransposeLoadOpLowering,
3550 AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
3551 AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
3552 AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
3553 AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>,
3554 AMDGPUTensorLoadStoreOpLowering<TensorLoadToLDSOp,
3555 ROCDL::TensorLoadToLDSOp>,
3556 AMDGPUTensorLoadStoreOpLowering<TensorStoreFromLDSOp,
3557 ROCDL::TensorStoreFromLDSOp>>(
3558 converter, chipset);
3559 patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k, bool isRDNA3)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma for RDNA3/4 architectures.
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
constexpr Chipset kGfx908
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs, StringRef attrName)
Push an input operand.
constexpr Chipset kGfx1250
static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA/WMMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to R...
constexpr Chipset kGfx90a
static std::optional< StringRef > getScaledWmmaIntrinsicName(int64_t m, int64_t n, int64_t k, bool isScale16)
Determines the ROCDL intrinsic name for scaled WMMA based on dimensions and scale block size (16 or 3...
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static std::optional< StringRef > smfmacOpToIntrinsic(SparseMFMAOp op, Chipset chipset)
Returns the rocdl intrinsic corresponding to a SparseMFMA (smfmac) operation if one exists.
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value)
static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts sparse MFMA (smfmac) operands to the expected ROCDL types.
static std::optional< StringRef > wmmaOpToIntrinsicGfx1250(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k)
Return the rocdl intrinsic corresponding to a WMMA operation wmma for the gfx1250 architecture.
static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Pack small float vector operands (fp4/fp6/fp8/bf16) into the format expected by scaled matrix multipl...
static std::optional< uint32_t > getWmmaScaleFormat(Type elemType)
Maps f8 scale element types to WMMA scale format codes.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static std::optional< uint32_t > smallFloatTypeToFormatCode(Type mlirElemType)
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i64.
constexpr Chipset kGfx950
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, int64_t elementByteWidth)
Compute the contents of the num_records field for a given memref descriptor - that is,...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
typename SourceOp::Adaptor OpAdaptor
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
NamedAttribute represents a combination of a name and an Attribute value.
This class helps build Operations.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
::mlir::Pass::Option< std::string > chipset
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
bool hasOcpFp8(const Chipset &chipset)
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces and types,...
llvm::TypeSwitch< T, ResultT > TypeSwitch
void populateAMDGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.