28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/TypeSwitch.h"
30#include "llvm/Support/Casting.h"
31#include "llvm/Support/ErrorHandling.h"
35#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
36#include "mlir/Conversion/Passes.h.inc"
52 IntegerType i32 = rewriter.getI32Type();
54 auto valTy = cast<IntegerType>(val.
getType());
57 return valTy.getWidth() > 32
58 ?
Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
59 :
Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
64 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
70 IntegerType i64 = rewriter.getI64Type();
72 auto valTy = cast<IntegerType>(val.
getType());
75 return valTy.getWidth() > 64
76 ?
Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
77 :
Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
82 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
89 IntegerType i32 = rewriter.getI32Type();
91 for (
auto [i, increment, stride] : llvm::enumerate(
indices, strides)) {
94 ShapedType::isDynamic(stride)
96 memRefDescriptor.
stride(rewriter, loc, i))
97 : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
98 increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
110 MemRefType memrefType,
114 if (memrefType.hasStaticShape() &&
115 !llvm::any_of(strides, ShapedType::isDynamic)) {
116 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
118 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
119 size = std::max(
shape[i] * strides[i], size);
120 size = size * elementByteWidth;
124 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
125 Value size = memrefDescriptor.
size(rewriter, loc, i);
126 Value stride = memrefDescriptor.
stride(rewriter, loc, i);
127 Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
129 ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
134 return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
140 Value cacheSwizzleStride =
nullptr,
141 unsigned addressSpace = 8) {
145 Type i16 = rewriter.getI16Type();
148 Value cacheStrideZext =
149 LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
150 Value swizzleBit = LLVM::ConstantOp::create(
151 rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
152 stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
155 stride = LLVM::ConstantOp::create(rewriter, loc, i16,
156 rewriter.getI16IntegerAttr(0));
173 uint32_t flags = (7 << 12) | (4 << 15);
176 uint32_t oob = boundsCheck ? 3 : 2;
177 flags |= (oob << 28);
181 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
182 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
183 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
188struct FatRawBufferCastLowering
190 FatRawBufferCastLowering(
const LLVMTypeConverter &converter, Chipset chipset)
191 : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
197 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
198 ConversionPatternRewriter &rewriter)
const override {
199 Location loc = op.getLoc();
200 Value memRef = adaptor.getSource();
201 Value unconvertedMemref = op.getSource();
202 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
203 MemRefDescriptor descriptor(memRef);
205 DataLayout dataLayout = DataLayout::closest(op);
206 int64_t elementByteWidth =
209 int64_t unusedOffset = 0;
210 SmallVector<int64_t, 5> strideVals;
211 if (
failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
212 return op.emitOpError(
"Can't lower non-stride-offset memrefs");
214 Value numRecords = adaptor.getValidBytes();
216 numRecords =
getNumRecords(rewriter, loc, memrefType, descriptor,
217 strideVals, elementByteWidth);
220 adaptor.getResetOffset()
221 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
223 : descriptor.alignedPtr(rewriter, loc);
225 Value offset = adaptor.getResetOffset()
226 ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
227 rewriter.getIndexAttr(0))
228 : descriptor.offset(rewriter, loc);
230 bool hasSizes = memrefType.getRank() > 0;
233 Value sizes = hasSizes
234 ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
238 hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
243 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
244 chipset, adaptor.getCacheSwizzleStride(), 7);
246 Value
result = MemRefDescriptor::poison(
248 getTypeConverter()->convertType(op.getResult().getType()));
250 result = LLVM::InsertValueOp::create(rewriter, loc,
result, fatPtr, pos);
251 result = LLVM::InsertValueOp::create(rewriter, loc,
result, fatPtr,
253 result = LLVM::InsertValueOp::create(rewriter, loc,
result, offset,
256 result = LLVM::InsertValueOp::create(rewriter, loc,
result, sizes,
258 result = LLVM::InsertValueOp::create(rewriter, loc,
result, strides,
261 rewriter.replaceOp(op,
result);
267template <
typename GpuOp,
typename Intrinsic>
269 RawBufferOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
270 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
273 static constexpr uint32_t maxVectorOpWidth = 128;
276 matchAndRewrite(GpuOp gpuOp,
typename GpuOp::Adaptor adaptor,
277 ConversionPatternRewriter &rewriter)
const override {
278 Location loc = gpuOp.getLoc();
279 Value memref = adaptor.getMemref();
280 Value unconvertedMemref = gpuOp.getMemref();
281 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
283 if (chipset.majorVersion < 9)
284 return gpuOp.emitOpError(
"raw buffer ops require GCN or higher");
286 Value storeData = adaptor.getODSOperands(0)[0];
287 if (storeData == memref)
291 wantedDataType = storeData.
getType();
293 wantedDataType = gpuOp.getODSResults(0)[0].getType();
295 Value atomicCmpData = Value();
298 Value maybeCmpData = adaptor.getODSOperands(1)[0];
299 if (maybeCmpData != memref)
300 atomicCmpData = maybeCmpData;
303 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
305 Type i32 = rewriter.getI32Type();
308 DataLayout dataLayout = DataLayout::closest(gpuOp);
309 int64_t elementByteWidth =
318 Type llvmBufferValType = llvmWantedDataType;
320 if (
auto floatType = dyn_cast<FloatType>(wantedDataType))
321 llvmBufferValType = this->getTypeConverter()->convertType(
322 rewriter.getIntegerType(floatType.getWidth()));
324 if (
auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
325 uint32_t vecLen = dataVector.getNumElements();
328 uint32_t totalBits = elemBits * vecLen;
330 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
331 if (totalBits > maxVectorOpWidth)
332 return gpuOp.emitOpError(
333 "Total width of loads or stores must be no more than " +
334 Twine(maxVectorOpWidth) +
" bits, but we call for " +
336 " bits. This should've been caught in validation");
337 if (!usePackedFp16 && elemBits < 32) {
338 if (totalBits > 32) {
339 if (totalBits % 32 != 0)
340 return gpuOp.emitOpError(
"Load or store of more than 32-bits that "
341 "doesn't fit into words. Can't happen\n");
342 llvmBufferValType = this->typeConverter->convertType(
343 VectorType::get(totalBits / 32, i32));
345 llvmBufferValType = this->typeConverter->convertType(
346 rewriter.getIntegerType(totalBits));
350 if (
auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
353 if (vecType.getNumElements() == 1)
354 llvmBufferValType = vecType.getElementType();
357 SmallVector<Value, 6> args;
359 if (llvmBufferValType != llvmWantedDataType) {
360 Value castForStore = LLVM::BitcastOp::create(
361 rewriter, loc, llvmBufferValType, storeData);
362 args.push_back(castForStore);
364 args.push_back(storeData);
369 if (llvmBufferValType != llvmWantedDataType) {
370 Value castForCmp = LLVM::BitcastOp::create(
371 rewriter, loc, llvmBufferValType, atomicCmpData);
372 args.push_back(castForCmp);
374 args.push_back(atomicCmpData);
380 SmallVector<int64_t, 5> strides;
381 if (
failed(memrefType.getStridesAndOffset(strides, offset)))
382 return gpuOp.emitOpError(
"Can't lower non-stride-offset memrefs");
384 MemRefDescriptor memrefDescriptor(memref);
386 Value ptr = memrefDescriptor.bufferPtr(
387 rewriter, loc, *this->getTypeConverter(), memrefType);
389 rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
391 adaptor.getBoundsCheck(), chipset);
392 args.push_back(resource);
396 adaptor.getIndices(), strides);
397 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
398 indexOffset && *indexOffset > 0) {
400 voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
404 voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
405 args.push_back(voffset);
408 Value sgprOffset = adaptor.getSgprOffset();
411 sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
412 args.push_back(sgprOffset);
419 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
421 Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
422 ArrayRef<NamedAttribute>());
425 if (llvmBufferValType != llvmWantedDataType) {
426 replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
431 rewriter.eraseOp(gpuOp);
448static FailureOr<unsigned> encodeWaitcnt(
Chipset chipset,
unsigned vmcnt,
449 unsigned expcnt,
unsigned lgkmcnt) {
451 vmcnt = std::min(15u, vmcnt);
452 expcnt = std::min(7u, expcnt);
453 lgkmcnt = std::min(15u, lgkmcnt);
454 return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
457 vmcnt = std::min(63u, vmcnt);
458 expcnt = std::min(7u, expcnt);
459 lgkmcnt = std::min(15u, lgkmcnt);
460 unsigned lowBits = vmcnt & 0xF;
461 unsigned highBits = (vmcnt >> 4) << 14;
462 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
463 return lowBits | highBits | otherCnts;
466 vmcnt = std::min(63u, vmcnt);
467 expcnt = std::min(7u, expcnt);
468 lgkmcnt = std::min(63u, lgkmcnt);
469 unsigned lowBits = vmcnt & 0xF;
470 unsigned highBits = (vmcnt >> 4) << 14;
471 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
472 return lowBits | highBits | otherCnts;
475 vmcnt = std::min(63u, vmcnt);
476 expcnt = std::min(7u, expcnt);
477 lgkmcnt = std::min(63u, lgkmcnt);
478 return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
483struct MemoryCounterWaitOpLowering
493 matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
494 ConversionPatternRewriter &rewriter)
const override {
495 if (
chipset.majorVersion >= 12) {
497 if (std::optional<int> ds = adaptor.getDs())
498 ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
500 if (std::optional<int>
load = adaptor.getLoad())
501 ROCDL::WaitLoadcntOp::create(rewriter, loc, *
load);
503 if (std::optional<int> store = adaptor.getStore())
504 ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
506 if (std::optional<int> exp = adaptor.getExp())
507 ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
509 rewriter.eraseOp(op);
513 auto getVal = [](
Attribute attr) ->
unsigned {
515 return cast<IntegerAttr>(attr).getInt();
520 unsigned ds = getVal(adaptor.getDsAttr());
521 unsigned exp = getVal(adaptor.getExpAttr());
523 unsigned vmcnt = 1024;
525 Attribute store = adaptor.getStoreAttr();
527 vmcnt = getVal(
load) + getVal(store);
529 vmcnt = getVal(
load);
531 vmcnt = getVal(store);
534 FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
536 return op.emitOpError(
"unsupported chipset");
538 rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
544 LDSBarrierOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
545 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
550 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
551 ConversionPatternRewriter &rewriter)
const override {
552 Location loc = op.getLoc();
555 bool requiresInlineAsm = chipset <
kGfx90a;
558 rewriter.getAttr<LLVM::MMRATagAttr>(
"amdgpu-synchronize-as",
"local");
567 StringRef scope =
"workgroup";
569 auto relFence = LLVM::FenceOp::create(rewriter, loc,
570 LLVM::AtomicOrdering::release, scope);
571 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
572 if (requiresInlineAsm) {
573 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
574 LLVM::AsmDialect::AD_ATT);
575 const char *asmStr =
";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
576 const char *constraints =
"";
577 LLVM::InlineAsmOp::create(
580 asmStr, constraints,
true,
581 false, LLVM::TailCallKind::None,
584 }
else if (chipset.majorVersion < 12) {
585 ROCDL::SBarrierOp::create(rewriter, loc);
587 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
588 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
591 auto acqFence = LLVM::FenceOp::create(rewriter, loc,
592 LLVM::AtomicOrdering::acquire, scope);
593 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
594 rewriter.replaceOp(op, acqFence);
600 SchedBarrierOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
601 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
606 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
607 ConversionPatternRewriter &rewriter)
const override {
608 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
609 (uint32_t)op.getOpts());
633 bool allowBf16 =
true) {
635 if (
auto vectorType = dyn_cast<VectorType>(inputType)) {
636 if (vectorType.getElementType().isBF16() && !allowBf16)
637 return LLVM::BitcastOp::create(
638 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
639 if (vectorType.getElementType().isInteger(8) &&
640 vectorType.getNumElements() <= 8)
641 return LLVM::BitcastOp::create(
643 rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
644 if (isa<IntegerType>(vectorType.getElementType()) &&
645 vectorType.getElementTypeBitWidth() <= 8) {
646 int64_t numWords = llvm::divideCeil(
647 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
649 return LLVM::BitcastOp::create(
650 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
670 Type outputType = rewriter.getI32Type();
671 if (
auto intType = dyn_cast<IntegerType>(inputType))
672 return LLVM::ZExtOp::create(rewriter, loc, outputType, input);
673 return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
685 ConversionPatternRewriter &rewriter,
Location loc,
690 auto vectorType = dyn_cast<VectorType>(inputType);
692 operands.push_back(llvmInput);
695 Type elemType = vectorType.getElementType();
697 operands.push_back(llvmInput);
704 auto mlirInputType = cast<VectorType>(mlirInput.
getType());
705 bool isInputInteger = mlirInputType.getElementType().isInteger();
706 if (isInputInteger) {
708 bool localIsUnsigned = isUnsigned;
710 localIsUnsigned =
true;
712 localIsUnsigned =
false;
715 NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
720 Type i32 = rewriter.getI32Type();
721 Type intrinsicInType = numBits <= 32
722 ? (
Type)rewriter.getIntegerType(numBits)
723 : (
Type)VectorType::get(numBits / 32, i32);
724 auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
725 Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
726 loc, llvmIntrinsicInType, llvmInput);
731 castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
732 operands.push_back(castInput);
745 Value output, int32_t subwordOffset,
749 auto vectorType = dyn_cast<VectorType>(inputType);
750 Type elemType = vectorType.getElementType();
751 operands.push_back(output);
763 return (chipset ==
kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
764 (
hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
770 return (chipset ==
kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
771 (
hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
779 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
780 b = mfma.getBlocks();
785 if (mfma.getReducePrecision() && chipset >=
kGfx942) {
786 if (m == 32 && n == 32 && k == 4 &&
b == 1)
787 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
788 if (m == 16 && n == 16 && k == 8 &&
b == 1)
789 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
791 if (m == 32 && n == 32 && k == 1 &&
b == 2)
792 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
793 if (m == 16 && n == 16 && k == 1 &&
b == 4)
794 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
795 if (m == 4 && n == 4 && k == 1 &&
b == 16)
796 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
797 if (m == 32 && n == 32 && k == 2 &&
b == 1)
798 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
799 if (m == 16 && n == 16 && k == 4 &&
b == 1)
800 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
805 if (m == 32 && n == 32 && k == 16 &&
b == 1)
806 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
807 if (m == 16 && n == 16 && k == 32 &&
b == 1)
808 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
810 if (m == 32 && n == 32 && k == 4 &&
b == 2)
811 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
812 if (m == 16 && n == 16 && k == 4 &&
b == 4)
813 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
814 if (m == 4 && n == 4 && k == 4 &&
b == 16)
815 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
816 if (m == 32 && n == 32 && k == 8 &&
b == 1)
817 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
818 if (m == 16 && n == 16 && k == 16 &&
b == 1)
819 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
824 if (m == 32 && n == 32 && k == 16 &&
b == 1)
825 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
826 if (m == 16 && n == 16 && k == 32 &&
b == 1)
827 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
830 if (m == 32 && n == 32 && k == 4 &&
b == 2)
831 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
832 if (m == 16 && n == 16 && k == 4 &&
b == 4)
833 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
834 if (m == 4 && n == 4 && k == 4 &&
b == 16)
835 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
836 if (m == 32 && n == 32 && k == 8 &&
b == 1)
837 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
838 if (m == 16 && n == 16 && k == 16 &&
b == 1)
839 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
841 if (m == 32 && n == 32 && k == 2 &&
b == 2)
842 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
843 if (m == 16 && n == 16 && k == 2 &&
b == 4)
844 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
845 if (m == 4 && n == 4 && k == 2 &&
b == 16)
846 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
847 if (m == 32 && n == 32 && k == 4 &&
b == 1)
848 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
849 if (m == 16 && n == 16 && k == 8 &&
b == 1)
850 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
855 if (m == 32 && n == 32 && k == 32 &&
b == 1)
856 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
857 if (m == 16 && n == 16 && k == 64 &&
b == 1)
858 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
860 if (m == 32 && n == 32 && k == 4 &&
b == 2)
861 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
862 if (m == 16 && n == 16 && k == 4 &&
b == 4)
863 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
864 if (m == 4 && n == 4 && k == 4 &&
b == 16)
865 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
866 if (m == 32 && n == 32 && k == 8 &&
b == 1)
867 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
868 if (m == 16 && n == 16 && k == 16 &&
b == 1)
869 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
870 if (m == 32 && n == 32 && k == 16 &&
b == 1 && chipset >=
kGfx942)
871 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
872 if (m == 16 && n == 16 && k == 32 &&
b == 1 && chipset >=
kGfx942)
873 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
877 if (m == 16 && n == 16 && k == 4 &&
b == 1)
878 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
879 if (m == 4 && n == 4 && k == 4 &&
b == 4)
880 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
887 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
888 if (m == 16 && n == 16 && k == 32 &&
b == 1) {
890 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
892 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
894 if (m == 32 && n == 32 && k == 16 &&
b == 1) {
896 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
898 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
904 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
905 if (m == 16 && n == 16 && k == 32 &&
b == 1) {
907 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
909 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
911 if (m == 32 && n == 32 && k == 16 &&
b == 1) {
913 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
915 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
924 .Case([](Float8E4M3FNType) {
return 0u; })
925 .Case([](Float8E5M2Type) {
return 1u; })
926 .Case([](Float6E2M3FNType) {
return 2u; })
927 .Case([](Float6E3M2FNType) {
return 3u; })
928 .Case([](Float4E2M1FNType) {
return 4u; })
929 .Default(std::nullopt);
939static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
941 uint32_t n, uint32_t k, uint32_t
b,
Chipset chipset) {
948 if (!isa<Float32Type>(destType))
953 if (!aTypeCode || !bTypeCode)
956 if (m == 32 && n == 32 && k == 64 &&
b == 1)
957 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
958 *aTypeCode, *bTypeCode};
959 if (m == 16 && n == 16 && k == 128 &&
b == 1)
961 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
967static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
970 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
971 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
972 mfma.getBlocks(), chipset);
975static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
978 smfma.getSourceB().getType(),
979 smfma.getDestC().getType(), smfma.getM(),
980 smfma.getN(), smfma.getK(), 1u, chipset);
985static std::optional<StringRef>
987 Type elemDestType, uint32_t k,
bool isRDNA3) {
988 using fp8 = Float8E4M3FNType;
989 using bf8 = Float8E5M2Type;
994 if (elemSourceType.
isF16() && elemDestType.
isF32())
995 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
996 if (elemSourceType.
isBF16() && elemDestType.
isF32())
997 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
998 if (elemSourceType.
isF16() && elemDestType.
isF16())
999 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1001 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1003 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1008 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1009 return std::nullopt;
1013 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1014 elemDestType.
isF32())
1015 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1016 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1017 elemDestType.
isF32())
1018 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1019 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1020 elemDestType.
isF32())
1021 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1022 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1023 elemDestType.
isF32())
1024 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1026 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1028 return std::nullopt;
1032 if (k == 32 && !isRDNA3) {
1034 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1037 return std::nullopt;
1043 Type elemBSourceType,
1046 using fp8 = Float8E4M3FNType;
1047 using bf8 = Float8E5M2Type;
1050 if (elemSourceType.
isF32() && elemDestType.
isF32())
1051 return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1053 return std::nullopt;
1057 if (elemSourceType.
isF16() && elemDestType.
isF32())
1058 return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
1059 if (elemSourceType.
isBF16() && elemDestType.
isF32())
1060 return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
1061 if (elemSourceType.
isF16() && elemDestType.
isF16())
1062 return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
1064 return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
1066 return std::nullopt;
1070 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1071 if (elemDestType.
isF32())
1072 return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
1073 if (elemDestType.
isF16())
1074 return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
1076 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1077 if (elemDestType.
isF32())
1078 return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
1079 if (elemDestType.
isF16())
1080 return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
1082 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1083 if (elemDestType.
isF32())
1084 return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
1085 if (elemDestType.
isF16())
1086 return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
1088 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1089 if (elemDestType.
isF32())
1090 return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
1091 if (elemDestType.
isF16())
1092 return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
1095 return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
1097 return std::nullopt;
1101 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1102 if (elemDestType.
isF32())
1103 return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
1104 if (elemDestType.
isF16())
1105 return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
1107 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1108 if (elemDestType.
isF32())
1109 return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
1110 if (elemDestType.
isF16())
1111 return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
1113 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1114 if (elemDestType.
isF32())
1115 return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
1116 if (elemDestType.
isF16())
1117 return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
1119 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1120 if (elemDestType.
isF32())
1121 return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
1122 if (elemDestType.
isF16())
1123 return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
1126 return std::nullopt;
1129 return std::nullopt;
1137 auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
1138 auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
1139 auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1140 Type elemSourceType = sourceVectorType.getElementType();
1141 Type elemBSourceType = sourceBVectorType.getElementType();
1142 Type elemDestType = destVectorType.getElementType();
1144 const uint32_t k = wmma.getK();
1149 if (isRDNA3 || isRDNA4)
1158 return std::nullopt;
1163 MFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1164 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
1169 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1170 ConversionPatternRewriter &rewriter)
const override {
1171 Location loc = op.getLoc();
1172 Type outType = typeConverter->convertType(op.getDestD().getType());
1173 Type intrinsicOutType = outType;
1174 if (
auto outVecType = dyn_cast<VectorType>(outType))
1175 if (outVecType.getElementType().isBF16())
1176 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
1178 if (chipset.majorVersion != 9 || chipset <
kGfx908)
1179 return op->emitOpError(
"MFMA only supported on gfx908+");
1180 uint32_t getBlgpField =
static_cast<uint32_t
>(op.getBlgp());
1181 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1183 return op.emitOpError(
"negation unsupported on older than gfx942");
1185 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1188 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1190 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1191 return op.emitOpError(
"no intrinsic matching MFMA size on given chipset");
1194 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1196 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1197 return op.emitOpError(
1198 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1199 "be scaled as those fields are used for type information");
1202 StringRef intrinsicName =
1203 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1206 bool allowBf16 = [&]() {
1211 return intrinsicName.contains(
"16x16x32.bf16") ||
1212 intrinsicName.contains(
"32x32x16.bf16");
1214 OperationState loweredOp(loc, intrinsicName);
1215 loweredOp.addTypes(intrinsicOutType);
1217 rewriter, loc, adaptor.getSourceA(), allowBf16),
1219 rewriter, loc, adaptor.getSourceB(), allowBf16),
1220 adaptor.getDestC()});
1223 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1233 Value lowered = rewriter.create(loweredOp)->getResult(0);
1234 if (outType != intrinsicOutType)
1235 lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1236 rewriter.replaceOp(op, lowered);
1242 ScaledMFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1243 : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1248 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1249 ConversionPatternRewriter &rewriter)
const override {
1250 Location loc = op.getLoc();
1251 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1253 if (chipset.majorVersion != 9 || chipset <
kGfx950)
1254 return op->emitOpError(
"scaled MFMA only supported on gfx908+");
1255 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1257 if (!maybeScaledIntrinsic.has_value())
1258 return op.emitOpError(
1259 "no intrinsic matching scaled MFMA size on given chipset");
1261 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1262 OperationState loweredOp(loc, intrinsicName);
1263 loweredOp.addTypes(intrinsicOutType);
1264 loweredOp.addOperands(
1267 adaptor.getDestC()});
1272 loweredOp.addOperands(
1281 Value lowered = rewriter.create(loweredOp)->getResult(0);
1282 rewriter.replaceOp(op, lowered);
1288 WMMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1289 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1294 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1295 ConversionPatternRewriter &rewriter)
const override {
1296 Location loc = op.getLoc();
1298 typeConverter->convertType<VectorType>(op.getDestD().
getType());
1300 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1302 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1303 return op->emitOpError(
"WMMA only supported on gfx11 and gfx12");
1305 bool isGFX1250 = chipset >=
kGfx1250;
1310 auto aType = cast<VectorType>(adaptor.getSourceA().getType());
1311 auto bType = cast<VectorType>(adaptor.getSourceB().getType());
1312 auto destCType = cast<VectorType>(adaptor.getDestC().getType());
1313 bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
1314 bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
1315 bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
1316 bool castOutToI16 = outType.getElementType().
isBF16() && !isGFX1250;
1317 VectorType rawOutType = outType;
1319 rawOutType = outType.clone(rewriter.getI16Type());
1320 Value a = adaptor.getSourceA();
1322 a = LLVM::BitcastOp::create(rewriter, loc,
1323 aType.clone(rewriter.getI16Type()), a);
1324 Value
b = adaptor.getSourceB();
1326 b = LLVM::BitcastOp::create(rewriter, loc,
1327 bType.clone(rewriter.getI16Type()),
b);
1328 Value destC = adaptor.getDestC();
1330 destC = LLVM::BitcastOp::create(
1331 rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
1335 if (!maybeIntrinsic.has_value())
1336 return op.emitOpError(
"no intrinsic matching WMMA on the given chipset");
1338 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1339 return op.emitOpError(
"subwordOffset not supported on gfx12+");
1341 SmallVector<Value, 4> operands;
1342 SmallVector<NamedAttribute, 4> attrs;
1344 op.getSourceA(), operands, attrs,
"signA");
1346 op.getSourceB(), operands, attrs,
"signB");
1348 op.getSubwordOffset(), op.getClamp(), operands,
1351 OperationState loweredOp(loc, *maybeIntrinsic);
1352 loweredOp.addTypes(rawOutType);
1353 loweredOp.addOperands(operands);
1354 loweredOp.addAttributes(attrs);
1355 Operation *lowered = rewriter.create(loweredOp);
1357 Operation *maybeCastBack = lowered;
1358 if (rawOutType != outType)
1359 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1361 rewriter.replaceOp(op, maybeCastBack->
getResults());
1367struct TransposeLoadOpLowering
1369 TransposeLoadOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1370 : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1375 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1376 ConversionPatternRewriter &rewriter)
const override {
1378 return op.emitOpError(
"Non-gfx950 chipset not supported");
1380 Location loc = op.getLoc();
1381 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1385 size_t srcElementSize =
1386 srcMemRefType.getElementType().getIntOrFloatBitWidth();
1387 if (srcElementSize < 8)
1388 return op.emitOpError(
"Expect source memref to have at least 8 bits "
1389 "element size, got ")
1392 auto resultType = cast<VectorType>(op.getResult().getType());
1395 (adaptor.getSrcIndices()));
1397 size_t numElements = resultType.getNumElements();
1398 size_t elementTypeSize =
1399 resultType.getElementType().getIntOrFloatBitWidth();
1403 Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
1404 rewriter.getIntegerType(32));
1405 Type llvmResultType = typeConverter->convertType(resultType);
1407 switch (elementTypeSize) {
1409 assert(numElements == 16);
1410 auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
1411 rocdlResultType, srcPtr);
1412 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1416 assert(numElements == 16);
1417 auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
1418 rocdlResultType, srcPtr);
1419 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1423 assert(numElements == 8);
1424 auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
1425 rocdlResultType, srcPtr);
1426 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1430 assert(numElements == 4);
1431 rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
1436 return op.emitOpError(
"Unsupported element size for transpose load");
1443 GatherToLDSOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1444 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
1449 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1450 ConversionPatternRewriter &rewriter)
const override {
1451 if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
1452 return op.emitOpError(
"pre-gfx9 and post-gfx10 not supported");
1454 Location loc = op.getLoc();
1456 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1457 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
1462 Type transferType = op.getTransferType();
1463 int loadWidth = [&]() ->
int {
1464 if (
auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1465 return (transferVectorType.getNumElements() *
1466 transferVectorType.getElementTypeBitWidth()) /
1473 if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
1474 return op.emitOpError(
"chipset unsupported element size");
1476 if (chipset !=
kGfx950 && llvm::is_contained({12, 16}, loadWidth))
1477 return op.emitOpError(
"Gather to LDS instructions with 12-byte and "
1478 "16-byte load widths are only supported on gfx950");
1482 (adaptor.getSrcIndices()));
1485 (adaptor.getDstIndices()));
1487 rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
1488 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
1489 rewriter.getI32IntegerAttr(0),
1498struct ExtPackedFp8OpLowering final
1500 ExtPackedFp8OpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1501 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
1506 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1507 ConversionPatternRewriter &rewriter)
const override;
1510struct ScaledExtPackedMatrixOpLowering final
1512 ScaledExtPackedMatrixOpLowering(
const LLVMTypeConverter &converter,
1514 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
1519 matchAndRewrite(ScaledExtPackedMatrixOp op,
1520 ScaledExtPackedMatrixOpAdaptor adaptor,
1521 ConversionPatternRewriter &rewriter)
const override;
1524struct PackedTrunc2xFp8OpLowering final
1526 PackedTrunc2xFp8OpLowering(
const LLVMTypeConverter &converter,
1528 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
1533 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1534 ConversionPatternRewriter &rewriter)
const override;
1537struct PackedStochRoundFp8OpLowering final
1539 PackedStochRoundFp8OpLowering(
const LLVMTypeConverter &converter,
1541 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
1546 matchAndRewrite(PackedStochRoundFp8Op op,
1547 PackedStochRoundFp8OpAdaptor adaptor,
1548 ConversionPatternRewriter &rewriter)
const override;
1551struct ScaledExtPackedOpLowering final
1553 ScaledExtPackedOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1554 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
1559 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1560 ConversionPatternRewriter &rewriter)
const override;
1563struct PackedScaledTruncOpLowering final
1565 PackedScaledTruncOpLowering(
const LLVMTypeConverter &converter,
1567 : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
1572 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1573 ConversionPatternRewriter &rewriter)
const override;
1578LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1579 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1580 ConversionPatternRewriter &rewriter)
const {
1581 Location loc = op.getLoc();
1583 return rewriter.notifyMatchFailure(
1584 loc,
"Fp8 conversion instructions are not available on target "
1585 "architecture and their emulation is not implemented");
1587 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
1588 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1589 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1591 Value source = adaptor.getSource();
1592 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1593 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1596 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1597 Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
1598 if (!sourceVecType) {
1599 longVec = LLVM::InsertElementOp::create(
1602 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1604 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1606 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1611 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1612 if (resultVecType) {
1614 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
1617 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
1622 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
1625 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
1632int32_t getScaleSel(int32_t blockSize,
unsigned bitWidth, int32_t scaleWaveHalf,
1633 int32_t firstScaleByte) {
1639 assert(llvm::is_contained({16, 32}, blockSize));
1640 assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
1642 const bool isFp8 = bitWidth == 8;
1643 const bool isBlock16 = blockSize == 16;
1646 int32_t bit0 = isBlock16;
1647 assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
1648 int32_t bit1 = (firstScaleByte == 2) << 1;
1649 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
1650 int32_t bit2 = scaleWaveHalf << 2;
1651 return bit2 | bit1 | bit0;
1654 int32_t bit0 = isBlock16;
1656 assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
1657 int32_t bits2and1 = firstScaleByte << 1;
1658 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
1659 int32_t bit3 = scaleWaveHalf << 3;
1660 int32_t bits = bit3 | bits2and1 | bit0;
1662 assert(!llvm::is_contained(
1663 {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
1667static std::optional<StringRef>
1668scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
1669 using fp4 = Float4E2M1FNType;
1670 using fp8 = Float8E4M3FNType;
1671 using bf8 = Float8E5M2Type;
1672 using fp6 = Float6E2M3FNType;
1673 using bf6 = Float6E3M2FNType;
1674 if (isa<fp4>(srcElemType)) {
1675 if (destElemType.
isF16())
1676 return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
1677 if (destElemType.
isBF16())
1678 return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
1679 if (destElemType.
isF32())
1680 return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
1681 return std::nullopt;
1683 if (isa<fp8>(srcElemType)) {
1684 if (destElemType.
isF16())
1685 return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
1686 if (destElemType.
isBF16())
1687 return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
1688 if (destElemType.
isF32())
1689 return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
1690 return std::nullopt;
1692 if (isa<bf8>(srcElemType)) {
1693 if (destElemType.
isF16())
1694 return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
1695 if (destElemType.
isBF16())
1696 return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
1697 if (destElemType.
isF32())
1698 return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
1699 return std::nullopt;
1701 if (isa<fp6>(srcElemType)) {
1702 if (destElemType.
isF16())
1703 return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
1704 if (destElemType.
isBF16())
1705 return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
1706 if (destElemType.
isF32())
1707 return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
1708 return std::nullopt;
1710 if (isa<bf6>(srcElemType)) {
1711 if (destElemType.
isF16())
1712 return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
1713 if (destElemType.
isBF16())
1714 return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
1715 if (destElemType.
isF32())
1716 return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
1717 return std::nullopt;
1719 llvm_unreachable(
"invalid combination of element types for packed conversion "
1723LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
1724 ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
1725 ConversionPatternRewriter &rewriter)
const {
1726 using fp4 = Float4E2M1FNType;
1727 using fp8 = Float8E4M3FNType;
1728 using bf8 = Float8E5M2Type;
1729 using fp6 = Float6E2M3FNType;
1730 using bf6 = Float6E3M2FNType;
1731 Location loc = op.getLoc();
1733 return rewriter.notifyMatchFailure(
1735 "Scaled fp packed conversion instructions are not available on target "
1736 "architecture and their emulation is not implemented");
1740 int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
1741 int32_t firstScaleByte = op.getFirstScaleByte();
1742 int32_t blockSize = op.getBlockSize();
1743 auto sourceType = cast<VectorType>(op.getSource().getType());
1744 auto srcElemType = cast<FloatType>(sourceType.getElementType());
1745 unsigned bitWidth = srcElemType.getWidth();
1747 auto targetType = cast<VectorType>(op.getResult().getType());
1748 auto destElemType = cast<FloatType>(targetType.getElementType());
1750 IntegerType i32 = rewriter.getI32Type();
1751 Value source = adaptor.getSource();
1752 Type llvmResultType = typeConverter->convertType(op.getResult().getType());
1753 Type packedType =
nullptr;
1754 if (isa<fp4>(srcElemType)) {
1756 packedType = getTypeConverter()->convertType(packedType);
1757 }
else if (isa<fp8, bf8>(srcElemType)) {
1758 packedType = VectorType::get(2, i32);
1759 packedType = getTypeConverter()->convertType(packedType);
1760 }
else if (isa<fp6, bf6>(srcElemType)) {
1761 packedType = VectorType::get(3, i32);
1762 packedType = getTypeConverter()->convertType(packedType);
1764 llvm_unreachable(
"invalid element type for packed scaled ext");
1767 if (!packedType || !llvmResultType) {
1768 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1771 std::optional<StringRef> maybeIntrinsic =
1772 scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
1773 if (!maybeIntrinsic.has_value())
1774 return op.emitOpError(
1775 "no intrinsic matching packed scaled conversion on the given chipset");
1778 getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
1780 LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
1781 Value castedSource =
1782 LLVM::BitcastOp::create(rewriter, loc, packedType, source);
1784 OperationState loweredOp(loc, *maybeIntrinsic);
1785 loweredOp.addTypes({llvmResultType});
1786 loweredOp.addOperands({castedSource, castedScale});
1788 SmallVector<NamedAttribute, 1> attrs;
1790 NamedAttribute(
"scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
1792 loweredOp.addAttributes(attrs);
1793 Operation *lowered = rewriter.create(loweredOp);
1794 rewriter.replaceOp(op, lowered);
1799LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
1800 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1801 ConversionPatternRewriter &rewriter)
const {
1802 Location loc = op.getLoc();
1804 return rewriter.notifyMatchFailure(
1805 loc,
"Scaled fp conversion instructions are not available on target "
1806 "architecture and their emulation is not implemented");
1807 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1809 Value source = adaptor.getSource();
1810 Value scale = adaptor.getScale();
1812 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1813 Type sourceElemType = sourceVecType.getElementType();
1814 VectorType destVecType = cast<VectorType>(op.getResult().getType());
1815 Type destElemType = destVecType.getElementType();
1817 VectorType packedVecType;
1818 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
1819 VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
1820 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
1821 }
else if (isa<Float4E2M1FNType>(sourceElemType)) {
1822 VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
1823 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
1825 llvm_unreachable(
"invalid element type for scaled ext");
1829 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
1830 Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
1831 if (!sourceVecType) {
1832 longVec = LLVM::InsertElementOp::create(
1835 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1837 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1839 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1844 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1846 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF32())
1847 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
1848 op, destVecType, i32Source, scale, op.getIndex());
1849 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF16())
1850 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
1851 op, destVecType, i32Source, scale, op.getIndex());
1852 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isBF16())
1853 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
1854 op, destVecType, i32Source, scale, op.getIndex());
1855 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF32())
1856 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
1857 op, destVecType, i32Source, scale, op.getIndex());
1858 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF16())
1859 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
1860 op, destVecType, i32Source, scale, op.getIndex());
1861 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isBF16())
1862 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
1863 op, destVecType, i32Source, scale, op.getIndex());
1864 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF32())
1865 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
1866 op, destVecType, i32Source, scale, op.getIndex());
1867 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF16())
1868 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
1869 op, destVecType, i32Source, scale, op.getIndex());
1870 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isBF16())
1871 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
1872 op, destVecType, i32Source, scale, op.getIndex());
1879LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
1880 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1881 ConversionPatternRewriter &rewriter)
const {
1882 Location loc = op.getLoc();
1884 return rewriter.notifyMatchFailure(
1885 loc,
"Scaled fp conversion instructions are not available on target "
1886 "architecture and their emulation is not implemented");
1887 Type v2i16 = getTypeConverter()->convertType(
1888 VectorType::get(2, rewriter.getI16Type()));
1889 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1891 Type resultType = op.getResult().getType();
1893 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1894 Type sourceElemType = sourceVecType.getElementType();
1896 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
1898 Value source = adaptor.getSource();
1899 Value scale = adaptor.getScale();
1900 Value existing = adaptor.getExisting();
1902 existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
1904 existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
1906 if (sourceVecType.getNumElements() < 2) {
1908 Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
1909 VectorType v2 = VectorType::get(2, sourceElemType);
1910 source = LLVM::ZeroOp::create(rewriter, loc, v2);
1911 source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
1914 Value sourceA, sourceB;
1915 if (sourceElemType.
isF32()) {
1918 sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
1919 sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
1923 if (sourceElemType.
isF32() && isa<Float8E5M2Type>(resultElemType))
1924 result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
1925 existing, sourceA, sourceB,
1926 scale, op.getIndex());
1927 else if (sourceElemType.
isF16() && isa<Float8E5M2Type>(resultElemType))
1928 result = ROCDL::CvtScaleF32PkBf8F16Op::create(
1929 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1930 else if (sourceElemType.
isBF16() && isa<Float8E5M2Type>(resultElemType))
1931 result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
1932 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1933 else if (sourceElemType.
isF32() && isa<Float8E4M3FNType>(resultElemType))
1934 result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
1935 existing, sourceA, sourceB,
1936 scale, op.getIndex());
1937 else if (sourceElemType.
isF16() && isa<Float8E4M3FNType>(resultElemType))
1938 result = ROCDL::CvtScaleF32PkFp8F16Op::create(
1939 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1940 else if (sourceElemType.
isBF16() && isa<Float8E4M3FNType>(resultElemType))
1941 result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
1942 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1943 else if (sourceElemType.
isF32() && isa<Float4E2M1FNType>(resultElemType))
1944 result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
1945 existing, sourceA, sourceB,
1946 scale, op.getIndex());
1947 else if (sourceElemType.
isF16() && isa<Float4E2M1FNType>(resultElemType))
1948 result = ROCDL::CvtScaleF32PkFp4F16Op::create(
1949 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1950 else if (sourceElemType.
isBF16() && isa<Float4E2M1FNType>(resultElemType))
1951 result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
1952 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1956 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1957 op, getTypeConverter()->convertType(resultType),
result);
1961LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
1962 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1963 ConversionPatternRewriter &rewriter)
const {
1964 Location loc = op.getLoc();
1966 return rewriter.notifyMatchFailure(
1967 loc,
"Fp8 conversion instructions are not available on target "
1968 "architecture and their emulation is not implemented");
1969 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1971 Type resultType = op.getResult().getType();
1974 Value sourceA = adaptor.getSourceA();
1975 Value sourceB = adaptor.getSourceB();
1977 sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.
getType());
1978 Value existing = adaptor.getExisting();
1980 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
1982 existing = LLVM::UndefOp::create(rewriter, loc, i32);
1986 result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
1987 existing, op.getWordIndex());
1989 result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
1990 existing, op.getWordIndex());
1992 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1993 op, getTypeConverter()->convertType(resultType),
result);
1997LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
1998 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
1999 ConversionPatternRewriter &rewriter)
const {
2000 Location loc = op.getLoc();
2002 return rewriter.notifyMatchFailure(
2003 loc,
"Fp8 conversion instructions are not available on target "
2004 "architecture and their emulation is not implemented");
2005 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2007 Type resultType = op.getResult().getType();
2010 Value source = adaptor.getSource();
2011 Value stoch = adaptor.getStochiasticParam();
2012 Value existing = adaptor.getExisting();
2014 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2016 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2020 result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
2021 existing, op.getStoreIndex());
2023 result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
2024 existing, op.getStoreIndex());
2026 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2027 op, getTypeConverter()->convertType(resultType),
result);
2033struct AMDGPUDPPLowering :
public ConvertOpToLLVMPattern<DPPOp> {
2034 AMDGPUDPPLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2035 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
2039 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
2040 ConversionPatternRewriter &rewriter)
const override {
2043 Location loc = DppOp.getLoc();
2044 Value src = adaptor.getSrc();
2045 Value old = adaptor.getOld();
2048 Type llvmType =
nullptr;
2050 llvmType = rewriter.getI32Type();
2051 }
else if (isa<FloatType>(srcType)) {
2053 ? rewriter.getF32Type()
2054 : rewriter.getF64Type();
2055 }
else if (isa<IntegerType>(srcType)) {
2057 ? rewriter.getI32Type()
2058 : rewriter.getI64Type();
2060 auto llvmSrcIntType = typeConverter->convertType(
2064 auto convertOperand = [&](Value operand, Type operandType) {
2065 if (operandType.getIntOrFloatBitWidth() <= 16) {
2066 if (llvm::isa<FloatType>(operandType)) {
2068 LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
2070 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
2071 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
2072 Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
2074 LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
2076 operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
2081 src = convertOperand(src, srcType);
2082 old = convertOperand(old, oldType);
2085 enum DppCtrl :
unsigned {
2094 ROW_HALF_MIRROR = 0x141,
2099 auto kind = DppOp.getKind();
2100 auto permArgument = DppOp.getPermArgument();
2101 uint32_t DppCtrl = 0;
2105 case DPPPerm::quad_perm:
2106 if (
auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
2108 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
2109 uint32_t num = elem.getInt();
2110 DppCtrl |= num << (i * 2);
2115 case DPPPerm::row_shl:
2116 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
2117 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
2120 case DPPPerm::row_shr:
2121 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
2122 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
2125 case DPPPerm::row_ror:
2126 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
2127 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
2130 case DPPPerm::wave_shl:
2131 DppCtrl = DppCtrl::WAVE_SHL1;
2133 case DPPPerm::wave_shr:
2134 DppCtrl = DppCtrl::WAVE_SHR1;
2136 case DPPPerm::wave_rol:
2137 DppCtrl = DppCtrl::WAVE_ROL1;
2139 case DPPPerm::wave_ror:
2140 DppCtrl = DppCtrl::WAVE_ROR1;
2142 case DPPPerm::row_mirror:
2143 DppCtrl = DppCtrl::ROW_MIRROR;
2145 case DPPPerm::row_half_mirror:
2146 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
2148 case DPPPerm::row_bcast_15:
2149 DppCtrl = DppCtrl::BCAST15;
2151 case DPPPerm::row_bcast_31:
2152 DppCtrl = DppCtrl::BCAST31;
2158 auto rowMask = DppOp->getAttrOfType<IntegerAttr>(
"row_mask").getInt();
2159 auto bankMask = DppOp->getAttrOfType<IntegerAttr>(
"bank_mask").getInt();
2160 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>(
"bound_ctrl").getValue();
2164 ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
2165 rowMask, bankMask, boundCtrl);
2167 Value
result = dppMovOp.getRes();
2169 result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType,
result);
2170 if (!llvm::isa<IntegerType>(srcType)) {
2171 result = LLVM::BitcastOp::create(rewriter, loc, srcType,
result);
2182struct AMDGPUSwizzleBitModeLowering
2183 :
public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
2187 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
2188 ConversionPatternRewriter &rewriter)
const override {
2189 Location loc = op.getLoc();
2190 Type i32 = rewriter.getI32Type();
2191 Value src = adaptor.getSrc();
2192 SmallVector<Value> decomposed =
2194 unsigned andMask = op.getAndMask();
2195 unsigned orMask = op.getOrMask();
2196 unsigned xorMask = op.getXorMask();
2200 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
2202 SmallVector<Value> swizzled;
2203 for (Value v : decomposed) {
2205 ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
2206 swizzled.emplace_back(res);
2210 rewriter.replaceOp(op,
result);
2215struct AMDGPUPermlaneLowering :
public ConvertOpToLLVMPattern<PermlaneSwapOp> {
2218 AMDGPUPermlaneLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2219 : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
2223 matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
2224 ConversionPatternRewriter &rewriter)
const override {
2226 return op->emitOpError(
"permlane_swap is only supported on gfx950+");
2228 Location loc = op.getLoc();
2229 Type i32 = rewriter.getI32Type();
2230 Value src = adaptor.getSrc();
2231 unsigned rowLength = op.getRowLength();
2232 bool fi = op.getFetchInactive();
2233 bool boundctrl = op.getBoundCtrl();
2235 SmallVector<Value> decomposed =
2238 SmallVector<Value> permuted;
2239 for (Value v : decomposed) {
2241 Type i32pair = LLVM::LLVMStructType::getLiteral(
2242 rewriter.getContext(), {v.getType(), v.getType()});
2244 if (rowLength == 16)
2245 res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2247 else if (rowLength == 32)
2248 res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2251 llvm_unreachable(
"unsupported row length");
2253 Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
2254 Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
2256 Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
2257 LLVM::ICmpPredicate::eq, vdst0, v);
2262 LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
2263 permuted.emplace_back(vdstNew);
2267 rewriter.replaceOp(op,
result);
2272struct AMDGPUMakeDmaBaseLowering
2273 :
public ConvertOpToLLVMPattern<MakeDmaBaseOp> {
2276 AMDGPUMakeDmaBaseLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2277 : ConvertOpToLLVMPattern<MakeDmaBaseOp>(converter), chipset(chipset) {}
2281 matchAndRewrite(MakeDmaBaseOp op, OpAdaptor adaptor,
2282 ConversionPatternRewriter &rewriter)
const override {
2284 return op->emitOpError(
"make_dma_base is only supported on gfx1250");
2286 Location loc = op.getLoc();
2288 ValueRange ldsIndices = adaptor.getLdsIndices();
2289 Value lds = adaptor.getLds();
2290 auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
2295 ValueRange globalIndices = adaptor.getGlobalIndices();
2296 Value global = adaptor.getGlobal();
2297 auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
2300 global, globalIndices);
2302 Type i32 = rewriter.getI32Type();
2303 Type i64 = rewriter.getI64Type();
2305 Value castForLdsAddr = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
2306 Value castForGlobalAddr =
2307 LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
2310 LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
2312 Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
2315 Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
2318 Value validHighHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
2321 Value highHalfPlusType =
2322 LLVM::OrOp::create(rewriter, loc, validHighHalf, typeField);
2329 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
2330 assert(v4i32 &&
"expected type conversion to succeed");
2331 Value
result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
2332 result = LLVM::InsertElementOp::create(rewriter, loc,
result, c1, c0);
2333 result = LLVM::InsertElementOp::create(rewriter, loc,
result,
2334 castForLdsAddr, c1);
2335 result = LLVM::InsertElementOp::create(rewriter, loc,
result, lowHalf, c2);
2336 result = LLVM::InsertElementOp::create(rewriter, loc,
result,
2337 highHalfPlusType, c3);
2339 rewriter.replaceOp(op,
result);
2344struct AMDGPUMakeDmaDescriptorLowering
2345 :
public ConvertOpToLLVMPattern<MakeDmaDescriptorOp> {
2348 AMDGPUMakeDmaDescriptorLowering(
const LLVMTypeConverter &converter,
2350 : ConvertOpToLLVMPattern<MakeDmaDescriptorOp>(converter),
2354 Value getDGroup0(OpAdaptor adaptor)
const {
return adaptor.getBase(); }
2356 Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
2357 Value accumulator, Value value, int64_t shift)
const {
2362 value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
2368 return LLVM::OrOp::create(rewriter, loc, accumulator, value);
2371 Value setDataSize(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2372 ConversionPatternRewriter &rewriter, Location loc,
2373 Value sgpr0, ArrayRef<Value> consts)
const {
2375 unsigned elementTypeWidthInBits = op.getElementTypeWidth();
2377 llvm::is_contained<unsigned>({8, 16, 32, 64}, elementTypeWidthInBits) &&
2378 "expected type width to be 8, 16, 32, or 64.");
2379 int64_t dataSize = llvm::Log2_32(elementTypeWidthInBits / 8);
2383 Value setAtomicBarrier(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2384 ConversionPatternRewriter &rewriter, Location loc,
2385 Value sgpr0, ArrayRef<Value> consts)
const {
2386 bool atomic_barrier_enable = adaptor.getAtomicBarrierAddress() !=
nullptr;
2387 if (!atomic_barrier_enable)
2390 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
2393 Value setIterateEnable(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2394 ConversionPatternRewriter &rewriter, Location loc,
2395 Value sgpr0, ArrayRef<Value> consts)
const {
2396 bool iterate_enable = adaptor.getGlobalIncrement() !=
nullptr;
2397 if (!iterate_enable)
2401 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
2404 Value setPadEnable(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2405 ConversionPatternRewriter &rewriter, Location loc,
2406 Value sgpr0, ArrayRef<Value> consts)
const {
2407 bool pad_enable = op.getPadAmount() !=
nullptr;
2411 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
2414 Value setPadInterval(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2415 ConversionPatternRewriter &rewriter, Location loc,
2416 Value sgpr0, ArrayRef<Value> consts)
const {
2417 bool pad_enable = op.getPadAmount() !=
nullptr;
2421 IntegerType i32 = rewriter.getI32Type();
2422 Value padInterval = adaptor.getPadInterval();
2424 padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
2425 padInterval,
false);
2426 padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
2428 return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
2431 Value setPadAmount(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2432 ConversionPatternRewriter &rewriter, Location loc,
2433 Value sgpr0, ArrayRef<Value> consts)
const {
2434 bool pad_enable = op.getPadAmount() !=
nullptr;
2438 Value padAmount = adaptor.getPadAmount();
2440 padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
2442 return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
2445 Value setAtomicBarrierAddress(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2446 ConversionPatternRewriter &rewriter,
2447 Location loc, Value sgpr1,
2448 ArrayRef<Value> consts)
const {
2449 bool atomic_barrier_enable = adaptor.getAtomicBarrierAddress() !=
nullptr;
2450 if (!atomic_barrier_enable)
2453 Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
2454 auto barrierAddressTy =
2455 cast<MemRefType>(op.getAtomicBarrierAddress().getType());
2456 ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
2457 atomicBarrierAddress =
2459 atomicBarrierAddress, atomicBarrierIndices);
2460 IntegerType i32 = rewriter.getI32Type();
2463 atomicBarrierAddress =
2464 LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
2465 atomicBarrierAddress =
2466 LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
2468 atomicBarrierAddress =
2469 LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
2470 return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
2473 std::pair<Value, Value> setTensorDim0(MakeDmaDescriptorOp op,
2475 ConversionPatternRewriter &rewriter,
2476 Location loc, Value sgpr1, Value sgpr2,
2477 ArrayRef<Value> consts)
const {
2478 SmallVector<OpFoldResult> mixedGlobalSizes = op.getMixedGlobalSizes();
2479 OpFoldResult tensorDim0OpFoldResult = mixedGlobalSizes.back();
2481 if (
auto attr = dyn_cast<Attribute>(tensorDim0OpFoldResult))
2485 tensorDim0 = cast<Value>(tensorDim0OpFoldResult);
2488 Value tensorDim0High = LLVM::LShrOp::create(rewriter, loc, tensorDim0, c16);
2489 sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDim0, 48);
2490 sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDim0High, 48 + 16);
2491 return {sgpr1, sgpr2};
2494 std::pair<Value, Value> setTensorDim1(MakeDmaDescriptorOp op,
2496 ConversionPatternRewriter &rewriter,
2497 Location loc, Value sgpr2, Value sgpr3,
2498 ArrayRef<Value> consts)
const {
2500 SmallVector<OpFoldResult> mixedGlobalSizes = op.getMixedGlobalSizes();
2501 OpFoldResult tensorDim1OpFoldResult = *(mixedGlobalSizes.rbegin() + 1);
2503 if (
auto attr = dyn_cast<Attribute>(tensorDim1OpFoldResult))
2507 tensorDim1 = cast<Value>(tensorDim1OpFoldResult);
2510 Value tensorDim1High = LLVM::LShrOp::create(rewriter, loc, tensorDim1, c16);
2511 sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDim1, 80);
2512 sgpr3 = setValueAtOffset(rewriter, loc, sgpr3, tensorDim1High, 80 + 16);
2513 return {sgpr2, sgpr3};
2516 Value setTileDimX(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2517 ConversionPatternRewriter &rewriter, Location loc,
2518 Value sgpr, ArrayRef<Value> consts,
size_t dimX,
2519 int64_t offset)
const {
2520 SmallVector<OpFoldResult> mixedSharedSizes = op.getMixedSharedSizes();
2522 if (mixedSharedSizes.size() <= dimX)
2525 OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
2527 if (
auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult))
2531 tileDimX = cast<Value>(tileDimXOpFoldResult);
2533 return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
2536 Value setTileDim0(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2537 ConversionPatternRewriter &rewriter, Location loc,
2538 Value sgpr3, ArrayRef<Value> consts)
const {
2539 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
2542 Value setTileDim1(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2543 ConversionPatternRewriter &rewriter, Location loc,
2544 Value sgpr4, ArrayRef<Value> consts)
const {
2545 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
2548 Value setTileDim2(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2549 ConversionPatternRewriter &rewriter, Location loc,
2550 Value sgpr4, ArrayRef<Value> consts)
const {
2551 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
2554 std::pair<Value, Value>
2555 setTensorDimXStride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2556 ConversionPatternRewriter &rewriter, Location loc,
2557 Value sgprY, Value sgprZ, ArrayRef<Value> consts,
2558 size_t dimX, int64_t offset)
const {
2559 SmallVector<OpFoldResult> mixedGlobalStrides = op.getMixedGlobalStrides();
2561 if (mixedGlobalStrides.size() <= dimX)
2562 return {sgprY, sgprZ};
2564 OpFoldResult tensorDimXStrideOpFoldResult =
2565 *(mixedGlobalStrides.rbegin() + dimX);
2566 Value tensorDimXStride;
2567 if (
auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult))
2571 tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
2573 constexpr int64_t first48bits = (1ll << 48) - 1;
2576 LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
2577 IntegerType i32 = rewriter.getI32Type();
2578 Value tensorDimXStrideLow =
2579 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
2581 int64_t shift = (offset % 32) == 0 ? 32 : offset % 32;
2583 Value tensorDimXStrideHigh =
2584 LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
2585 tensorDimXStrideHigh =
2586 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
2588 sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
2589 sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
2591 return {sgprY, sgprZ};
2594 std::pair<Value, Value>
2595 setTensorDim0Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2596 ConversionPatternRewriter &rewriter, Location loc,
2597 Value sgpr5, Value sgpr6, ArrayRef<Value> consts)
const {
2598 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
2602 std::pair<Value, Value>
2603 setTensorDim1Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2604 ConversionPatternRewriter &rewriter, Location loc,
2605 Value sgpr5, Value sgpr6, ArrayRef<Value> consts)
const {
2606 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
2610 Value getDGroup1(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2611 ConversionPatternRewriter &rewriter, Location loc,
2612 ArrayRef<Value> consts)
const {
2614 for (int64_t i = 0; i < 8; i++) {
2615 sgprs[i] = consts[0];
2618 sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts);
2619 sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts);
2620 sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
2621 sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
2622 sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts);
2623 sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts);
2626 setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts);
2627 std::tie(sgprs[1], sgprs[2]) =
2628 setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
2629 std::tie(sgprs[2], sgprs[3]) =
2630 setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
2632 sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
2633 sgprs[4] = setTileDim1(op, adaptor, rewriter, loc, sgprs[4], consts);
2634 sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
2635 std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
2636 op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
2637 std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride(
2638 op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts);
2640 IntegerType i32 = rewriter.getI32Type();
2641 Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32));
2642 assert(v8i32 &&
"expected type conversion to succeed");
2643 Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32);
2645 for (
auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) {
2647 LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant);
2654 matchAndRewrite(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2655 ConversionPatternRewriter &rewriter)
const override {
2657 return op->emitOpError(
2658 "make_dma_descriptor is only supported on gfx1250");
2660 if (op.getRank() > 2)
2661 return op->emitOpError(
"unimplemented");
2663 Location loc = op.getLoc();
2665 IntegerType i32 = rewriter.getI32Type();
2666 [[maybe_unused]] Type v4i32 =
2667 this->typeConverter->convertType(VectorType::get(4, i32));
2668 assert(v4i32 &&
"expected type conversion to succeed");
2670 SmallVector<Value> consts;
2671 for (int64_t i = 0; i < 8; i++)
2674 Value dgroup0 = this->getDGroup0(adaptor);
2675 Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts);
2677 SmallVector<Value> results = {dgroup0, dgroup1};
2678 rewriter.replaceOpWithMultiple(op, {results});
2683struct ConvertAMDGPUToROCDLPass
2684 :
public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
2687 void runOnOperation()
override {
2690 if (
failed(maybeChipset)) {
2691 emitError(UnknownLoc::get(ctx),
"Invalid chipset name: " + chipset);
2692 return signalPassFailure();
2696 LLVMTypeConverter converter(ctx);
2697 converter.addConversion([&](TDMBaseType type) -> Type {
2698 Type i32 = IntegerType::get(type.getContext(), 32);
2699 return converter.convertType(VectorType::get(4, i32));
2704 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
2705 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
2706 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
2707 if (
failed(applyPartialConversion(getOperation(),
target,
2709 signalPassFailure();
2716 typeConverter.addTypeAttributeConversion(
2718 -> TypeConverter::AttributeConversionResult {
2720 Type i64 = IntegerType::get(ctx, 64);
2721 switch (as.getValue()) {
2722 case amdgpu::AddressSpace::FatRawBuffer:
2723 return IntegerAttr::get(i64, 7);
2724 case amdgpu::AddressSpace::BufferRsrc:
2725 return IntegerAttr::get(i64, 8);
2726 case amdgpu::AddressSpace::FatStructuredBuffer:
2727 return IntegerAttr::get(i64, 9);
2729 return TypeConverter::AttributeConversionResult::abort();
2738 FatRawBufferCastLowering,
2739 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
2740 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
2741 RawBufferOpLowering<RawBufferAtomicFaddOp,
2742 ROCDL::RawPtrBufferAtomicFaddOp>,
2743 RawBufferOpLowering<RawBufferAtomicFmaxOp,
2744 ROCDL::RawPtrBufferAtomicFmaxOp>,
2745 RawBufferOpLowering<RawBufferAtomicSmaxOp,
2746 ROCDL::RawPtrBufferAtomicSmaxOp>,
2747 RawBufferOpLowering<RawBufferAtomicUminOp,
2748 ROCDL::RawPtrBufferAtomicUminOp>,
2749 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
2750 ROCDL::RawPtrBufferAtomicCmpSwap>,
2751 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
2752 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
2753 WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
2754 ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
2755 PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
2756 GatherToLDSOpLowering, TransposeLoadOpLowering, AMDGPUPermlaneLowering,
2757 AMDGPUMakeDmaBaseLowering, AMDGPUMakeDmaDescriptorLowering>(converter,
2759 patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsic...
static std::optional< StringRef > wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k, bool isRDNA3)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma for RDNA3/4 architectures.
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
constexpr Chipset kGfx908
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs, StringRef attrName)
Push an input operand.
constexpr Chipset kGfx1250
constexpr Chipset kGfx90a
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value)
static std::optional< StringRef > wmmaOpToIntrinsicGfx1250(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k)
Return the rocdl intrinsic corresponding to a WMMA operation wmma for the gfx1250 architecture.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to ROCDL ...
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i64.
constexpr Chipset kGfx950
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, int64_t elementByteWidth)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static std::optional< uint32_t > mfmaTypeSelectCode(Type mlirElemType)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
NamedAttribute represents a combination of a name and an Attribute value.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
::mlir::Pass::Option< std::string > chipset
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
bool hasOcpFp8(const Chipset &chipset)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
void populateAMDGPUMemorySpaceAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces,...
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.