26#include "llvm/ADT/STLExtras.h"
27#include "llvm/ADT/TypeSwitch.h"
28#include "llvm/Support/Casting.h"
29#include "llvm/Support/ErrorHandling.h"
33#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
34#include "mlir/Conversion/Passes.h.inc"
49 IntegerType i32 = rewriter.getI32Type();
51 auto valTy = cast<IntegerType>(val.
getType());
54 return valTy.getWidth() > 32
55 ?
Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
56 :
Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
61 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
67 IntegerType i64 = rewriter.getI64Type();
69 auto valTy = cast<IntegerType>(val.
getType());
72 return valTy.getWidth() > 64
73 ?
Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
74 :
Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
79 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
84 Type llvmI1 = rewriter.getI1Type();
85 return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value);
92 IntegerType i32 = rewriter.getI32Type();
94 for (
auto [i, increment, stride] : llvm::enumerate(
indices, strides)) {
97 ShapedType::isDynamic(stride)
99 memRefDescriptor.
stride(rewriter, loc, i))
100 : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
101 increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
113 MemRefType memrefType,
117 if (memrefType.hasStaticShape() &&
118 !llvm::any_of(strides, ShapedType::isDynamic)) {
119 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
121 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
122 size = std::max(
shape[i] * strides[i], size);
123 size = size * elementByteWidth;
127 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
128 Value size = memrefDescriptor.
size(rewriter, loc, i);
129 Value stride = memrefDescriptor.
stride(rewriter, loc, i);
130 Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
132 ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
137 return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
143 Value cacheSwizzleStride =
nullptr,
144 unsigned addressSpace = 8) {
148 Type i16 = rewriter.getI16Type();
151 Value cacheStrideZext =
152 LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
153 Value swizzleBit = LLVM::ConstantOp::create(
154 rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
155 stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
158 stride = LLVM::ConstantOp::create(rewriter, loc, i16,
159 rewriter.getI16IntegerAttr(0));
176 uint32_t flags = (7 << 12) | (4 << 15);
179 uint32_t oob = boundsCheck ? 3 : 2;
180 flags |= (oob << 28);
184 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
185 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
186 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
191struct FatRawBufferCastLowering
193 FatRawBufferCastLowering(
const LLVMTypeConverter &converter, Chipset chipset)
194 : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
200 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
201 ConversionPatternRewriter &rewriter)
const override {
202 Location loc = op.getLoc();
203 Value memRef = adaptor.getSource();
204 Value unconvertedMemref = op.getSource();
205 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
206 MemRefDescriptor descriptor(memRef);
208 DataLayout dataLayout = DataLayout::closest(op);
209 int64_t elementByteWidth =
212 int64_t unusedOffset = 0;
213 SmallVector<int64_t, 5> strideVals;
214 if (
failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
215 return op.emitOpError(
"Can't lower non-stride-offset memrefs");
217 Value numRecords = adaptor.getValidBytes();
219 numRecords =
getNumRecords(rewriter, loc, memrefType, descriptor,
220 strideVals, elementByteWidth);
223 adaptor.getResetOffset()
224 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
226 : descriptor.alignedPtr(rewriter, loc);
228 Value offset = adaptor.getResetOffset()
229 ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
230 rewriter.getIndexAttr(0))
231 : descriptor.offset(rewriter, loc);
233 bool hasSizes = memrefType.getRank() > 0;
236 Value sizes = hasSizes
237 ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
241 hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
246 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
247 chipset, adaptor.getCacheSwizzleStride(), 7);
249 Value
result = MemRefDescriptor::poison(
251 getTypeConverter()->convertType(op.getResult().getType()));
253 result = LLVM::InsertValueOp::create(rewriter, loc,
result, fatPtr, pos);
254 result = LLVM::InsertValueOp::create(rewriter, loc,
result, fatPtr,
256 result = LLVM::InsertValueOp::create(rewriter, loc,
result, offset,
259 result = LLVM::InsertValueOp::create(rewriter, loc,
result, sizes,
261 result = LLVM::InsertValueOp::create(rewriter, loc,
result, strides,
264 rewriter.replaceOp(op,
result);
270template <
typename GpuOp,
typename Intrinsic>
272 RawBufferOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
273 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
276 static constexpr uint32_t maxVectorOpWidth = 128;
279 matchAndRewrite(GpuOp gpuOp,
typename GpuOp::Adaptor adaptor,
280 ConversionPatternRewriter &rewriter)
const override {
281 Location loc = gpuOp.getLoc();
282 Value memref = adaptor.getMemref();
283 Value unconvertedMemref = gpuOp.getMemref();
284 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
286 if (chipset.majorVersion < 9)
287 return gpuOp.emitOpError(
"raw buffer ops require GCN or higher");
289 Value storeData = adaptor.getODSOperands(0)[0];
290 if (storeData == memref)
294 wantedDataType = storeData.
getType();
296 wantedDataType = gpuOp.getODSResults(0)[0].getType();
298 Value atomicCmpData = Value();
301 Value maybeCmpData = adaptor.getODSOperands(1)[0];
302 if (maybeCmpData != memref)
303 atomicCmpData = maybeCmpData;
306 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
308 Type i32 = rewriter.getI32Type();
311 DataLayout dataLayout = DataLayout::closest(gpuOp);
312 int64_t elementByteWidth =
321 Type llvmBufferValType = llvmWantedDataType;
323 if (
auto floatType = dyn_cast<FloatType>(wantedDataType))
324 llvmBufferValType = this->getTypeConverter()->convertType(
325 rewriter.getIntegerType(floatType.getWidth()));
327 if (
auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
328 uint32_t vecLen = dataVector.getNumElements();
331 uint32_t totalBits = elemBits * vecLen;
333 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
334 if (totalBits > maxVectorOpWidth)
335 return gpuOp.emitOpError(
336 "Total width of loads or stores must be no more than " +
337 Twine(maxVectorOpWidth) +
" bits, but we call for " +
339 " bits. This should've been caught in validation");
340 if (!usePackedFp16 && elemBits < 32) {
341 if (totalBits > 32) {
342 if (totalBits % 32 != 0)
343 return gpuOp.emitOpError(
"Load or store of more than 32-bits that "
344 "doesn't fit into words. Can't happen\n");
345 llvmBufferValType = this->typeConverter->convertType(
346 VectorType::get(totalBits / 32, i32));
348 llvmBufferValType = this->typeConverter->convertType(
349 rewriter.getIntegerType(totalBits));
353 if (
auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
356 if (vecType.getNumElements() == 1)
357 llvmBufferValType = vecType.getElementType();
360 SmallVector<Value, 6> args;
362 if (llvmBufferValType != llvmWantedDataType) {
363 Value castForStore = LLVM::BitcastOp::create(
364 rewriter, loc, llvmBufferValType, storeData);
365 args.push_back(castForStore);
367 args.push_back(storeData);
372 if (llvmBufferValType != llvmWantedDataType) {
373 Value castForCmp = LLVM::BitcastOp::create(
374 rewriter, loc, llvmBufferValType, atomicCmpData);
375 args.push_back(castForCmp);
377 args.push_back(atomicCmpData);
383 SmallVector<int64_t, 5> strides;
384 if (
failed(memrefType.getStridesAndOffset(strides, offset)))
385 return gpuOp.emitOpError(
"Can't lower non-stride-offset memrefs");
390 rewriter, loc, *this->getTypeConverter(), memrefType);
392 rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
395 args.push_back(resource);
399 adaptor.getIndices(), strides);
400 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
401 indexOffset && *indexOffset > 0) {
403 voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
407 voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
408 args.push_back(voffset);
411 Value sgprOffset = adaptor.getSgprOffset();
414 sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
415 args.push_back(sgprOffset);
424 Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
428 if (llvmBufferValType != llvmWantedDataType) {
429 replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
434 rewriter.eraseOp(gpuOp);
451static FailureOr<unsigned> encodeWaitcnt(
Chipset chipset,
unsigned vmcnt,
452 unsigned expcnt,
unsigned lgkmcnt) {
454 vmcnt = std::min(15u, vmcnt);
455 expcnt = std::min(7u, expcnt);
456 lgkmcnt = std::min(15u, lgkmcnt);
457 return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
460 vmcnt = std::min(63u, vmcnt);
461 expcnt = std::min(7u, expcnt);
462 lgkmcnt = std::min(15u, lgkmcnt);
463 unsigned lowBits = vmcnt & 0xF;
464 unsigned highBits = (vmcnt >> 4) << 14;
465 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
466 return lowBits | highBits | otherCnts;
469 vmcnt = std::min(63u, vmcnt);
470 expcnt = std::min(7u, expcnt);
471 lgkmcnt = std::min(63u, lgkmcnt);
472 unsigned lowBits = vmcnt & 0xF;
473 unsigned highBits = (vmcnt >> 4) << 14;
474 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
475 return lowBits | highBits | otherCnts;
478 vmcnt = std::min(63u, vmcnt);
479 expcnt = std::min(7u, expcnt);
480 lgkmcnt = std::min(63u, lgkmcnt);
481 return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
486struct MemoryCounterWaitOpLowering
496 matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
497 ConversionPatternRewriter &rewriter)
const override {
498 if (chipset.majorVersion >= 12) {
499 Location loc = op.getLoc();
500 if (std::optional<int> ds = adaptor.getDs())
501 ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
503 if (std::optional<int>
load = adaptor.getLoad())
504 ROCDL::WaitLoadcntOp::create(rewriter, loc, *
load);
506 if (std::optional<int> store = adaptor.getStore())
507 ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
509 if (std::optional<int> exp = adaptor.getExp())
510 ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
512 rewriter.eraseOp(op);
516 auto getVal = [](Attribute attr) ->
unsigned {
518 return cast<IntegerAttr>(attr).getInt();
523 unsigned ds = getVal(adaptor.getDsAttr());
524 unsigned exp = getVal(adaptor.getExpAttr());
526 unsigned vmcnt = 1024;
527 Attribute
load = adaptor.getLoadAttr();
528 Attribute store = adaptor.getStoreAttr();
530 vmcnt = getVal(
load) + getVal(store);
532 vmcnt = getVal(
load);
534 vmcnt = getVal(store);
537 FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
539 return op.emitOpError(
"unsupported chipset");
541 rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
547 LDSBarrierOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
548 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
553 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
554 ConversionPatternRewriter &rewriter)
const override {
555 Location loc = op.getLoc();
558 bool requiresInlineAsm = chipset <
kGfx90a;
561 rewriter.getAttr<LLVM::MMRATagAttr>(
"amdgpu-synchronize-as",
"local");
570 StringRef scope =
"workgroup";
572 auto relFence = LLVM::FenceOp::create(rewriter, loc,
573 LLVM::AtomicOrdering::release, scope);
574 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
575 if (requiresInlineAsm) {
576 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
577 LLVM::AsmDialect::AD_ATT);
578 const char *asmStr =
";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
579 const char *constraints =
"";
580 LLVM::InlineAsmOp::create(
583 asmStr, constraints,
true,
584 false, LLVM::TailCallKind::None,
587 }
else if (chipset.majorVersion < 12) {
588 ROCDL::SBarrierOp::create(rewriter, loc);
590 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
591 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
594 auto acqFence = LLVM::FenceOp::create(rewriter, loc,
595 LLVM::AtomicOrdering::acquire, scope);
596 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
597 rewriter.replaceOp(op, acqFence);
603 SchedBarrierOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
604 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
609 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
610 ConversionPatternRewriter &rewriter)
const override {
611 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
612 (uint32_t)op.getOpts());
636 bool allowBf16 =
true) {
638 if (
auto vectorType = dyn_cast<VectorType>(inputType)) {
639 if (vectorType.getElementType().isBF16() && !allowBf16)
640 return LLVM::BitcastOp::create(
641 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
642 if (vectorType.getElementType().isInteger(8) &&
643 vectorType.getNumElements() <= 8)
644 return LLVM::BitcastOp::create(
646 rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
647 if (isa<IntegerType>(vectorType.getElementType()) &&
648 vectorType.getElementTypeBitWidth() <= 8) {
649 int64_t numWords = llvm::divideCeil(
650 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
652 return LLVM::BitcastOp::create(
653 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
673 Type outputType = rewriter.getI32Type();
674 if (
auto intType = dyn_cast<IntegerType>(inputType))
675 return LLVM::ZExtOp::create(rewriter, loc, outputType, input);
676 return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
690 bool isUnsigned,
Value llvmInput,
694 auto vectorType = dyn_cast<VectorType>(inputType);
696 operands.push_back(llvmInput);
699 Type elemType = vectorType.getElementType();
702 llvmInput = LLVM::BitcastOp::create(
703 rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
705 operands.push_back(llvmInput);
712 auto mlirInputType = cast<VectorType>(mlirInput.
getType());
713 bool isInputInteger = mlirInputType.getElementType().isInteger();
714 if (isInputInteger) {
716 bool localIsUnsigned = isUnsigned;
718 localIsUnsigned =
true;
720 localIsUnsigned =
false;
723 operands.push_back(sign);
728 Type i32 = rewriter.getI32Type();
729 Type intrinsicInType = numBits <= 32
730 ? (
Type)rewriter.getIntegerType(numBits)
731 : (
Type)VectorType::get(numBits / 32, i32);
732 auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
733 Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
734 loc, llvmIntrinsicInType, llvmInput);
739 castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
740 operands.push_back(castInput);
753 Value output, int32_t subwordOffset,
756 auto vectorType = dyn_cast<VectorType>(inputType);
757 Type elemType = vectorType.getElementType();
759 output = LLVM::BitcastOp::create(
760 rewriter, loc, vectorType.clone(rewriter.getI16Type()), output);
761 operands.push_back(output);
772 return (chipset ==
kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
773 (
hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
779 return (chipset ==
kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
780 (
hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
788 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
789 b = mfma.getBlocks();
794 if (mfma.getReducePrecision() && chipset >=
kGfx942) {
795 if (m == 32 && n == 32 && k == 4 &&
b == 1)
796 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
797 if (m == 16 && n == 16 && k == 8 &&
b == 1)
798 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
800 if (m == 32 && n == 32 && k == 1 &&
b == 2)
801 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
802 if (m == 16 && n == 16 && k == 1 &&
b == 4)
803 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
804 if (m == 4 && n == 4 && k == 1 &&
b == 16)
805 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
806 if (m == 32 && n == 32 && k == 2 &&
b == 1)
807 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
808 if (m == 16 && n == 16 && k == 4 &&
b == 1)
809 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
814 if (m == 32 && n == 32 && k == 16 &&
b == 1)
815 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
816 if (m == 16 && n == 16 && k == 32 &&
b == 1)
817 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
819 if (m == 32 && n == 32 && k == 4 &&
b == 2)
820 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
821 if (m == 16 && n == 16 && k == 4 &&
b == 4)
822 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
823 if (m == 4 && n == 4 && k == 4 &&
b == 16)
824 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
825 if (m == 32 && n == 32 && k == 8 &&
b == 1)
826 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
827 if (m == 16 && n == 16 && k == 16 &&
b == 1)
828 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
833 if (m == 32 && n == 32 && k == 16 &&
b == 1)
834 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
835 if (m == 16 && n == 16 && k == 32 &&
b == 1)
836 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
839 if (m == 32 && n == 32 && k == 4 &&
b == 2)
840 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
841 if (m == 16 && n == 16 && k == 4 &&
b == 4)
842 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
843 if (m == 4 && n == 4 && k == 4 &&
b == 16)
844 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
845 if (m == 32 && n == 32 && k == 8 &&
b == 1)
846 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
847 if (m == 16 && n == 16 && k == 16 &&
b == 1)
848 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
850 if (m == 32 && n == 32 && k == 2 &&
b == 2)
851 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
852 if (m == 16 && n == 16 && k == 2 &&
b == 4)
853 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
854 if (m == 4 && n == 4 && k == 2 &&
b == 16)
855 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
856 if (m == 32 && n == 32 && k == 4 &&
b == 1)
857 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
858 if (m == 16 && n == 16 && k == 8 &&
b == 1)
859 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
864 if (m == 32 && n == 32 && k == 32 &&
b == 1)
865 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
866 if (m == 16 && n == 16 && k == 64 &&
b == 1)
867 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
869 if (m == 32 && n == 32 && k == 4 &&
b == 2)
870 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
871 if (m == 16 && n == 16 && k == 4 &&
b == 4)
872 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
873 if (m == 4 && n == 4 && k == 4 &&
b == 16)
874 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
875 if (m == 32 && n == 32 && k == 8 &&
b == 1)
876 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
877 if (m == 16 && n == 16 && k == 16 &&
b == 1)
878 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
879 if (m == 32 && n == 32 && k == 16 &&
b == 1 && chipset >=
kGfx942)
880 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
881 if (m == 16 && n == 16 && k == 32 &&
b == 1 && chipset >=
kGfx942)
882 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
886 if (m == 16 && n == 16 && k == 4 &&
b == 1)
887 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
888 if (m == 4 && n == 4 && k == 4 &&
b == 4)
889 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
896 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
897 if (m == 16 && n == 16 && k == 32 &&
b == 1) {
899 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
901 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
903 if (m == 32 && n == 32 && k == 16 &&
b == 1) {
905 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
907 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
913 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
914 if (m == 16 && n == 16 && k == 32 &&
b == 1) {
916 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
918 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
920 if (m == 32 && n == 32 && k == 16 &&
b == 1) {
922 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
924 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
933 .Case([](Float8E4M3FNType) {
return 0u; })
934 .Case([](Float8E5M2Type) {
return 1u; })
935 .Case([](Float6E2M3FNType) {
return 2u; })
936 .Case([](Float6E3M2FNType) {
return 3u; })
937 .Case([](Float4E2M1FNType) {
return 4u; })
938 .Default(std::nullopt);
948static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
950 uint32_t n, uint32_t k, uint32_t
b,
Chipset chipset) {
957 if (!isa<Float32Type>(destType))
962 if (!aTypeCode || !bTypeCode)
965 if (m == 32 && n == 32 && k == 64 &&
b == 1)
966 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
967 *aTypeCode, *bTypeCode};
968 if (m == 16 && n == 16 && k == 128 &&
b == 1)
970 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
976static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
979 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
980 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
981 mfma.getBlocks(), chipset);
984static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
987 smfma.getSourceB().getType(),
988 smfma.getDestC().getType(), smfma.getM(),
989 smfma.getN(), smfma.getK(), 1u, chipset);
994static std::optional<StringRef>
996 Type elemDestType, uint32_t k,
bool isRDNA3) {
997 using fp8 = Float8E4M3FNType;
998 using bf8 = Float8E5M2Type;
1003 if (elemSourceType.
isF16() && elemDestType.
isF32())
1004 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
1005 if (elemSourceType.
isBF16() && elemDestType.
isF32())
1006 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
1007 if (elemSourceType.
isF16() && elemDestType.
isF16())
1008 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1010 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1012 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1017 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1018 return std::nullopt;
1022 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1023 elemDestType.
isF32())
1024 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1025 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1026 elemDestType.
isF32())
1027 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1028 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1029 elemDestType.
isF32())
1030 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1031 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1032 elemDestType.
isF32())
1033 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1035 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1037 return std::nullopt;
1041 if (k == 32 && !isRDNA3) {
1043 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1046 return std::nullopt;
1052 Type elemBSourceType,
1055 using fp8 = Float8E4M3FNType;
1056 using bf8 = Float8E5M2Type;
1059 if (elemSourceType.
isF32() && elemDestType.
isF32())
1060 return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1062 return std::nullopt;
1066 if (elemSourceType.
isF16() && elemDestType.
isF32())
1067 return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
1068 if (elemSourceType.
isBF16() && elemDestType.
isF32())
1069 return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
1070 if (elemSourceType.
isF16() && elemDestType.
isF16())
1071 return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
1073 return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
1075 return std::nullopt;
1079 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1080 if (elemDestType.
isF32())
1081 return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
1082 if (elemDestType.
isF16())
1083 return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
1085 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1086 if (elemDestType.
isF32())
1087 return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
1088 if (elemDestType.
isF16())
1089 return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
1091 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1092 if (elemDestType.
isF32())
1093 return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
1094 if (elemDestType.
isF16())
1095 return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
1097 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1098 if (elemDestType.
isF32())
1099 return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
1100 if (elemDestType.
isF16())
1101 return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
1104 return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
1106 return std::nullopt;
1110 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1111 if (elemDestType.
isF32())
1112 return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
1113 if (elemDestType.
isF16())
1114 return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
1116 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1117 if (elemDestType.
isF32())
1118 return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
1119 if (elemDestType.
isF16())
1120 return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
1122 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1123 if (elemDestType.
isF32())
1124 return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
1125 if (elemDestType.
isF16())
1126 return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
1128 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1129 if (elemDestType.
isF32())
1130 return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
1131 if (elemDestType.
isF16())
1132 return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
1135 return std::nullopt;
1138 return std::nullopt;
1146 auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
1147 auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
1148 auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1149 Type elemSourceType = sourceVectorType.getElementType();
1150 Type elemBSourceType = sourceBVectorType.getElementType();
1151 Type elemDestType = destVectorType.getElementType();
1153 const uint32_t k = wmma.getK();
1158 if (isRDNA3 || isRDNA4)
1163 if (chipset ==
Chipset{12, 5, 0})
1167 return std::nullopt;
1172 MFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1173 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
1178 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1179 ConversionPatternRewriter &rewriter)
const override {
1180 Location loc = op.getLoc();
1181 Type outType = typeConverter->convertType(op.getDestD().getType());
1182 Type intrinsicOutType = outType;
1183 if (
auto outVecType = dyn_cast<VectorType>(outType))
1184 if (outVecType.getElementType().isBF16())
1185 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
1187 if (chipset.majorVersion != 9 || chipset <
kGfx908)
1188 return op->emitOpError(
"MFMA only supported on gfx908+");
1189 uint32_t getBlgpField =
static_cast<uint32_t
>(op.getBlgp());
1190 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1192 return op.emitOpError(
"negation unsupported on older than gfx942");
1194 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1197 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1199 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1200 return op.emitOpError(
"no intrinsic matching MFMA size on given chipset");
1203 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1205 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1206 return op.emitOpError(
1207 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1208 "be scaled as those fields are used for type information");
1211 StringRef intrinsicName =
1212 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1215 bool allowBf16 = [&]() {
1220 return intrinsicName.contains(
"16x16x32.bf16") ||
1221 intrinsicName.contains(
"32x32x16.bf16");
1223 OperationState loweredOp(loc, intrinsicName);
1224 loweredOp.addTypes(intrinsicOutType);
1226 rewriter, loc, adaptor.getSourceA(), allowBf16),
1228 rewriter, loc, adaptor.getSourceB(), allowBf16),
1229 adaptor.getDestC()});
1232 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1242 Value lowered = rewriter.create(loweredOp)->getResult(0);
1243 if (outType != intrinsicOutType)
1244 lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1245 rewriter.replaceOp(op, lowered);
1251 ScaledMFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1252 : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1257 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1258 ConversionPatternRewriter &rewriter)
const override {
1259 Location loc = op.getLoc();
1260 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1262 if (chipset.majorVersion != 9 || chipset <
kGfx950)
1263 return op->emitOpError(
"scaled MFMA only supported on gfx908+");
1264 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1266 if (!maybeScaledIntrinsic.has_value())
1267 return op.emitOpError(
1268 "no intrinsic matching scaled MFMA size on given chipset");
1270 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1271 OperationState loweredOp(loc, intrinsicName);
1272 loweredOp.addTypes(intrinsicOutType);
1273 loweredOp.addOperands(
1276 adaptor.getDestC()});
1281 loweredOp.addOperands(
1290 Value lowered = rewriter.create(loweredOp)->getResult(0);
1291 rewriter.replaceOp(op, lowered);
1297 WMMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1298 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1303 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1304 ConversionPatternRewriter &rewriter)
const override {
1305 Location loc = op.getLoc();
1307 typeConverter->convertType<VectorType>(op.getDestD().
getType());
1309 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1311 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1312 return op->emitOpError(
"WMMA only supported on gfx11 and gfx12");
1316 VectorType rawOutType = outType;
1317 if (outType.getElementType().
isBF16())
1318 rawOutType = outType.clone(rewriter.getI16Type());
1322 if (!maybeIntrinsic.has_value())
1323 return op.emitOpError(
"no intrinsic matching WMMA on the given chipset");
1325 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1326 return op.emitOpError(
"subwordOffset not supported on gfx12+");
1328 OperationState loweredOp(loc, *maybeIntrinsic);
1329 loweredOp.addTypes(rawOutType);
1331 SmallVector<Value, 4> operands;
1333 adaptor.getSourceA(), op.getSourceA(), operands);
1335 adaptor.getSourceB(), op.getSourceB(), operands);
1337 op.getSubwordOffset(), op.getClamp(), operands);
1339 loweredOp.addOperands(operands);
1340 Operation *lowered = rewriter.create(loweredOp);
1342 Operation *maybeCastBack = lowered;
1343 if (rawOutType != outType)
1344 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1346 rewriter.replaceOp(op, maybeCastBack->
getResults());
1352struct TransposeLoadOpLowering
1354 TransposeLoadOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1355 : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1360 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1361 ConversionPatternRewriter &rewriter)
const override {
1363 return op.emitOpError(
"Non-gfx950 chipset not supported");
1365 Location loc = op.getLoc();
1366 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1370 size_t srcElementSize =
1371 srcMemRefType.getElementType().getIntOrFloatBitWidth();
1372 if (srcElementSize < 8)
1373 return op.emitOpError(
"Expect source memref to have at least 8 bits "
1374 "element size, got ")
1377 auto resultType = cast<VectorType>(op.getResult().getType());
1380 (adaptor.getSrcIndices()));
1382 size_t numElements = resultType.getNumElements();
1383 size_t elementTypeSize =
1384 resultType.getElementType().getIntOrFloatBitWidth();
1388 Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
1389 rewriter.getIntegerType(32));
1390 Type llvmResultType = typeConverter->convertType(resultType);
1392 switch (elementTypeSize) {
1394 assert(numElements == 16);
1395 auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
1396 rocdlResultType, srcPtr);
1397 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1401 assert(numElements == 16);
1402 auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
1403 rocdlResultType, srcPtr);
1404 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1408 assert(numElements == 8);
1409 auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
1410 rocdlResultType, srcPtr);
1411 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1415 assert(numElements == 4);
1416 rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
1421 return op.emitOpError(
"Unsupported element size for transpose load");
1428 GatherToLDSOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1429 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
1434 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1435 ConversionPatternRewriter &rewriter)
const override {
1436 if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
1437 return op.emitOpError(
"pre-gfx9 and post-gfx10 not supported");
1439 Location loc = op.getLoc();
1441 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1442 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
1447 Type transferType = op.getTransferType();
1448 int loadWidth = [&]() ->
int {
1449 if (
auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1450 return (transferVectorType.getNumElements() *
1451 transferVectorType.getElementTypeBitWidth()) /
1458 if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
1459 return op.emitOpError(
"chipset unsupported element size");
1461 if (chipset !=
kGfx950 && llvm::is_contained({12, 16}, loadWidth))
1462 return op.emitOpError(
"Gather to LDS instructions with 12-byte and "
1463 "16-byte load widths are only supported on gfx950");
1467 (adaptor.getSrcIndices()));
1470 (adaptor.getDstIndices()));
1472 rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
1473 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
1474 rewriter.getI32IntegerAttr(0),
1483struct ExtPackedFp8OpLowering final
1485 ExtPackedFp8OpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1486 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
1491 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1492 ConversionPatternRewriter &rewriter)
const override;
1495struct PackedTrunc2xFp8OpLowering final
1497 PackedTrunc2xFp8OpLowering(
const LLVMTypeConverter &converter,
1499 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
1504 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1505 ConversionPatternRewriter &rewriter)
const override;
1508struct PackedStochRoundFp8OpLowering final
1510 PackedStochRoundFp8OpLowering(
const LLVMTypeConverter &converter,
1512 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
1517 matchAndRewrite(PackedStochRoundFp8Op op,
1518 PackedStochRoundFp8OpAdaptor adaptor,
1519 ConversionPatternRewriter &rewriter)
const override;
1522struct ScaledExtPackedOpLowering final
1524 ScaledExtPackedOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1525 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
1530 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1531 ConversionPatternRewriter &rewriter)
const override;
1534struct PackedScaledTruncOpLowering final
1536 PackedScaledTruncOpLowering(
const LLVMTypeConverter &converter,
1538 : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
1543 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1544 ConversionPatternRewriter &rewriter)
const override;
1549LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1550 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1551 ConversionPatternRewriter &rewriter)
const {
1552 Location loc = op.getLoc();
1554 return rewriter.notifyMatchFailure(
1555 loc,
"Fp8 conversion instructions are not available on target "
1556 "architecture and their emulation is not implemented");
1558 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
1559 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1560 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1562 Value source = adaptor.getSource();
1563 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1564 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1567 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1568 Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
1569 if (!sourceVecType) {
1570 longVec = LLVM::InsertElementOp::create(
1573 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1575 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1577 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1582 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1583 if (resultVecType) {
1585 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
1588 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
1593 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
1596 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
1603LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
1604 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1605 ConversionPatternRewriter &rewriter)
const {
1606 Location loc = op.getLoc();
1608 return rewriter.notifyMatchFailure(
1609 loc,
"Scaled fp conversion instructions are not available on target "
1610 "architecture and their emulation is not implemented");
1611 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1613 Value source = adaptor.getSource();
1614 Value scale = adaptor.getScale();
1616 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1617 Type sourceElemType = sourceVecType.getElementType();
1618 VectorType destVecType = cast<VectorType>(op.getResult().getType());
1619 Type destElemType = destVecType.getElementType();
1621 VectorType packedVecType;
1622 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
1623 VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
1624 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
1625 }
else if (isa<Float4E2M1FNType>(sourceElemType)) {
1626 VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
1627 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
1629 llvm_unreachable(
"invalid element type for scaled ext");
1633 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
1634 Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
1635 if (!sourceVecType) {
1636 longVec = LLVM::InsertElementOp::create(
1639 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1641 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1643 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1648 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1650 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF32())
1651 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
1652 op, destVecType, i32Source, scale, op.getIndex());
1653 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF16())
1654 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
1655 op, destVecType, i32Source, scale, op.getIndex());
1656 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isBF16())
1657 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
1658 op, destVecType, i32Source, scale, op.getIndex());
1659 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF32())
1660 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
1661 op, destVecType, i32Source, scale, op.getIndex());
1662 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF16())
1663 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
1664 op, destVecType, i32Source, scale, op.getIndex());
1665 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isBF16())
1666 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
1667 op, destVecType, i32Source, scale, op.getIndex());
1668 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF32())
1669 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
1670 op, destVecType, i32Source, scale, op.getIndex());
1671 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF16())
1672 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
1673 op, destVecType, i32Source, scale, op.getIndex());
1674 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isBF16())
1675 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
1676 op, destVecType, i32Source, scale, op.getIndex());
1683LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
1684 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1685 ConversionPatternRewriter &rewriter)
const {
1686 Location loc = op.getLoc();
1688 return rewriter.notifyMatchFailure(
1689 loc,
"Scaled fp conversion instructions are not available on target "
1690 "architecture and their emulation is not implemented");
1691 Type v2i16 = getTypeConverter()->convertType(
1692 VectorType::get(2, rewriter.getI16Type()));
1693 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1695 Type resultType = op.getResult().getType();
1697 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1698 Type sourceElemType = sourceVecType.getElementType();
1700 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
1702 Value source = adaptor.getSource();
1703 Value scale = adaptor.getScale();
1704 Value existing = adaptor.getExisting();
1706 existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
1708 existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
1710 if (sourceVecType.getNumElements() < 2) {
1712 Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
1713 VectorType v2 = VectorType::get(2, sourceElemType);
1714 source = LLVM::ZeroOp::create(rewriter, loc, v2);
1715 source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
1718 Value sourceA, sourceB;
1719 if (sourceElemType.
isF32()) {
1722 sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
1723 sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
1727 if (sourceElemType.
isF32() && isa<Float8E5M2Type>(resultElemType))
1728 result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
1729 existing, sourceA, sourceB,
1730 scale, op.getIndex());
1731 else if (sourceElemType.
isF16() && isa<Float8E5M2Type>(resultElemType))
1732 result = ROCDL::CvtScaleF32PkBf8F16Op::create(
1733 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1734 else if (sourceElemType.
isBF16() && isa<Float8E5M2Type>(resultElemType))
1735 result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
1736 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1737 else if (sourceElemType.
isF32() && isa<Float8E4M3FNType>(resultElemType))
1738 result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
1739 existing, sourceA, sourceB,
1740 scale, op.getIndex());
1741 else if (sourceElemType.
isF16() && isa<Float8E4M3FNType>(resultElemType))
1742 result = ROCDL::CvtScaleF32PkFp8F16Op::create(
1743 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1744 else if (sourceElemType.
isBF16() && isa<Float8E4M3FNType>(resultElemType))
1745 result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
1746 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1747 else if (sourceElemType.
isF32() && isa<Float4E2M1FNType>(resultElemType))
1748 result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
1749 existing, sourceA, sourceB,
1750 scale, op.getIndex());
1751 else if (sourceElemType.
isF16() && isa<Float4E2M1FNType>(resultElemType))
1752 result = ROCDL::CvtScaleF32PkFp4F16Op::create(
1753 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1754 else if (sourceElemType.
isBF16() && isa<Float4E2M1FNType>(resultElemType))
1755 result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
1756 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1760 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1761 op, getTypeConverter()->convertType(resultType),
result);
1765LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
1766 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1767 ConversionPatternRewriter &rewriter)
const {
1768 Location loc = op.getLoc();
1770 return rewriter.notifyMatchFailure(
1771 loc,
"Fp8 conversion instructions are not available on target "
1772 "architecture and their emulation is not implemented");
1773 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1775 Type resultType = op.getResult().getType();
1778 Value sourceA = adaptor.getSourceA();
1779 Value sourceB = adaptor.getSourceB();
1781 sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.
getType());
1782 Value existing = adaptor.getExisting();
1784 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
1786 existing = LLVM::UndefOp::create(rewriter, loc, i32);
1790 result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
1791 existing, op.getWordIndex());
1793 result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
1794 existing, op.getWordIndex());
1796 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1797 op, getTypeConverter()->convertType(resultType),
result);
1801LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
1802 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
1803 ConversionPatternRewriter &rewriter)
const {
1804 Location loc = op.getLoc();
1806 return rewriter.notifyMatchFailure(
1807 loc,
"Fp8 conversion instructions are not available on target "
1808 "architecture and their emulation is not implemented");
1809 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1811 Type resultType = op.getResult().getType();
1814 Value source = adaptor.getSource();
1815 Value stoch = adaptor.getStochiasticParam();
1816 Value existing = adaptor.getExisting();
1818 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
1820 existing = LLVM::UndefOp::create(rewriter, loc, i32);
1824 result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
1825 existing, op.getStoreIndex());
1827 result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
1828 existing, op.getStoreIndex());
1830 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1831 op, getTypeConverter()->convertType(resultType),
result);
1837struct AMDGPUDPPLowering :
public ConvertOpToLLVMPattern<DPPOp> {
1838 AMDGPUDPPLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1839 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
1843 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
1844 ConversionPatternRewriter &rewriter)
const override {
1847 Location loc = DppOp.getLoc();
1848 Value src = adaptor.getSrc();
1849 Value old = adaptor.getOld();
1852 Type llvmType =
nullptr;
1854 llvmType = rewriter.getI32Type();
1855 }
else if (isa<FloatType>(srcType)) {
1857 ? rewriter.getF32Type()
1858 : rewriter.getF64Type();
1859 }
else if (isa<IntegerType>(srcType)) {
1861 ? rewriter.getI32Type()
1862 : rewriter.getI64Type();
1864 auto llvmSrcIntType = typeConverter->convertType(
1868 auto convertOperand = [&](Value operand, Type operandType) {
1869 if (operandType.getIntOrFloatBitWidth() <= 16) {
1870 if (llvm::isa<FloatType>(operandType)) {
1872 LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
1874 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
1875 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
1876 Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
1878 LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
1880 operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
1885 src = convertOperand(src, srcType);
1886 old = convertOperand(old, oldType);
1889 enum DppCtrl :
unsigned {
1898 ROW_HALF_MIRROR = 0x141,
1903 auto kind = DppOp.getKind();
1904 auto permArgument = DppOp.getPermArgument();
1905 uint32_t DppCtrl = 0;
1909 case DPPPerm::quad_perm:
1910 if (
auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
1912 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
1913 uint32_t num = elem.getInt();
1914 DppCtrl |= num << (i * 2);
1919 case DPPPerm::row_shl:
1920 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
1921 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
1924 case DPPPerm::row_shr:
1925 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
1926 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
1929 case DPPPerm::row_ror:
1930 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
1931 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
1934 case DPPPerm::wave_shl:
1935 DppCtrl = DppCtrl::WAVE_SHL1;
1937 case DPPPerm::wave_shr:
1938 DppCtrl = DppCtrl::WAVE_SHR1;
1940 case DPPPerm::wave_rol:
1941 DppCtrl = DppCtrl::WAVE_ROL1;
1943 case DPPPerm::wave_ror:
1944 DppCtrl = DppCtrl::WAVE_ROR1;
1946 case DPPPerm::row_mirror:
1947 DppCtrl = DppCtrl::ROW_MIRROR;
1949 case DPPPerm::row_half_mirror:
1950 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
1952 case DPPPerm::row_bcast_15:
1953 DppCtrl = DppCtrl::BCAST15;
1955 case DPPPerm::row_bcast_31:
1956 DppCtrl = DppCtrl::BCAST31;
1962 auto rowMask = DppOp->getAttrOfType<IntegerAttr>(
"row_mask").getInt();
1963 auto bankMask = DppOp->getAttrOfType<IntegerAttr>(
"bank_mask").getInt();
1964 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>(
"bound_ctrl").getValue();
1968 ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
1969 rowMask, bankMask, boundCtrl);
1971 Value
result = dppMovOp.getRes();
1973 result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType,
result);
1974 if (!llvm::isa<IntegerType>(srcType)) {
1975 result = LLVM::BitcastOp::create(rewriter, loc, srcType,
result);
1986struct AMDGPUSwizzleBitModeLowering
1987 :
public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
1991 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
1992 ConversionPatternRewriter &rewriter)
const override {
1993 Location loc = op.getLoc();
1994 Type i32 = rewriter.getI32Type();
1995 Value src = adaptor.getSrc();
1996 SmallVector<Value> decomposed =
1998 unsigned andMask = op.getAndMask();
1999 unsigned orMask = op.getOrMask();
2000 unsigned xorMask = op.getXorMask();
2004 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
2006 SmallVector<Value> swizzled;
2007 for (Value v : decomposed) {
2009 ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
2010 swizzled.emplace_back(res);
2014 rewriter.replaceOp(op,
result);
2019struct AMDGPUPermlaneLowering :
public ConvertOpToLLVMPattern<PermlaneSwapOp> {
2022 AMDGPUPermlaneLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2023 : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
2027 matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
2028 ConversionPatternRewriter &rewriter)
const override {
2030 return op->emitOpError(
"permlane_swap is only supported on gfx950+");
2032 Location loc = op.getLoc();
2033 Type i32 = rewriter.getI32Type();
2034 Value src = adaptor.getSrc();
2035 unsigned rowLength = op.getRowLength();
2036 bool fi = op.getFetchInactive();
2037 bool boundctrl = op.getBoundCtrl();
2039 SmallVector<Value> decomposed =
2042 SmallVector<Value> permuted;
2043 for (Value v : decomposed) {
2045 Type i32pair = LLVM::LLVMStructType::getLiteral(
2046 rewriter.getContext(), {v.getType(), v.getType()});
2048 if (rowLength == 16)
2049 res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2051 else if (rowLength == 32)
2052 res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2055 llvm_unreachable(
"unsupported row length");
2057 Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
2058 Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
2060 Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
2061 LLVM::ICmpPredicate::eq, vdst0, v);
2066 LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
2067 permuted.emplace_back(vdstNew);
2071 rewriter.replaceOp(op,
result);
2076struct ConvertAMDGPUToROCDLPass
2077 :
public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
2080 void runOnOperation()
override {
2083 if (
failed(maybeChipset)) {
2084 emitError(UnknownLoc::get(ctx),
"Invalid chipset name: " + chipset);
2085 return signalPassFailure();
2089 LLVMTypeConverter converter(ctx);
2092 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
2093 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
2094 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
2095 if (
failed(applyPartialConversion(getOperation(),
target,
2097 signalPassFailure();
2104 typeConverter.addTypeAttributeConversion(
2106 -> TypeConverter::AttributeConversionResult {
2108 Type i64 = IntegerType::get(ctx, 64);
2109 switch (as.getValue()) {
2110 case amdgpu::AddressSpace::FatRawBuffer:
2111 return IntegerAttr::get(i64, 7);
2112 case amdgpu::AddressSpace::BufferRsrc:
2113 return IntegerAttr::get(i64, 8);
2114 case amdgpu::AddressSpace::FatStructuredBuffer:
2115 return IntegerAttr::get(i64, 9);
2117 return TypeConverter::AttributeConversionResult::abort();
2126 .add<FatRawBufferCastLowering,
2127 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
2128 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
2129 RawBufferOpLowering<RawBufferAtomicFaddOp,
2130 ROCDL::RawPtrBufferAtomicFaddOp>,
2131 RawBufferOpLowering<RawBufferAtomicFmaxOp,
2132 ROCDL::RawPtrBufferAtomicFmaxOp>,
2133 RawBufferOpLowering<RawBufferAtomicSmaxOp,
2134 ROCDL::RawPtrBufferAtomicSmaxOp>,
2135 RawBufferOpLowering<RawBufferAtomicUminOp,
2136 ROCDL::RawPtrBufferAtomicUminOp>,
2137 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
2138 ROCDL::RawPtrBufferAtomicCmpSwap>,
2139 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
2140 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
2141 WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
2142 PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
2143 PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
2144 TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
2145 patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsic...
static std::optional< StringRef > wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k, bool isRDNA3)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma for RDNA3/4 architectures.
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value)
constexpr Chipset kGfx908
constexpr Chipset kGfx90a
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector< Value, 4 > &operands)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value)
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVector< Value, 4 > &operands)
Push an input operand.
static std::optional< StringRef > wmmaOpToIntrinsicGfx1250(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k)
Return the rocdl intrinsic corresponding to a WMMA operation wmma for the gfx1250 architecture.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to ROCDL ...
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i64.
constexpr Chipset kGfx950
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, int64_t elementByteWidth)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static std::optional< uint32_t > mfmaTypeSelectCode(Type mlirElemType)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
This class provides a shared interface for ranked and unranked memref types.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
::mlir::Pass::Option< std::string > chipset
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
bool hasOcpFp8(const Chipset &chipset)
Include the generated interface declarations.
void populateAMDGPUMemorySpaceAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces,...
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.