28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/TypeSwitch.h"
30#include "llvm/Support/Casting.h"
31#include "llvm/Support/ErrorHandling.h"
35#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
36#include "mlir/Conversion/Passes.h.inc"
52 IntegerType i32 = rewriter.getI32Type();
54 auto valTy = cast<IntegerType>(val.
getType());
57 return valTy.getWidth() > 32
58 ?
Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
59 :
Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
64 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
70 IntegerType i64 = rewriter.getI64Type();
72 auto valTy = cast<IntegerType>(val.
getType());
75 return valTy.getWidth() > 64
76 ?
Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
77 :
Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
82 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
89 IntegerType i32 = rewriter.getI32Type();
91 for (
auto [i, increment, stride] : llvm::enumerate(
indices, strides)) {
94 ShapedType::isDynamic(stride)
96 memRefDescriptor.
stride(rewriter, loc, i))
97 : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
98 increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
110 MemRefType memrefType,
114 if (memrefType.hasStaticShape() &&
115 !llvm::any_of(strides, ShapedType::isDynamic)) {
116 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
118 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
119 size = std::max(
shape[i] * strides[i], size);
120 size = size * elementByteWidth;
124 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
125 Value size = memrefDescriptor.
size(rewriter, loc, i);
126 Value stride = memrefDescriptor.
stride(rewriter, loc, i);
127 Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
129 ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
134 return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
140 Value cacheSwizzleStride =
nullptr,
141 unsigned addressSpace = 8) {
145 Type i16 = rewriter.getI16Type();
148 Value cacheStrideZext =
149 LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
150 Value swizzleBit = LLVM::ConstantOp::create(
151 rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
152 stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
155 stride = LLVM::ConstantOp::create(rewriter, loc, i16,
156 rewriter.getI16IntegerAttr(0));
173 uint32_t flags = (7 << 12) | (4 << 15);
176 uint32_t oob = boundsCheck ? 3 : 2;
177 flags |= (oob << 28);
181 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
182 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
183 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
188struct FatRawBufferCastLowering
190 FatRawBufferCastLowering(
const LLVMTypeConverter &converter, Chipset chipset)
191 : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
197 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
198 ConversionPatternRewriter &rewriter)
const override {
199 Location loc = op.getLoc();
200 Value memRef = adaptor.getSource();
201 Value unconvertedMemref = op.getSource();
202 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
203 MemRefDescriptor descriptor(memRef);
205 DataLayout dataLayout = DataLayout::closest(op);
206 int64_t elementByteWidth =
209 int64_t unusedOffset = 0;
210 SmallVector<int64_t, 5> strideVals;
211 if (
failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
212 return op.emitOpError(
"Can't lower non-stride-offset memrefs");
214 Value numRecords = adaptor.getValidBytes();
216 numRecords =
getNumRecords(rewriter, loc, memrefType, descriptor,
217 strideVals, elementByteWidth);
220 adaptor.getResetOffset()
221 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
223 : descriptor.alignedPtr(rewriter, loc);
225 Value offset = adaptor.getResetOffset()
226 ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
227 rewriter.getIndexAttr(0))
228 : descriptor.offset(rewriter, loc);
230 bool hasSizes = memrefType.getRank() > 0;
233 Value sizes = hasSizes
234 ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
238 hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
243 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
244 chipset, adaptor.getCacheSwizzleStride(), 7);
246 Value
result = MemRefDescriptor::poison(
248 getTypeConverter()->convertType(op.getResult().getType()));
250 result = LLVM::InsertValueOp::create(rewriter, loc,
result, fatPtr, pos);
251 result = LLVM::InsertValueOp::create(rewriter, loc,
result, fatPtr,
253 result = LLVM::InsertValueOp::create(rewriter, loc,
result, offset,
256 result = LLVM::InsertValueOp::create(rewriter, loc,
result, sizes,
258 result = LLVM::InsertValueOp::create(rewriter, loc,
result, strides,
261 rewriter.replaceOp(op,
result);
267template <
typename GpuOp,
typename Intrinsic>
269 RawBufferOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
270 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
273 static constexpr uint32_t maxVectorOpWidth = 128;
276 matchAndRewrite(GpuOp gpuOp,
typename GpuOp::Adaptor adaptor,
277 ConversionPatternRewriter &rewriter)
const override {
278 Location loc = gpuOp.getLoc();
279 Value memref = adaptor.getMemref();
280 Value unconvertedMemref = gpuOp.getMemref();
281 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
283 if (chipset.majorVersion < 9)
284 return gpuOp.emitOpError(
"raw buffer ops require GCN or higher");
286 Value storeData = adaptor.getODSOperands(0)[0];
287 if (storeData == memref)
291 wantedDataType = storeData.
getType();
293 wantedDataType = gpuOp.getODSResults(0)[0].getType();
295 Value atomicCmpData = Value();
298 Value maybeCmpData = adaptor.getODSOperands(1)[0];
299 if (maybeCmpData != memref)
300 atomicCmpData = maybeCmpData;
303 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
305 Type i32 = rewriter.getI32Type();
308 DataLayout dataLayout = DataLayout::closest(gpuOp);
309 int64_t elementByteWidth =
318 Type llvmBufferValType = llvmWantedDataType;
320 if (
auto floatType = dyn_cast<FloatType>(wantedDataType))
321 llvmBufferValType = this->getTypeConverter()->convertType(
322 rewriter.getIntegerType(floatType.getWidth()));
324 if (
auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
325 uint32_t vecLen = dataVector.getNumElements();
328 uint32_t totalBits = elemBits * vecLen;
330 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
331 if (totalBits > maxVectorOpWidth)
332 return gpuOp.emitOpError(
333 "Total width of loads or stores must be no more than " +
334 Twine(maxVectorOpWidth) +
" bits, but we call for " +
336 " bits. This should've been caught in validation");
337 if (!usePackedFp16 && elemBits < 32) {
338 if (totalBits > 32) {
339 if (totalBits % 32 != 0)
340 return gpuOp.emitOpError(
"Load or store of more than 32-bits that "
341 "doesn't fit into words. Can't happen\n");
342 llvmBufferValType = this->typeConverter->convertType(
343 VectorType::get(totalBits / 32, i32));
345 llvmBufferValType = this->typeConverter->convertType(
346 rewriter.getIntegerType(totalBits));
350 if (
auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
353 if (vecType.getNumElements() == 1)
354 llvmBufferValType = vecType.getElementType();
357 SmallVector<Value, 6> args;
359 if (llvmBufferValType != llvmWantedDataType) {
360 Value castForStore = LLVM::BitcastOp::create(
361 rewriter, loc, llvmBufferValType, storeData);
362 args.push_back(castForStore);
364 args.push_back(storeData);
369 if (llvmBufferValType != llvmWantedDataType) {
370 Value castForCmp = LLVM::BitcastOp::create(
371 rewriter, loc, llvmBufferValType, atomicCmpData);
372 args.push_back(castForCmp);
374 args.push_back(atomicCmpData);
380 SmallVector<int64_t, 5> strides;
381 if (
failed(memrefType.getStridesAndOffset(strides, offset)))
382 return gpuOp.emitOpError(
"Can't lower non-stride-offset memrefs");
384 MemRefDescriptor memrefDescriptor(memref);
386 Value ptr = memrefDescriptor.bufferPtr(
387 rewriter, loc, *this->getTypeConverter(), memrefType);
389 rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
391 adaptor.getBoundsCheck(), chipset);
392 args.push_back(resource);
396 adaptor.getIndices(), strides);
397 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
398 indexOffset && *indexOffset > 0) {
400 voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
404 voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
405 args.push_back(voffset);
408 Value sgprOffset = adaptor.getSgprOffset();
411 sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
412 args.push_back(sgprOffset);
419 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
421 Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
422 ArrayRef<NamedAttribute>());
425 if (llvmBufferValType != llvmWantedDataType) {
426 replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
431 rewriter.eraseOp(gpuOp);
448static FailureOr<unsigned> encodeWaitcnt(
Chipset chipset,
unsigned vmcnt,
449 unsigned expcnt,
unsigned lgkmcnt) {
451 vmcnt = std::min(15u, vmcnt);
452 expcnt = std::min(7u, expcnt);
453 lgkmcnt = std::min(15u, lgkmcnt);
454 return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
457 vmcnt = std::min(63u, vmcnt);
458 expcnt = std::min(7u, expcnt);
459 lgkmcnt = std::min(15u, lgkmcnt);
460 unsigned lowBits = vmcnt & 0xF;
461 unsigned highBits = (vmcnt >> 4) << 14;
462 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
463 return lowBits | highBits | otherCnts;
466 vmcnt = std::min(63u, vmcnt);
467 expcnt = std::min(7u, expcnt);
468 lgkmcnt = std::min(63u, lgkmcnt);
469 unsigned lowBits = vmcnt & 0xF;
470 unsigned highBits = (vmcnt >> 4) << 14;
471 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
472 return lowBits | highBits | otherCnts;
475 vmcnt = std::min(63u, vmcnt);
476 expcnt = std::min(7u, expcnt);
477 lgkmcnt = std::min(63u, lgkmcnt);
478 return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
483struct MemoryCounterWaitOpLowering
493 matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
494 ConversionPatternRewriter &rewriter)
const override {
495 if (
chipset.majorVersion >= 12) {
497 if (std::optional<int> ds = adaptor.getDs())
498 ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
500 if (std::optional<int>
load = adaptor.getLoad())
501 ROCDL::WaitLoadcntOp::create(rewriter, loc, *
load);
503 if (std::optional<int> store = adaptor.getStore())
504 ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
506 if (std::optional<int> exp = adaptor.getExp())
507 ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
509 if (std::optional<int>
tensor = adaptor.getTensor())
510 ROCDL::WaitTensorcntOp::create(rewriter, loc, *
tensor);
512 rewriter.eraseOp(op);
516 if (adaptor.getTensor())
517 return op.emitOpError(
"unsupported chipset");
519 auto getVal = [](
Attribute attr) ->
unsigned {
521 return cast<IntegerAttr>(attr).getInt();
526 unsigned ds = getVal(adaptor.getDsAttr());
527 unsigned exp = getVal(adaptor.getExpAttr());
529 unsigned vmcnt = 1024;
531 Attribute store = adaptor.getStoreAttr();
533 vmcnt = getVal(
load) + getVal(store);
535 vmcnt = getVal(
load);
537 vmcnt = getVal(store);
540 FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
542 return op.emitOpError(
"unsupported chipset");
544 rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
550 LDSBarrierOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
551 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
556 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
557 ConversionPatternRewriter &rewriter)
const override {
558 Location loc = op.getLoc();
561 bool requiresInlineAsm = chipset <
kGfx90a;
564 rewriter.getAttr<LLVM::MMRATagAttr>(
"amdgpu-synchronize-as",
"local");
573 StringRef scope =
"workgroup";
575 auto relFence = LLVM::FenceOp::create(rewriter, loc,
576 LLVM::AtomicOrdering::release, scope);
577 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
578 if (requiresInlineAsm) {
579 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
580 LLVM::AsmDialect::AD_ATT);
581 const char *asmStr =
";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
582 const char *constraints =
"";
583 LLVM::InlineAsmOp::create(
586 asmStr, constraints,
true,
587 false, LLVM::TailCallKind::None,
590 }
else if (chipset.majorVersion < 12) {
591 ROCDL::SBarrierOp::create(rewriter, loc);
593 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
594 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
597 auto acqFence = LLVM::FenceOp::create(rewriter, loc,
598 LLVM::AtomicOrdering::acquire, scope);
599 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
600 rewriter.replaceOp(op, acqFence);
606 SchedBarrierOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
607 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
612 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
613 ConversionPatternRewriter &rewriter)
const override {
614 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
615 (uint32_t)op.getOpts());
639 bool allowBf16 =
true) {
641 if (
auto vectorType = dyn_cast<VectorType>(inputType)) {
642 if (vectorType.getElementType().isBF16() && !allowBf16)
643 return LLVM::BitcastOp::create(
644 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
645 if (vectorType.getElementType().isInteger(8) &&
646 vectorType.getNumElements() <= 8)
647 return LLVM::BitcastOp::create(
649 rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
650 if (isa<IntegerType>(vectorType.getElementType()) &&
651 vectorType.getElementTypeBitWidth() <= 8) {
652 int64_t numWords = llvm::divideCeil(
653 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
655 return LLVM::BitcastOp::create(
656 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
676 Type outputType = rewriter.getI32Type();
677 if (
auto intType = dyn_cast<IntegerType>(inputType))
678 return LLVM::ZExtOp::create(rewriter, loc, outputType, input);
679 return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
691 ConversionPatternRewriter &rewriter,
Location loc,
696 auto vectorType = dyn_cast<VectorType>(inputType);
698 operands.push_back(llvmInput);
701 Type elemType = vectorType.getElementType();
703 operands.push_back(llvmInput);
710 auto mlirInputType = cast<VectorType>(mlirInput.
getType());
711 bool isInputInteger = mlirInputType.getElementType().isInteger();
712 if (isInputInteger) {
714 bool localIsUnsigned = isUnsigned;
716 localIsUnsigned =
true;
718 localIsUnsigned =
false;
721 NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
726 Type i32 = rewriter.getI32Type();
727 Type intrinsicInType = numBits <= 32
728 ? (
Type)rewriter.getIntegerType(numBits)
729 : (
Type)VectorType::get(numBits / 32, i32);
730 auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
731 Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
732 loc, llvmIntrinsicInType, llvmInput);
737 castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
738 operands.push_back(castInput);
751 Value output, int32_t subwordOffset,
755 auto vectorType = dyn_cast<VectorType>(inputType);
756 Type elemType = vectorType.getElementType();
757 operands.push_back(output);
769 return (chipset ==
kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
770 (
hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
776 return (chipset ==
kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
777 (
hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
785 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
786 b = mfma.getBlocks();
791 if (mfma.getReducePrecision() && chipset >=
kGfx942) {
792 if (m == 32 && n == 32 && k == 4 &&
b == 1)
793 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
794 if (m == 16 && n == 16 && k == 8 &&
b == 1)
795 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
797 if (m == 32 && n == 32 && k == 1 &&
b == 2)
798 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
799 if (m == 16 && n == 16 && k == 1 &&
b == 4)
800 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
801 if (m == 4 && n == 4 && k == 1 &&
b == 16)
802 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
803 if (m == 32 && n == 32 && k == 2 &&
b == 1)
804 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
805 if (m == 16 && n == 16 && k == 4 &&
b == 1)
806 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
811 if (m == 32 && n == 32 && k == 16 &&
b == 1)
812 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
813 if (m == 16 && n == 16 && k == 32 &&
b == 1)
814 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
816 if (m == 32 && n == 32 && k == 4 &&
b == 2)
817 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
818 if (m == 16 && n == 16 && k == 4 &&
b == 4)
819 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
820 if (m == 4 && n == 4 && k == 4 &&
b == 16)
821 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
822 if (m == 32 && n == 32 && k == 8 &&
b == 1)
823 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
824 if (m == 16 && n == 16 && k == 16 &&
b == 1)
825 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
830 if (m == 32 && n == 32 && k == 16 &&
b == 1)
831 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
832 if (m == 16 && n == 16 && k == 32 &&
b == 1)
833 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
836 if (m == 32 && n == 32 && k == 4 &&
b == 2)
837 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
838 if (m == 16 && n == 16 && k == 4 &&
b == 4)
839 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
840 if (m == 4 && n == 4 && k == 4 &&
b == 16)
841 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
842 if (m == 32 && n == 32 && k == 8 &&
b == 1)
843 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
844 if (m == 16 && n == 16 && k == 16 &&
b == 1)
845 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
847 if (m == 32 && n == 32 && k == 2 &&
b == 2)
848 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
849 if (m == 16 && n == 16 && k == 2 &&
b == 4)
850 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
851 if (m == 4 && n == 4 && k == 2 &&
b == 16)
852 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
853 if (m == 32 && n == 32 && k == 4 &&
b == 1)
854 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
855 if (m == 16 && n == 16 && k == 8 &&
b == 1)
856 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
861 if (m == 32 && n == 32 && k == 32 &&
b == 1)
862 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
863 if (m == 16 && n == 16 && k == 64 &&
b == 1)
864 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
866 if (m == 32 && n == 32 && k == 4 &&
b == 2)
867 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
868 if (m == 16 && n == 16 && k == 4 &&
b == 4)
869 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
870 if (m == 4 && n == 4 && k == 4 &&
b == 16)
871 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
872 if (m == 32 && n == 32 && k == 8 &&
b == 1)
873 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
874 if (m == 16 && n == 16 && k == 16 &&
b == 1)
875 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
876 if (m == 32 && n == 32 && k == 16 &&
b == 1 && chipset >=
kGfx942)
877 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
878 if (m == 16 && n == 16 && k == 32 &&
b == 1 && chipset >=
kGfx942)
879 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
883 if (m == 16 && n == 16 && k == 4 &&
b == 1)
884 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
885 if (m == 4 && n == 4 && k == 4 &&
b == 4)
886 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
893 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
894 if (m == 16 && n == 16 && k == 32 &&
b == 1) {
896 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
898 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
900 if (m == 32 && n == 32 && k == 16 &&
b == 1) {
902 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
904 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
910 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
911 if (m == 16 && n == 16 && k == 32 &&
b == 1) {
913 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
915 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
917 if (m == 32 && n == 32 && k == 16 &&
b == 1) {
919 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
921 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
930 .Case([](Float8E4M3FNType) {
return 0u; })
931 .Case([](Float8E5M2Type) {
return 1u; })
932 .Case([](Float6E2M3FNType) {
return 2u; })
933 .Case([](Float6E3M2FNType) {
return 3u; })
934 .Case([](Float4E2M1FNType) {
return 4u; })
935 .Default(std::nullopt);
945static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
947 uint32_t n, uint32_t k, uint32_t
b,
Chipset chipset) {
954 if (!isa<Float32Type>(destType))
959 if (!aTypeCode || !bTypeCode)
962 if (m == 32 && n == 32 && k == 64 &&
b == 1)
963 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
964 *aTypeCode, *bTypeCode};
965 if (m == 16 && n == 16 && k == 128 &&
b == 1)
967 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
973static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
976 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
977 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
978 mfma.getBlocks(), chipset);
981static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
984 smfma.getSourceB().getType(),
985 smfma.getDestC().getType(), smfma.getM(),
986 smfma.getN(), smfma.getK(), 1u, chipset);
991static std::optional<StringRef>
993 Type elemDestType, uint32_t k,
bool isRDNA3) {
994 using fp8 = Float8E4M3FNType;
995 using bf8 = Float8E5M2Type;
1000 if (elemSourceType.
isF16() && elemDestType.
isF32())
1001 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
1002 if (elemSourceType.
isBF16() && elemDestType.
isF32())
1003 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
1004 if (elemSourceType.
isF16() && elemDestType.
isF16())
1005 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1007 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1009 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1014 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1015 return std::nullopt;
1019 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1020 elemDestType.
isF32())
1021 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1022 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1023 elemDestType.
isF32())
1024 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1025 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1026 elemDestType.
isF32())
1027 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1028 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1029 elemDestType.
isF32())
1030 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1032 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1034 return std::nullopt;
1038 if (k == 32 && !isRDNA3) {
1040 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1043 return std::nullopt;
1049 Type elemBSourceType,
1052 using fp8 = Float8E4M3FNType;
1053 using bf8 = Float8E5M2Type;
1056 if (elemSourceType.
isF32() && elemDestType.
isF32())
1057 return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1059 return std::nullopt;
1063 if (elemSourceType.
isF16() && elemDestType.
isF32())
1064 return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
1065 if (elemSourceType.
isBF16() && elemDestType.
isF32())
1066 return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
1067 if (elemSourceType.
isF16() && elemDestType.
isF16())
1068 return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
1070 return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
1072 return std::nullopt;
1076 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1077 if (elemDestType.
isF32())
1078 return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
1079 if (elemDestType.
isF16())
1080 return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
1082 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1083 if (elemDestType.
isF32())
1084 return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
1085 if (elemDestType.
isF16())
1086 return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
1088 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1089 if (elemDestType.
isF32())
1090 return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
1091 if (elemDestType.
isF16())
1092 return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
1094 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1095 if (elemDestType.
isF32())
1096 return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
1097 if (elemDestType.
isF16())
1098 return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
1101 return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
1103 return std::nullopt;
1107 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1108 if (elemDestType.
isF32())
1109 return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
1110 if (elemDestType.
isF16())
1111 return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
1113 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1114 if (elemDestType.
isF32())
1115 return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
1116 if (elemDestType.
isF16())
1117 return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
1119 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1120 if (elemDestType.
isF32())
1121 return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
1122 if (elemDestType.
isF16())
1123 return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
1125 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1126 if (elemDestType.
isF32())
1127 return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
1128 if (elemDestType.
isF16())
1129 return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
1132 return std::nullopt;
1135 return std::nullopt;
1143 auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
1144 auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
1145 auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1146 Type elemSourceType = sourceVectorType.getElementType();
1147 Type elemBSourceType = sourceBVectorType.getElementType();
1148 Type elemDestType = destVectorType.getElementType();
1150 const uint32_t k = wmma.getK();
1155 if (isRDNA3 || isRDNA4)
1164 return std::nullopt;
1169 MFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1170 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
1175 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1176 ConversionPatternRewriter &rewriter)
const override {
1177 Location loc = op.getLoc();
1178 Type outType = typeConverter->convertType(op.getDestD().getType());
1179 Type intrinsicOutType = outType;
1180 if (
auto outVecType = dyn_cast<VectorType>(outType))
1181 if (outVecType.getElementType().isBF16())
1182 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
1184 if (chipset.majorVersion != 9 || chipset <
kGfx908)
1185 return op->emitOpError(
"MFMA only supported on gfx908+");
1186 uint32_t getBlgpField =
static_cast<uint32_t
>(op.getBlgp());
1187 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1189 return op.emitOpError(
"negation unsupported on older than gfx942");
1191 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1194 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1196 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1197 return op.emitOpError(
"no intrinsic matching MFMA size on given chipset");
1200 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1202 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1203 return op.emitOpError(
1204 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1205 "be scaled as those fields are used for type information");
1208 StringRef intrinsicName =
1209 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1212 bool allowBf16 = [&]() {
1217 return intrinsicName.contains(
"16x16x32.bf16") ||
1218 intrinsicName.contains(
"32x32x16.bf16");
1220 OperationState loweredOp(loc, intrinsicName);
1221 loweredOp.addTypes(intrinsicOutType);
1223 rewriter, loc, adaptor.getSourceA(), allowBf16),
1225 rewriter, loc, adaptor.getSourceB(), allowBf16),
1226 adaptor.getDestC()});
1229 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1239 Value lowered = rewriter.create(loweredOp)->getResult(0);
1240 if (outType != intrinsicOutType)
1241 lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1242 rewriter.replaceOp(op, lowered);
1248 ScaledMFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1249 : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1254 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1255 ConversionPatternRewriter &rewriter)
const override {
1256 Location loc = op.getLoc();
1257 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1259 if (chipset.majorVersion != 9 || chipset <
kGfx950)
1260 return op->emitOpError(
"scaled MFMA only supported on gfx908+");
1261 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1263 if (!maybeScaledIntrinsic.has_value())
1264 return op.emitOpError(
1265 "no intrinsic matching scaled MFMA size on given chipset");
1267 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1268 OperationState loweredOp(loc, intrinsicName);
1269 loweredOp.addTypes(intrinsicOutType);
1270 loweredOp.addOperands(
1273 adaptor.getDestC()});
1278 loweredOp.addOperands(
1287 Value lowered = rewriter.create(loweredOp)->getResult(0);
1288 rewriter.replaceOp(op, lowered);
1294 WMMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1295 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1300 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1301 ConversionPatternRewriter &rewriter)
const override {
1302 Location loc = op.getLoc();
1304 typeConverter->convertType<VectorType>(op.getDestD().
getType());
1306 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1308 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1309 return op->emitOpError(
"WMMA only supported on gfx11 and gfx12");
1311 bool isGFX1250 = chipset >=
kGfx1250;
1316 auto aType = cast<VectorType>(adaptor.getSourceA().getType());
1317 auto bType = cast<VectorType>(adaptor.getSourceB().getType());
1318 auto destCType = cast<VectorType>(adaptor.getDestC().getType());
1319 bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
1320 bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
1321 bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
1322 bool castOutToI16 = outType.getElementType().
isBF16() && !isGFX1250;
1323 VectorType rawOutType = outType;
1325 rawOutType = outType.clone(rewriter.getI16Type());
1326 Value a = adaptor.getSourceA();
1328 a = LLVM::BitcastOp::create(rewriter, loc,
1329 aType.clone(rewriter.getI16Type()), a);
1330 Value
b = adaptor.getSourceB();
1332 b = LLVM::BitcastOp::create(rewriter, loc,
1333 bType.clone(rewriter.getI16Type()),
b);
1334 Value destC = adaptor.getDestC();
1336 destC = LLVM::BitcastOp::create(
1337 rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
1341 if (!maybeIntrinsic.has_value())
1342 return op.emitOpError(
"no intrinsic matching WMMA on the given chipset");
1344 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1345 return op.emitOpError(
"subwordOffset not supported on gfx12+");
1347 SmallVector<Value, 4> operands;
1348 SmallVector<NamedAttribute, 4> attrs;
1350 op.getSourceA(), operands, attrs,
"signA");
1352 op.getSourceB(), operands, attrs,
"signB");
1354 op.getSubwordOffset(), op.getClamp(), operands,
1357 OperationState loweredOp(loc, *maybeIntrinsic);
1358 loweredOp.addTypes(rawOutType);
1359 loweredOp.addOperands(operands);
1360 loweredOp.addAttributes(attrs);
1361 Operation *lowered = rewriter.create(loweredOp);
1363 Operation *maybeCastBack = lowered;
1364 if (rawOutType != outType)
1365 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1367 rewriter.replaceOp(op, maybeCastBack->
getResults());
1373struct TransposeLoadOpLowering
1375 TransposeLoadOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1376 : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1381 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1382 ConversionPatternRewriter &rewriter)
const override {
1384 return op.emitOpError(
"Non-gfx950 chipset not supported");
1386 Location loc = op.getLoc();
1387 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1391 size_t srcElementSize =
1392 srcMemRefType.getElementType().getIntOrFloatBitWidth();
1393 if (srcElementSize < 8)
1394 return op.emitOpError(
"Expect source memref to have at least 8 bits "
1395 "element size, got ")
1398 auto resultType = cast<VectorType>(op.getResult().getType());
1401 (adaptor.getSrcIndices()));
1403 size_t numElements = resultType.getNumElements();
1404 size_t elementTypeSize =
1405 resultType.getElementType().getIntOrFloatBitWidth();
1409 Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
1410 rewriter.getIntegerType(32));
1411 Type llvmResultType = typeConverter->convertType(resultType);
1413 switch (elementTypeSize) {
1415 assert(numElements == 16);
1416 auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
1417 rocdlResultType, srcPtr);
1418 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1422 assert(numElements == 16);
1423 auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
1424 rocdlResultType, srcPtr);
1425 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1429 assert(numElements == 8);
1430 auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
1431 rocdlResultType, srcPtr);
1432 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1436 assert(numElements == 4);
1437 rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
1442 return op.emitOpError(
"Unsupported element size for transpose load");
1449 GatherToLDSOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1450 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
1455 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1456 ConversionPatternRewriter &rewriter)
const override {
1457 if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
1458 return op.emitOpError(
"pre-gfx9 and post-gfx10 not supported");
1460 Location loc = op.getLoc();
1462 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1463 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
1468 Type transferType = op.getTransferType();
1469 int loadWidth = [&]() ->
int {
1470 if (
auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1471 return (transferVectorType.getNumElements() *
1472 transferVectorType.getElementTypeBitWidth()) /
1479 if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
1480 return op.emitOpError(
"chipset unsupported element size");
1482 if (chipset !=
kGfx950 && llvm::is_contained({12, 16}, loadWidth))
1483 return op.emitOpError(
"Gather to LDS instructions with 12-byte and "
1484 "16-byte load widths are only supported on gfx950");
1488 (adaptor.getSrcIndices()));
1491 (adaptor.getDstIndices()));
1493 rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
1494 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
1495 rewriter.getI32IntegerAttr(0),
1504struct ExtPackedFp8OpLowering final
1506 ExtPackedFp8OpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1507 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
1512 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1513 ConversionPatternRewriter &rewriter)
const override;
1516struct ScaledExtPackedMatrixOpLowering final
1518 ScaledExtPackedMatrixOpLowering(
const LLVMTypeConverter &converter,
1520 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
1525 matchAndRewrite(ScaledExtPackedMatrixOp op,
1526 ScaledExtPackedMatrixOpAdaptor adaptor,
1527 ConversionPatternRewriter &rewriter)
const override;
1530struct PackedTrunc2xFp8OpLowering final
1532 PackedTrunc2xFp8OpLowering(
const LLVMTypeConverter &converter,
1534 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
1539 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1540 ConversionPatternRewriter &rewriter)
const override;
1543struct PackedStochRoundFp8OpLowering final
1545 PackedStochRoundFp8OpLowering(
const LLVMTypeConverter &converter,
1547 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
1552 matchAndRewrite(PackedStochRoundFp8Op op,
1553 PackedStochRoundFp8OpAdaptor adaptor,
1554 ConversionPatternRewriter &rewriter)
const override;
1557struct ScaledExtPackedOpLowering final
1559 ScaledExtPackedOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1560 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
1565 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1566 ConversionPatternRewriter &rewriter)
const override;
1569struct PackedScaledTruncOpLowering final
1571 PackedScaledTruncOpLowering(
const LLVMTypeConverter &converter,
1573 : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
1578 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1579 ConversionPatternRewriter &rewriter)
const override;
1584LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1585 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1586 ConversionPatternRewriter &rewriter)
const {
1587 Location loc = op.getLoc();
1589 return rewriter.notifyMatchFailure(
1590 loc,
"Fp8 conversion instructions are not available on target "
1591 "architecture and their emulation is not implemented");
1593 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
1594 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1595 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1597 Value source = adaptor.getSource();
1598 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1599 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1602 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1603 Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
1604 if (!sourceVecType) {
1605 longVec = LLVM::InsertElementOp::create(
1608 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1610 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1612 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1617 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1618 if (resultVecType) {
1620 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
1623 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
1628 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
1631 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
1638int32_t getScaleSel(int32_t blockSize,
unsigned bitWidth, int32_t scaleWaveHalf,
1639 int32_t firstScaleByte) {
1645 assert(llvm::is_contained({16, 32}, blockSize));
1646 assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
1648 const bool isFp8 = bitWidth == 8;
1649 const bool isBlock16 = blockSize == 16;
1652 int32_t bit0 = isBlock16;
1653 assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
1654 int32_t bit1 = (firstScaleByte == 2) << 1;
1655 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
1656 int32_t bit2 = scaleWaveHalf << 2;
1657 return bit2 | bit1 | bit0;
1660 int32_t bit0 = isBlock16;
1662 assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
1663 int32_t bits2and1 = firstScaleByte << 1;
1664 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
1665 int32_t bit3 = scaleWaveHalf << 3;
1666 int32_t bits = bit3 | bits2and1 | bit0;
1668 assert(!llvm::is_contained(
1669 {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
1673static std::optional<StringRef>
1674scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
1675 using fp4 = Float4E2M1FNType;
1676 using fp8 = Float8E4M3FNType;
1677 using bf8 = Float8E5M2Type;
1678 using fp6 = Float6E2M3FNType;
1679 using bf6 = Float6E3M2FNType;
1680 if (isa<fp4>(srcElemType)) {
1681 if (destElemType.
isF16())
1682 return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
1683 if (destElemType.
isBF16())
1684 return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
1685 if (destElemType.
isF32())
1686 return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
1687 return std::nullopt;
1689 if (isa<fp8>(srcElemType)) {
1690 if (destElemType.
isF16())
1691 return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
1692 if (destElemType.
isBF16())
1693 return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
1694 if (destElemType.
isF32())
1695 return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
1696 return std::nullopt;
1698 if (isa<bf8>(srcElemType)) {
1699 if (destElemType.
isF16())
1700 return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
1701 if (destElemType.
isBF16())
1702 return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
1703 if (destElemType.
isF32())
1704 return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
1705 return std::nullopt;
1707 if (isa<fp6>(srcElemType)) {
1708 if (destElemType.
isF16())
1709 return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
1710 if (destElemType.
isBF16())
1711 return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
1712 if (destElemType.
isF32())
1713 return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
1714 return std::nullopt;
1716 if (isa<bf6>(srcElemType)) {
1717 if (destElemType.
isF16())
1718 return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
1719 if (destElemType.
isBF16())
1720 return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
1721 if (destElemType.
isF32())
1722 return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
1723 return std::nullopt;
1725 llvm_unreachable(
"invalid combination of element types for packed conversion "
1729LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
1730 ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
1731 ConversionPatternRewriter &rewriter)
const {
1732 using fp4 = Float4E2M1FNType;
1733 using fp8 = Float8E4M3FNType;
1734 using bf8 = Float8E5M2Type;
1735 using fp6 = Float6E2M3FNType;
1736 using bf6 = Float6E3M2FNType;
1737 Location loc = op.getLoc();
1739 return rewriter.notifyMatchFailure(
1741 "Scaled fp packed conversion instructions are not available on target "
1742 "architecture and their emulation is not implemented");
1746 int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
1747 int32_t firstScaleByte = op.getFirstScaleByte();
1748 int32_t blockSize = op.getBlockSize();
1749 auto sourceType = cast<VectorType>(op.getSource().getType());
1750 auto srcElemType = cast<FloatType>(sourceType.getElementType());
1751 unsigned bitWidth = srcElemType.getWidth();
1753 auto targetType = cast<VectorType>(op.getResult().getType());
1754 auto destElemType = cast<FloatType>(targetType.getElementType());
1756 IntegerType i32 = rewriter.getI32Type();
1757 Value source = adaptor.getSource();
1758 Type llvmResultType = typeConverter->convertType(op.getResult().getType());
1759 Type packedType =
nullptr;
1760 if (isa<fp4>(srcElemType)) {
1762 packedType = getTypeConverter()->convertType(packedType);
1763 }
else if (isa<fp8, bf8>(srcElemType)) {
1764 packedType = VectorType::get(2, i32);
1765 packedType = getTypeConverter()->convertType(packedType);
1766 }
else if (isa<fp6, bf6>(srcElemType)) {
1767 packedType = VectorType::get(3, i32);
1768 packedType = getTypeConverter()->convertType(packedType);
1770 llvm_unreachable(
"invalid element type for packed scaled ext");
1773 if (!packedType || !llvmResultType) {
1774 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1777 std::optional<StringRef> maybeIntrinsic =
1778 scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
1779 if (!maybeIntrinsic.has_value())
1780 return op.emitOpError(
1781 "no intrinsic matching packed scaled conversion on the given chipset");
1784 getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
1786 LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
1787 Value castedSource =
1788 LLVM::BitcastOp::create(rewriter, loc, packedType, source);
1790 OperationState loweredOp(loc, *maybeIntrinsic);
1791 loweredOp.addTypes({llvmResultType});
1792 loweredOp.addOperands({castedSource, castedScale});
1794 SmallVector<NamedAttribute, 1> attrs;
1796 NamedAttribute(
"scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
1798 loweredOp.addAttributes(attrs);
1799 Operation *lowered = rewriter.create(loweredOp);
1800 rewriter.replaceOp(op, lowered);
1805LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
1806 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1807 ConversionPatternRewriter &rewriter)
const {
1808 Location loc = op.getLoc();
1810 return rewriter.notifyMatchFailure(
1811 loc,
"Scaled fp conversion instructions are not available on target "
1812 "architecture and their emulation is not implemented");
1813 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1815 Value source = adaptor.getSource();
1816 Value scale = adaptor.getScale();
1818 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1819 Type sourceElemType = sourceVecType.getElementType();
1820 VectorType destVecType = cast<VectorType>(op.getResult().getType());
1821 Type destElemType = destVecType.getElementType();
1823 VectorType packedVecType;
1824 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
1825 VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
1826 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
1827 }
else if (isa<Float4E2M1FNType>(sourceElemType)) {
1828 VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
1829 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
1831 llvm_unreachable(
"invalid element type for scaled ext");
1835 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
1836 Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
1837 if (!sourceVecType) {
1838 longVec = LLVM::InsertElementOp::create(
1841 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1843 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1845 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1850 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1852 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF32())
1853 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
1854 op, destVecType, i32Source, scale, op.getIndex());
1855 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF16())
1856 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
1857 op, destVecType, i32Source, scale, op.getIndex());
1858 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isBF16())
1859 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
1860 op, destVecType, i32Source, scale, op.getIndex());
1861 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF32())
1862 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
1863 op, destVecType, i32Source, scale, op.getIndex());
1864 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF16())
1865 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
1866 op, destVecType, i32Source, scale, op.getIndex());
1867 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isBF16())
1868 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
1869 op, destVecType, i32Source, scale, op.getIndex());
1870 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF32())
1871 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
1872 op, destVecType, i32Source, scale, op.getIndex());
1873 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF16())
1874 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
1875 op, destVecType, i32Source, scale, op.getIndex());
1876 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isBF16())
1877 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
1878 op, destVecType, i32Source, scale, op.getIndex());
1885LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
1886 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1887 ConversionPatternRewriter &rewriter)
const {
1888 Location loc = op.getLoc();
1890 return rewriter.notifyMatchFailure(
1891 loc,
"Scaled fp conversion instructions are not available on target "
1892 "architecture and their emulation is not implemented");
1893 Type v2i16 = getTypeConverter()->convertType(
1894 VectorType::get(2, rewriter.getI16Type()));
1895 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1897 Type resultType = op.getResult().getType();
1899 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1900 Type sourceElemType = sourceVecType.getElementType();
1902 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
1904 Value source = adaptor.getSource();
1905 Value scale = adaptor.getScale();
1906 Value existing = adaptor.getExisting();
1908 existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
1910 existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
1912 if (sourceVecType.getNumElements() < 2) {
1914 Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
1915 VectorType v2 = VectorType::get(2, sourceElemType);
1916 source = LLVM::ZeroOp::create(rewriter, loc, v2);
1917 source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
1920 Value sourceA, sourceB;
1921 if (sourceElemType.
isF32()) {
1924 sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
1925 sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
1929 if (sourceElemType.
isF32() && isa<Float8E5M2Type>(resultElemType))
1930 result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
1931 existing, sourceA, sourceB,
1932 scale, op.getIndex());
1933 else if (sourceElemType.
isF16() && isa<Float8E5M2Type>(resultElemType))
1934 result = ROCDL::CvtScaleF32PkBf8F16Op::create(
1935 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1936 else if (sourceElemType.
isBF16() && isa<Float8E5M2Type>(resultElemType))
1937 result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
1938 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1939 else if (sourceElemType.
isF32() && isa<Float8E4M3FNType>(resultElemType))
1940 result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
1941 existing, sourceA, sourceB,
1942 scale, op.getIndex());
1943 else if (sourceElemType.
isF16() && isa<Float8E4M3FNType>(resultElemType))
1944 result = ROCDL::CvtScaleF32PkFp8F16Op::create(
1945 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1946 else if (sourceElemType.
isBF16() && isa<Float8E4M3FNType>(resultElemType))
1947 result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
1948 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1949 else if (sourceElemType.
isF32() && isa<Float4E2M1FNType>(resultElemType))
1950 result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
1951 existing, sourceA, sourceB,
1952 scale, op.getIndex());
1953 else if (sourceElemType.
isF16() && isa<Float4E2M1FNType>(resultElemType))
1954 result = ROCDL::CvtScaleF32PkFp4F16Op::create(
1955 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1956 else if (sourceElemType.
isBF16() && isa<Float4E2M1FNType>(resultElemType))
1957 result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
1958 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1962 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1963 op, getTypeConverter()->convertType(resultType),
result);
1967LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
1968 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1969 ConversionPatternRewriter &rewriter)
const {
1970 Location loc = op.getLoc();
1972 return rewriter.notifyMatchFailure(
1973 loc,
"Fp8 conversion instructions are not available on target "
1974 "architecture and their emulation is not implemented");
1975 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1977 Type resultType = op.getResult().getType();
1980 Value sourceA = adaptor.getSourceA();
1981 Value sourceB = adaptor.getSourceB();
1983 sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.
getType());
1984 Value existing = adaptor.getExisting();
1986 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
1988 existing = LLVM::UndefOp::create(rewriter, loc, i32);
1992 result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
1993 existing, op.getWordIndex());
1995 result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
1996 existing, op.getWordIndex());
1998 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1999 op, getTypeConverter()->convertType(resultType),
result);
2003LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
2004 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
2005 ConversionPatternRewriter &rewriter)
const {
2006 Location loc = op.getLoc();
2008 return rewriter.notifyMatchFailure(
2009 loc,
"Fp8 conversion instructions are not available on target "
2010 "architecture and their emulation is not implemented");
2011 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2013 Type resultType = op.getResult().getType();
2016 Value source = adaptor.getSource();
2017 Value stoch = adaptor.getStochiasticParam();
2018 Value existing = adaptor.getExisting();
2020 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2022 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2026 result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
2027 existing, op.getStoreIndex());
2029 result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
2030 existing, op.getStoreIndex());
2032 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2033 op, getTypeConverter()->convertType(resultType),
result);
2039struct AMDGPUDPPLowering :
public ConvertOpToLLVMPattern<DPPOp> {
2040 AMDGPUDPPLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2041 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
2045 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
2046 ConversionPatternRewriter &rewriter)
const override {
2049 Location loc = DppOp.getLoc();
2050 Value src = adaptor.getSrc();
2051 Value old = adaptor.getOld();
2054 Type llvmType =
nullptr;
2056 llvmType = rewriter.getI32Type();
2057 }
else if (isa<FloatType>(srcType)) {
2059 ? rewriter.getF32Type()
2060 : rewriter.getF64Type();
2061 }
else if (isa<IntegerType>(srcType)) {
2063 ? rewriter.getI32Type()
2064 : rewriter.getI64Type();
2066 auto llvmSrcIntType = typeConverter->convertType(
2070 auto convertOperand = [&](Value operand, Type operandType) {
2071 if (operandType.getIntOrFloatBitWidth() <= 16) {
2072 if (llvm::isa<FloatType>(operandType)) {
2074 LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
2076 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
2077 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
2078 Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
2080 LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
2082 operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
2087 src = convertOperand(src, srcType);
2088 old = convertOperand(old, oldType);
2091 enum DppCtrl :
unsigned {
2100 ROW_HALF_MIRROR = 0x141,
2105 auto kind = DppOp.getKind();
2106 auto permArgument = DppOp.getPermArgument();
2107 uint32_t DppCtrl = 0;
2111 case DPPPerm::quad_perm:
2112 if (
auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
2114 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
2115 uint32_t num = elem.getInt();
2116 DppCtrl |= num << (i * 2);
2121 case DPPPerm::row_shl:
2122 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
2123 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
2126 case DPPPerm::row_shr:
2127 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
2128 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
2131 case DPPPerm::row_ror:
2132 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
2133 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
2136 case DPPPerm::wave_shl:
2137 DppCtrl = DppCtrl::WAVE_SHL1;
2139 case DPPPerm::wave_shr:
2140 DppCtrl = DppCtrl::WAVE_SHR1;
2142 case DPPPerm::wave_rol:
2143 DppCtrl = DppCtrl::WAVE_ROL1;
2145 case DPPPerm::wave_ror:
2146 DppCtrl = DppCtrl::WAVE_ROR1;
2148 case DPPPerm::row_mirror:
2149 DppCtrl = DppCtrl::ROW_MIRROR;
2151 case DPPPerm::row_half_mirror:
2152 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
2154 case DPPPerm::row_bcast_15:
2155 DppCtrl = DppCtrl::BCAST15;
2157 case DPPPerm::row_bcast_31:
2158 DppCtrl = DppCtrl::BCAST31;
2164 auto rowMask = DppOp->getAttrOfType<IntegerAttr>(
"row_mask").getInt();
2165 auto bankMask = DppOp->getAttrOfType<IntegerAttr>(
"bank_mask").getInt();
2166 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>(
"bound_ctrl").getValue();
2170 ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
2171 rowMask, bankMask, boundCtrl);
2173 Value
result = dppMovOp.getRes();
2175 result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType,
result);
2176 if (!llvm::isa<IntegerType>(srcType)) {
2177 result = LLVM::BitcastOp::create(rewriter, loc, srcType,
result);
2188struct AMDGPUSwizzleBitModeLowering
2189 :
public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
2193 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
2194 ConversionPatternRewriter &rewriter)
const override {
2195 Location loc = op.getLoc();
2196 Type i32 = rewriter.getI32Type();
2197 Value src = adaptor.getSrc();
2198 SmallVector<Value> decomposed =
2200 unsigned andMask = op.getAndMask();
2201 unsigned orMask = op.getOrMask();
2202 unsigned xorMask = op.getXorMask();
2206 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
2208 SmallVector<Value> swizzled;
2209 for (Value v : decomposed) {
2211 ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
2212 swizzled.emplace_back(res);
2216 rewriter.replaceOp(op,
result);
2221struct AMDGPUPermlaneLowering :
public ConvertOpToLLVMPattern<PermlaneSwapOp> {
2224 AMDGPUPermlaneLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2225 : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
2229 matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
2230 ConversionPatternRewriter &rewriter)
const override {
2232 return op->emitOpError(
"permlane_swap is only supported on gfx950+");
2234 Location loc = op.getLoc();
2235 Type i32 = rewriter.getI32Type();
2236 Value src = adaptor.getSrc();
2237 unsigned rowLength = op.getRowLength();
2238 bool fi = op.getFetchInactive();
2239 bool boundctrl = op.getBoundCtrl();
2241 SmallVector<Value> decomposed =
2244 SmallVector<Value> permuted;
2245 for (Value v : decomposed) {
2247 Type i32pair = LLVM::LLVMStructType::getLiteral(
2248 rewriter.getContext(), {v.getType(), v.getType()});
2250 if (rowLength == 16)
2251 res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2253 else if (rowLength == 32)
2254 res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2257 llvm_unreachable(
"unsupported row length");
2259 Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
2260 Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
2262 Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
2263 LLVM::ICmpPredicate::eq, vdst0, v);
2268 LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
2269 permuted.emplace_back(vdstNew);
2273 rewriter.replaceOp(op,
result);
2278struct AMDGPUMakeDmaBaseLowering
2279 :
public ConvertOpToLLVMPattern<MakeDmaBaseOp> {
2282 AMDGPUMakeDmaBaseLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2283 : ConvertOpToLLVMPattern<MakeDmaBaseOp>(converter), chipset(chipset) {}
2287 matchAndRewrite(MakeDmaBaseOp op, OpAdaptor adaptor,
2288 ConversionPatternRewriter &rewriter)
const override {
2290 return op->emitOpError(
"make_dma_base is only supported on gfx1250");
2292 Location loc = op.getLoc();
2294 ValueRange ldsIndices = adaptor.getLdsIndices();
2295 Value lds = adaptor.getLds();
2296 auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
2301 ValueRange globalIndices = adaptor.getGlobalIndices();
2302 Value global = adaptor.getGlobal();
2303 auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
2306 global, globalIndices);
2308 Type i32 = rewriter.getI32Type();
2309 Type i64 = rewriter.getI64Type();
2311 Value castForLdsAddr = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
2312 Value castForGlobalAddr =
2313 LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
2316 LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
2318 Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
2321 Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
2324 Value validHighHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
2327 Value highHalfPlusType =
2328 LLVM::OrOp::create(rewriter, loc, validHighHalf, typeField);
2335 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
2336 assert(v4i32 &&
"expected type conversion to succeed");
2337 Value
result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
2338 result = LLVM::InsertElementOp::create(rewriter, loc,
result, c1, c0);
2339 result = LLVM::InsertElementOp::create(rewriter, loc,
result,
2340 castForLdsAddr, c1);
2341 result = LLVM::InsertElementOp::create(rewriter, loc,
result, lowHalf, c2);
2342 result = LLVM::InsertElementOp::create(rewriter, loc,
result,
2343 highHalfPlusType, c3);
2345 rewriter.replaceOp(op,
result);
2350struct AMDGPUMakeDmaDescriptorLowering
2351 :
public ConvertOpToLLVMPattern<MakeDmaDescriptorOp> {
2354 AMDGPUMakeDmaDescriptorLowering(
const LLVMTypeConverter &converter,
2356 : ConvertOpToLLVMPattern<MakeDmaDescriptorOp>(converter),
2360 Value getDGroup0(OpAdaptor adaptor)
const {
return adaptor.getBase(); }
2362 Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
2363 Value accumulator, Value value, int64_t shift)
const {
2368 value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
2374 return LLVM::OrOp::create(rewriter, loc, accumulator, value);
2377 Value setWorkgroupMask(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2378 ConversionPatternRewriter &rewriter, Location loc,
2379 Value sgpr0)
const {
2380 Value mask = op.getWorkgroupMask();
2384 Type i32 = rewriter.getI32Type();
2385 Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask);
2386 return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0);
2389 Value setDataSize(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2390 ConversionPatternRewriter &rewriter, Location loc,
2391 Value sgpr0, ArrayRef<Value> consts)
const {
2393 unsigned elementTypeWidthInBits = op.getElementTypeWidth();
2395 llvm::is_contained<unsigned>({8, 16, 32, 64}, elementTypeWidthInBits) &&
2396 "expected type width to be 8, 16, 32, or 64.");
2397 int64_t dataSize = llvm::Log2_32(elementTypeWidthInBits / 8);
2399 return setValueAtOffset(rewriter, loc, sgpr0, size, 16);
2402 Value setAtomicBarrier(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2403 ConversionPatternRewriter &rewriter, Location loc,
2404 Value sgpr0, ArrayRef<Value> consts)
const {
2405 bool atomic_barrier_enable = adaptor.getAtomicBarrierAddress() !=
nullptr;
2406 if (!atomic_barrier_enable)
2409 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
2412 Value setIterateEnable(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2413 ConversionPatternRewriter &rewriter, Location loc,
2414 Value sgpr0, ArrayRef<Value> consts)
const {
2415 bool iterate_enable = adaptor.getGlobalIncrement() !=
nullptr;
2416 if (!iterate_enable)
2420 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
2423 Value setPadEnable(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2424 ConversionPatternRewriter &rewriter, Location loc,
2425 Value sgpr0, ArrayRef<Value> consts)
const {
2426 bool pad_enable = op.getPadAmount() !=
nullptr;
2430 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
2433 Value setEarlyTimeout(MakeDmaDescriptorOp op, OpAdaptor adaptorm,
2434 ConversionPatternRewriter &rewriter, Location loc,
2435 Value sgpr0, ArrayRef<Value> consts)
const {
2436 if (!op.getWorkgroupMask())
2439 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21);
2442 Value setPadInterval(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2443 ConversionPatternRewriter &rewriter, Location loc,
2444 Value sgpr0, ArrayRef<Value> consts)
const {
2445 bool pad_enable = op.getPadAmount() !=
nullptr;
2449 IntegerType i32 = rewriter.getI32Type();
2450 Value padInterval = adaptor.getPadInterval();
2452 padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
2453 padInterval,
false);
2454 padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
2456 return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
2459 Value setPadAmount(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2460 ConversionPatternRewriter &rewriter, Location loc,
2461 Value sgpr0, ArrayRef<Value> consts)
const {
2462 bool pad_enable = op.getPadAmount() !=
nullptr;
2466 Value padAmount = adaptor.getPadAmount();
2468 padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
2470 return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
2473 Value setAtomicBarrierAddress(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2474 ConversionPatternRewriter &rewriter,
2475 Location loc, Value sgpr1,
2476 ArrayRef<Value> consts)
const {
2477 bool atomic_barrier_enable = adaptor.getAtomicBarrierAddress() !=
nullptr;
2478 if (!atomic_barrier_enable)
2481 Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
2482 auto barrierAddressTy =
2483 cast<MemRefType>(op.getAtomicBarrierAddress().getType());
2484 ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
2485 atomicBarrierAddress =
2487 atomicBarrierAddress, atomicBarrierIndices);
2488 IntegerType i32 = rewriter.getI32Type();
2491 atomicBarrierAddress =
2492 LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
2493 atomicBarrierAddress =
2494 LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
2496 atomicBarrierAddress =
2497 LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
2498 return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
2501 std::pair<Value, Value> setTensorDim0(MakeDmaDescriptorOp op,
2503 ConversionPatternRewriter &rewriter,
2504 Location loc, Value sgpr1, Value sgpr2,
2505 ArrayRef<Value> consts)
const {
2506 SmallVector<OpFoldResult> mixedGlobalSizes = op.getMixedGlobalSizes();
2507 OpFoldResult tensorDim0OpFoldResult = mixedGlobalSizes.back();
2509 if (
auto attr = dyn_cast<Attribute>(tensorDim0OpFoldResult))
2513 tensorDim0 = cast<Value>(tensorDim0OpFoldResult);
2516 Value tensorDim0High = LLVM::LShrOp::create(rewriter, loc, tensorDim0, c16);
2517 sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDim0, 48);
2518 sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDim0High, 48 + 16);
2519 return {sgpr1, sgpr2};
2522 std::pair<Value, Value> setTensorDim1(MakeDmaDescriptorOp op,
2524 ConversionPatternRewriter &rewriter,
2525 Location loc, Value sgpr2, Value sgpr3,
2526 ArrayRef<Value> consts)
const {
2528 SmallVector<OpFoldResult> mixedGlobalSizes = op.getMixedGlobalSizes();
2529 OpFoldResult tensorDim1OpFoldResult = *(mixedGlobalSizes.rbegin() + 1);
2531 if (
auto attr = dyn_cast<Attribute>(tensorDim1OpFoldResult))
2535 tensorDim1 = cast<Value>(tensorDim1OpFoldResult);
2538 Value tensorDim1High = LLVM::LShrOp::create(rewriter, loc, tensorDim1, c16);
2539 sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDim1, 80);
2540 sgpr3 = setValueAtOffset(rewriter, loc, sgpr3, tensorDim1High, 80 + 16);
2541 return {sgpr2, sgpr3};
2544 Value setTileDimX(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2545 ConversionPatternRewriter &rewriter, Location loc,
2546 Value sgpr, ArrayRef<Value> consts,
size_t dimX,
2547 int64_t offset)
const {
2548 SmallVector<OpFoldResult> mixedSharedSizes = op.getMixedSharedSizes();
2550 if (mixedSharedSizes.size() <= dimX)
2553 OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
2555 if (
auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult))
2559 tileDimX = cast<Value>(tileDimXOpFoldResult);
2561 return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
2564 Value setTileDim0(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2565 ConversionPatternRewriter &rewriter, Location loc,
2566 Value sgpr3, ArrayRef<Value> consts)
const {
2567 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
2570 Value setTileDim1(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2571 ConversionPatternRewriter &rewriter, Location loc,
2572 Value sgpr4, ArrayRef<Value> consts)
const {
2573 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
2576 Value setTileDim2(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2577 ConversionPatternRewriter &rewriter, Location loc,
2578 Value sgpr4, ArrayRef<Value> consts)
const {
2579 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
2582 std::pair<Value, Value>
2583 setTensorDimXStride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2584 ConversionPatternRewriter &rewriter, Location loc,
2585 Value sgprY, Value sgprZ, ArrayRef<Value> consts,
2586 size_t dimX, int64_t offset)
const {
2587 SmallVector<OpFoldResult> mixedGlobalStrides = op.getMixedGlobalStrides();
2589 if (mixedGlobalStrides.size() <= dimX)
2590 return {sgprY, sgprZ};
2592 OpFoldResult tensorDimXStrideOpFoldResult =
2593 *(mixedGlobalStrides.rbegin() + dimX);
2594 Value tensorDimXStride;
2595 if (
auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult))
2599 tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
2601 constexpr int64_t first48bits = (1ll << 48) - 1;
2604 LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
2605 IntegerType i32 = rewriter.getI32Type();
2606 Value tensorDimXStrideLow =
2607 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
2609 int64_t shift = (offset % 32) == 0 ? 32 : offset % 32;
2611 Value tensorDimXStrideHigh =
2612 LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
2613 tensorDimXStrideHigh =
2614 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
2616 sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
2617 sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
2619 return {sgprY, sgprZ};
2622 std::pair<Value, Value>
2623 setTensorDim0Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2624 ConversionPatternRewriter &rewriter, Location loc,
2625 Value sgpr5, Value sgpr6, ArrayRef<Value> consts)
const {
2626 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
2630 std::pair<Value, Value>
2631 setTensorDim1Stride(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2632 ConversionPatternRewriter &rewriter, Location loc,
2633 Value sgpr5, Value sgpr6, ArrayRef<Value> consts)
const {
2634 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
2638 Value getDGroup1(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2639 ConversionPatternRewriter &rewriter, Location loc,
2640 ArrayRef<Value> consts)
const {
2642 for (int64_t i = 0; i < 8; i++) {
2643 sgprs[i] = consts[0];
2646 sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]);
2647 sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts);
2648 sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts);
2649 sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
2650 sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
2651 sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts);
2652 sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts);
2653 sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts);
2656 setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts);
2657 std::tie(sgprs[1], sgprs[2]) =
2658 setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
2659 std::tie(sgprs[2], sgprs[3]) =
2660 setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
2662 sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
2663 sgprs[4] = setTileDim1(op, adaptor, rewriter, loc, sgprs[4], consts);
2664 sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
2665 std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
2666 op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
2667 std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride(
2668 op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts);
2670 IntegerType i32 = rewriter.getI32Type();
2671 Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32));
2672 assert(v8i32 &&
"expected type conversion to succeed");
2673 Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32);
2675 for (
auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) {
2677 LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant);
2684 matchAndRewrite(MakeDmaDescriptorOp op, OpAdaptor adaptor,
2685 ConversionPatternRewriter &rewriter)
const override {
2687 return op->emitOpError(
2688 "make_dma_descriptor is only supported on gfx1250");
2690 if (op.getRank() > 2)
2691 return op->emitOpError(
"unimplemented");
2693 Location loc = op.getLoc();
2695 IntegerType i32 = rewriter.getI32Type();
2696 [[maybe_unused]] Type v4i32 =
2697 this->typeConverter->convertType(VectorType::get(4, i32));
2698 assert(v4i32 &&
"expected type conversion to succeed");
2700 SmallVector<Value> consts;
2701 for (int64_t i = 0; i < 8; i++)
2704 Value dgroup0 = this->getDGroup0(adaptor);
2705 Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts);
2707 SmallVector<Value> results = {dgroup0, dgroup1};
2708 rewriter.replaceOpWithMultiple(op, {results});
2713struct ConvertAMDGPUToROCDLPass
2714 :
public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
2717 void runOnOperation()
override {
2720 if (
failed(maybeChipset)) {
2721 emitError(UnknownLoc::get(ctx),
"Invalid chipset name: " + chipset);
2722 return signalPassFailure();
2726 LLVMTypeConverter converter(ctx);
2727 converter.addConversion([&](TDMBaseType type) -> Type {
2728 Type i32 = IntegerType::get(type.getContext(), 32);
2729 return converter.convertType(VectorType::get(4, i32));
2734 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
2735 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
2736 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
2737 if (
failed(applyPartialConversion(getOperation(),
target,
2739 signalPassFailure();
2746 typeConverter.addTypeAttributeConversion(
2748 -> TypeConverter::AttributeConversionResult {
2750 Type i64 = IntegerType::get(ctx, 64);
2751 switch (as.getValue()) {
2752 case amdgpu::AddressSpace::FatRawBuffer:
2753 return IntegerAttr::get(i64, 7);
2754 case amdgpu::AddressSpace::BufferRsrc:
2755 return IntegerAttr::get(i64, 8);
2756 case amdgpu::AddressSpace::FatStructuredBuffer:
2757 return IntegerAttr::get(i64, 9);
2759 return TypeConverter::AttributeConversionResult::abort();
2768 FatRawBufferCastLowering,
2769 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
2770 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
2771 RawBufferOpLowering<RawBufferAtomicFaddOp,
2772 ROCDL::RawPtrBufferAtomicFaddOp>,
2773 RawBufferOpLowering<RawBufferAtomicFmaxOp,
2774 ROCDL::RawPtrBufferAtomicFmaxOp>,
2775 RawBufferOpLowering<RawBufferAtomicSmaxOp,
2776 ROCDL::RawPtrBufferAtomicSmaxOp>,
2777 RawBufferOpLowering<RawBufferAtomicUminOp,
2778 ROCDL::RawPtrBufferAtomicUminOp>,
2779 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
2780 ROCDL::RawPtrBufferAtomicCmpSwap>,
2781 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
2782 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
2783 WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
2784 ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
2785 PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
2786 GatherToLDSOpLowering, TransposeLoadOpLowering, AMDGPUPermlaneLowering,
2787 AMDGPUMakeDmaBaseLowering, AMDGPUMakeDmaDescriptorLowering>(converter,
2789 patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsic...
static std::optional< StringRef > wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k, bool isRDNA3)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma for RDNA3/4 architectures.
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
constexpr Chipset kGfx908
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs, StringRef attrName)
Push an input operand.
constexpr Chipset kGfx1250
constexpr Chipset kGfx90a
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value)
static std::optional< StringRef > wmmaOpToIntrinsicGfx1250(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k)
Return the rocdl intrinsic corresponding to a WMMA operation wmma for the gfx1250 architecture.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to ROCDL ...
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i64.
constexpr Chipset kGfx950
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, int64_t elementByteWidth)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static std::optional< uint32_t > mfmaTypeSelectCode(Type mlirElemType)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
NamedAttribute represents a combination of a name and an Attribute value.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
::mlir::Pass::Option< std::string > chipset
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
bool hasOcpFp8(const Chipset &chipset)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
void populateAMDGPUMemorySpaceAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces,...
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.