23 #include "../LLVMCommon/MemRefDescriptor.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/ErrorHandling.h"
32 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
33 #include "mlir/Conversion/Passes.h.inc"
50 auto valTy = cast<IntegerType>(val.
getType());
53 return valTy.getWidth() > 32
54 ?
Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
55 :
Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
60 return LLVM::ConstantOp::create(rewriter, loc, rewriter.
getI32Type(), value);
68 auto valTy = cast<IntegerType>(val.
getType());
71 return valTy.getWidth() > 64
72 ?
Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
73 :
Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
78 return LLVM::ConstantOp::create(rewriter, loc, rewriter.
getI64Type(), value);
84 return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value);
96 ShapedType::isDynamic(stride)
98 memRefDescriptor.
stride(rewriter, loc, i))
99 : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
100 increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
102 index = index ? LLVM::AddOp::create(rewriter, loc, index, increment)
112 MemRefType memrefType,
115 int64_t elementByteWidth) {
116 if (memrefType.hasStaticShape() &&
117 !llvm::any_of(strides, ShapedType::isDynamic)) {
118 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
120 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
121 size =
std::max(shape[i] * strides[i], size);
122 size = size * elementByteWidth;
126 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
127 Value size = memrefDescriptor.
size(rewriter, loc, i);
128 Value stride = memrefDescriptor.
stride(rewriter, loc, i);
129 Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
131 ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
136 return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
141 bool boundsCheck, amdgpu::Chipset chipset,
142 Value cacheSwizzleStride =
nullptr,
143 unsigned addressSpace = 8) {
149 if (chipset.majorVersion == 9 && chipset >=
kGfx942 && cacheSwizzleStride) {
150 Value cacheStrideZext =
151 LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
152 Value swizzleBit = LLVM::ConstantOp::create(
154 stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
157 stride = LLVM::ConstantOp::create(rewriter, loc, i16,
175 uint32_t flags = (7 << 12) | (4 << 15);
176 if (chipset.majorVersion >= 10) {
178 uint32_t oob = boundsCheck ? 3 : 2;
179 flags |= (oob << 28);
185 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
190 struct FatRawBufferCastLowering
199 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
202 Value memRef = adaptor.getSource();
203 Value unconvertedMemref = op.getSource();
204 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
208 int64_t elementByteWidth =
211 int64_t unusedOffset = 0;
213 if (
failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
214 return op.emitOpError(
"Can't lower non-stride-offset memrefs");
216 Value numRecords = adaptor.getValidBytes();
218 numRecords =
getNumRecords(rewriter, loc, memrefType, descriptor,
219 strideVals, elementByteWidth);
222 adaptor.getResetOffset()
223 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
225 : descriptor.alignedPtr(rewriter, loc);
227 Value offset = adaptor.getResetOffset()
228 ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
230 : descriptor.offset(rewriter, loc);
232 bool hasSizes = memrefType.getRank() > 0;
235 Value sizes = hasSizes
236 ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
240 hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
245 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
246 chipset, adaptor.getCacheSwizzleStride(), 7);
250 getTypeConverter()->convertType(op.getResult().getType()));
252 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos);
253 result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
255 result = LLVM::InsertValueOp::create(rewriter, loc, result, offset,
258 result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes,
260 result = LLVM::InsertValueOp::create(rewriter, loc, result, strides,
269 template <
typename GpuOp,
typename Intrinsic>
275 static constexpr uint32_t maxVectorOpWidth = 128;
278 matchAndRewrite(GpuOp gpuOp,
typename GpuOp::Adaptor adaptor,
281 Value memref = adaptor.getMemref();
282 Value unconvertedMemref = gpuOp.getMemref();
283 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
286 return gpuOp.emitOpError(
"raw buffer ops require GCN or higher");
288 Value storeData = adaptor.getODSOperands(0)[0];
289 if (storeData == memref)
293 wantedDataType = storeData.
getType();
295 wantedDataType = gpuOp.getODSResults(0)[0].getType();
300 Value maybeCmpData = adaptor.getODSOperands(1)[0];
301 if (maybeCmpData != memref)
302 atomicCmpData = maybeCmpData;
305 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
311 int64_t elementByteWidth =
320 Type llvmBufferValType = llvmWantedDataType;
322 if (
auto floatType = dyn_cast<FloatType>(wantedDataType))
323 llvmBufferValType = this->getTypeConverter()->convertType(
326 if (
auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
327 uint32_t vecLen = dataVector.getNumElements();
330 uint32_t totalBits = elemBits * vecLen;
332 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
333 if (totalBits > maxVectorOpWidth)
334 return gpuOp.emitOpError(
335 "Total width of loads or stores must be no more than " +
336 Twine(maxVectorOpWidth) +
" bits, but we call for " +
338 " bits. This should've been caught in validation");
339 if (!usePackedFp16 && elemBits < 32) {
340 if (totalBits > 32) {
341 if (totalBits % 32 != 0)
342 return gpuOp.emitOpError(
"Load or store of more than 32-bits that "
343 "doesn't fit into words. Can't happen\n");
344 llvmBufferValType = this->typeConverter->convertType(
347 llvmBufferValType = this->typeConverter->convertType(
352 if (
auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
355 if (vecType.getNumElements() == 1)
356 llvmBufferValType = vecType.getElementType();
361 if (llvmBufferValType != llvmWantedDataType) {
362 Value castForStore = LLVM::BitcastOp::create(
363 rewriter, loc, llvmBufferValType, storeData);
364 args.push_back(castForStore);
366 args.push_back(storeData);
371 if (llvmBufferValType != llvmWantedDataType) {
372 Value castForCmp = LLVM::BitcastOp::create(
373 rewriter, loc, llvmBufferValType, atomicCmpData);
374 args.push_back(castForCmp);
376 args.push_back(atomicCmpData);
383 if (
failed(memrefType.getStridesAndOffset(strides, offset)))
384 return gpuOp.emitOpError(
"Can't lower non-stride-offset memrefs");
388 Value ptr = memrefDescriptor.bufferPtr(
389 rewriter, loc, *this->getTypeConverter(), memrefType);
391 rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
393 adaptor.getBoundsCheck(), chipset);
394 args.push_back(resource);
398 adaptor.getIndices(), strides);
399 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
400 indexOffset && *indexOffset > 0) {
402 voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
406 voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
407 args.push_back(voffset);
410 Value sgprOffset = adaptor.getSgprOffset();
413 sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
414 args.push_back(sgprOffset);
423 Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
427 if (llvmBufferValType != llvmWantedDataType) {
428 replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
450 static FailureOr<unsigned> encodeWaitcnt(
Chipset chipset,
unsigned vmcnt,
451 unsigned expcnt,
unsigned lgkmcnt) {
456 return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
462 unsigned lowBits = vmcnt & 0xF;
463 unsigned highBits = (vmcnt >> 4) << 14;
464 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
465 return lowBits | highBits | otherCnts;
471 unsigned lowBits = vmcnt & 0xF;
472 unsigned highBits = (vmcnt >> 4) << 14;
473 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
474 return lowBits | highBits | otherCnts;
480 return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
485 struct MemoryCounterWaitOpLowering
495 matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
499 if (std::optional<int> ds = adaptor.getDs())
500 ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
502 if (std::optional<int> load = adaptor.getLoad())
503 ROCDL::WaitLoadcntOp::create(rewriter, loc, *load);
505 if (std::optional<int> store = adaptor.getStore())
506 ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
508 if (std::optional<int> exp = adaptor.getExp())
509 ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
515 auto getVal = [](
Attribute attr) ->
unsigned {
517 return cast<IntegerAttr>(attr).getInt();
522 unsigned ds = getVal(adaptor.getDsAttr());
523 unsigned exp = getVal(adaptor.getExpAttr());
525 unsigned vmcnt = 1024;
527 Attribute store = adaptor.getStoreAttr();
529 vmcnt = getVal(load) + getVal(store);
531 vmcnt = getVal(load);
533 vmcnt = getVal(store);
536 FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
538 return op.emitOpError(
"unsupported chipset");
552 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
557 bool requiresInlineAsm = chipset <
kGfx90a;
560 rewriter.
getAttr<LLVM::MMRATagAttr>(
"amdgpu-synchronize-as",
"local");
569 StringRef scope =
"workgroup";
571 auto relFence = LLVM::FenceOp::create(rewriter, loc,
572 LLVM::AtomicOrdering::release, scope);
573 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
574 if (requiresInlineAsm) {
576 LLVM::AsmDialect::AD_ATT);
577 const char *asmStr =
";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
578 const char *constraints =
"";
579 LLVM::InlineAsmOp::create(
582 asmStr, constraints,
true,
587 ROCDL::SBarrierOp::create(rewriter, loc);
589 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
590 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
593 auto acqFence = LLVM::FenceOp::create(rewriter, loc,
594 LLVM::AtomicOrdering::acquire, scope);
595 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
608 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
611 (uint32_t)op.getOpts());
635 bool allowBf16 =
true) {
637 if (
auto vectorType = dyn_cast<VectorType>(inputType)) {
638 if (vectorType.getElementType().isBF16() && !allowBf16)
639 return LLVM::BitcastOp::create(
640 rewriter, loc, vectorType.clone(rewriter.
getI16Type()), input);
641 if (vectorType.getElementType().isInteger(8) &&
642 vectorType.getNumElements() <= 8)
643 return LLVM::BitcastOp::create(
646 if (isa<IntegerType>(vectorType.getElementType()) &&
647 vectorType.getElementTypeBitWidth() <= 8) {
649 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
651 return LLVM::BitcastOp::create(
673 if (
auto intType = dyn_cast<IntegerType>(inputType))
674 return LLVM::ZExtOp::create(rewriter, loc, outputType, input);
675 return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
689 bool isUnsigned,
Value llvmInput,
693 auto vectorType = dyn_cast<VectorType>(inputType);
695 operands.push_back(llvmInput);
698 Type elemType = vectorType.getElementType();
701 llvmInput = LLVM::BitcastOp::create(
702 rewriter, loc, vectorType.clone(rewriter.
getI16Type()), llvmInput);
704 operands.push_back(llvmInput);
711 auto mlirInputType = cast<VectorType>(mlirInput.
getType());
712 bool isInputInteger = mlirInputType.getElementType().isInteger();
713 if (isInputInteger) {
715 bool localIsUnsigned = isUnsigned;
717 localIsUnsigned =
true;
719 localIsUnsigned =
false;
722 operands.push_back(sign);
728 Type intrinsicInType = numBits <= 32
731 auto llvmIntrinsicInType = typeConverter->
convertType(intrinsicInType);
733 loc, llvmIntrinsicInType, llvmInput);
738 castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
739 operands.push_back(castInput);
752 Value output, int32_t subwordOffset,
754 Type inputType = output.getType();
755 auto vectorType = dyn_cast<VectorType>(inputType);
756 Type elemType = vectorType.getElementType();
758 output = LLVM::BitcastOp::create(
759 rewriter, loc, vectorType.clone(rewriter.
getI16Type()), output);
760 operands.push_back(output);
771 return (chipset ==
kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
772 (
hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
778 return (chipset ==
kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
779 (
hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
787 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
788 b = mfma.getBlocks();
793 if (mfma.getReducePrecision() && chipset >=
kGfx942) {
794 if (m == 32 && n == 32 && k == 4 && b == 1)
795 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
796 if (m == 16 && n == 16 && k == 8 && b == 1)
797 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
799 if (m == 32 && n == 32 && k == 1 && b == 2)
800 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
801 if (m == 16 && n == 16 && k == 1 && b == 4)
802 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
803 if (m == 4 && n == 4 && k == 1 && b == 16)
804 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
805 if (m == 32 && n == 32 && k == 2 && b == 1)
806 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
807 if (m == 16 && n == 16 && k == 4 && b == 1)
808 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
813 if (m == 32 && n == 32 && k == 16 && b == 1)
814 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
815 if (m == 16 && n == 16 && k == 32 && b == 1)
816 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
818 if (m == 32 && n == 32 && k == 4 && b == 2)
819 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
820 if (m == 16 && n == 16 && k == 4 && b == 4)
821 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
822 if (m == 4 && n == 4 && k == 4 && b == 16)
823 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
824 if (m == 32 && n == 32 && k == 8 && b == 1)
825 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
826 if (m == 16 && n == 16 && k == 16 && b == 1)
827 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
832 if (m == 32 && n == 32 && k == 16 && b == 1)
833 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
834 if (m == 16 && n == 16 && k == 32 && b == 1)
835 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
838 if (m == 32 && n == 32 && k == 4 && b == 2)
839 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
840 if (m == 16 && n == 16 && k == 4 && b == 4)
841 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
842 if (m == 4 && n == 4 && k == 4 && b == 16)
843 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
844 if (m == 32 && n == 32 && k == 8 && b == 1)
845 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
846 if (m == 16 && n == 16 && k == 16 && b == 1)
847 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
849 if (m == 32 && n == 32 && k == 2 && b == 2)
850 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
851 if (m == 16 && n == 16 && k == 2 && b == 4)
852 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
853 if (m == 4 && n == 4 && k == 2 && b == 16)
854 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
855 if (m == 32 && n == 32 && k == 4 && b == 1)
856 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
857 if (m == 16 && n == 16 && k == 8 && b == 1)
858 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
863 if (m == 32 && n == 32 && k == 32 && b == 1)
864 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
865 if (m == 16 && n == 16 && k == 64 && b == 1)
866 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
868 if (m == 32 && n == 32 && k == 4 && b == 2)
869 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
870 if (m == 16 && n == 16 && k == 4 && b == 4)
871 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
872 if (m == 4 && n == 4 && k == 4 && b == 16)
873 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
874 if (m == 32 && n == 32 && k == 8 && b == 1)
875 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
876 if (m == 16 && n == 16 && k == 16 && b == 1)
877 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
878 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >=
kGfx942)
879 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
880 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >=
kGfx942)
881 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
885 if (m == 16 && n == 16 && k == 4 && b == 1)
886 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
887 if (m == 4 && n == 4 && k == 4 && b == 4)
888 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
895 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
896 if (m == 16 && n == 16 && k == 32 && b == 1) {
898 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
900 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
902 if (m == 32 && n == 32 && k == 16 && b == 1) {
904 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
906 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
912 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
913 if (m == 16 && n == 16 && k == 32 && b == 1) {
915 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
917 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
919 if (m == 32 && n == 32 && k == 16 && b == 1) {
921 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
923 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
932 .Case([](Float8E4M3FNType) {
return 0u; })
933 .Case([](Float8E5M2Type) {
return 1u; })
934 .Case([](Float6E2M3FNType) {
return 2u; })
935 .Case([](Float6E3M2FNType) {
return 3u; })
936 .Case([](Float4E2M1FNType) {
return 4u; })
937 .Default([](
Type) {
return std::nullopt; });
947 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
949 uint32_t n, uint32_t k, uint32_t b,
Chipset chipset) {
956 if (!isa<Float32Type>(destType))
961 if (!aTypeCode || !bTypeCode)
964 if (m == 32 && n == 32 && k == 64 && b == 1)
965 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
966 *aTypeCode, *bTypeCode};
967 if (m == 16 && n == 16 && k == 128 && b == 1)
969 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
975 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
978 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
979 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
980 mfma.getBlocks(), chipset);
983 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
986 smfma.getSourceB().getType(),
987 smfma.getDestC().getType(), smfma.getM(),
988 smfma.getN(), smfma.getK(), 1u, chipset);
996 auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
997 auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
998 auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
999 auto elemSourceType = sourceVectorType.getElementType();
1000 auto elemBSourceType = sourceBVectorType.getElementType();
1001 auto elemDestType = destVectorType.getElementType();
1003 if (elemSourceType.isF16() && elemDestType.isF32())
1004 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
1005 if (elemSourceType.isBF16() && elemDestType.isF32())
1006 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
1007 if (elemSourceType.isF16() && elemDestType.isF16())
1008 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1009 if (elemSourceType.isBF16() && elemDestType.isBF16())
1010 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1011 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
1012 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1014 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1015 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1018 if (isa<Float8E4M3FNType>(elemSourceType) &&
1019 isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
1020 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1021 if (isa<Float8E4M3FNType>(elemSourceType) &&
1022 isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
1023 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1024 if (isa<Float8E5M2Type>(elemSourceType) &&
1025 isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
1026 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1027 if (isa<Float8E5M2Type>(elemSourceType) &&
1028 isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
1029 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1030 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
1031 bool isWave64 = destVectorType.getNumElements() == 4;
1034 bool has8Inputs = sourceVectorType.getNumElements() == 8;
1035 if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
1036 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1037 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1040 return std::nullopt;
1051 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1054 Type outType = typeConverter->convertType(op.getDestD().getType());
1055 Type intrinsicOutType = outType;
1056 if (
auto outVecType = dyn_cast<VectorType>(outType))
1057 if (outVecType.getElementType().isBF16())
1058 intrinsicOutType = outVecType.clone(rewriter.
getI16Type());
1061 return op->emitOpError(
"MFMA only supported on gfx908+");
1062 uint32_t getBlgpField =
static_cast<uint32_t
>(op.getBlgp());
1063 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1065 return op.emitOpError(
"negation unsupported on older than gfx942");
1067 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1070 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1072 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1073 return op.emitOpError(
"no intrinsic matching MFMA size on given chipset");
1076 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1078 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1079 return op.emitOpError(
1080 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1081 "be scaled as those fields are used for type information");
1084 StringRef intrinsicName =
1085 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1088 bool allowBf16 = [&]() {
1093 return intrinsicName.contains(
"16x16x32.bf16") ||
1094 intrinsicName.contains(
"32x32x16.bf16");
1097 loweredOp.addTypes(intrinsicOutType);
1099 rewriter, loc, adaptor.getSourceA(), allowBf16),
1101 rewriter, loc, adaptor.getSourceB(), allowBf16),
1102 adaptor.getDestC()});
1105 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1116 if (outType != intrinsicOutType)
1117 lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1130 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1133 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1136 return op->emitOpError(
"scaled MFMA only supported on gfx908+");
1137 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1139 if (!maybeScaledIntrinsic.has_value())
1140 return op.emitOpError(
1141 "no intrinsic matching scaled MFMA size on given chipset");
1143 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1145 loweredOp.addTypes(intrinsicOutType);
1146 loweredOp.addOperands(
1149 adaptor.getDestC()});
1154 loweredOp.addOperands(
1176 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1180 typeConverter->convertType<VectorType>(op.getDestD().getType());
1185 return op->emitOpError(
"WMMA only supported on gfx11 and gfx12");
1189 VectorType rawOutType = outType;
1190 if (outType.getElementType().
isBF16())
1191 rawOutType = outType.clone(rewriter.
getI16Type());
1195 if (!maybeIntrinsic.has_value())
1196 return op.emitOpError(
"no intrinsic matching WMMA on the given chipset");
1198 if (chipset.
majorVersion >= 12 && op.getSubwordOffset() != 0)
1199 return op.emitOpError(
"subwordOffset not supported on gfx12+");
1202 loweredOp.addTypes(rawOutType);
1206 adaptor.getSourceA(), op.getSourceA(), operands);
1208 adaptor.getSourceB(), op.getSourceB(), operands);
1210 op.getSubwordOffset(), op.getClamp(), operands);
1212 loweredOp.addOperands(operands);
1216 if (rawOutType != outType)
1217 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1225 struct TransposeLoadOpLowering
1233 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1236 return op.emitOpError(
"Non-gfx950 chipset not supported");
1239 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1243 size_t srcElementSize =
1244 srcMemRefType.getElementType().getIntOrFloatBitWidth();
1245 if (srcElementSize < 8)
1246 return op.emitOpError(
"Expect source memref to have at least 8 bits "
1247 "element size, got ")
1250 auto resultType = cast<VectorType>(op.getResult().getType());
1253 (adaptor.getSrcIndices()));
1255 size_t numElements = resultType.getNumElements();
1256 size_t elementTypeSize =
1257 resultType.getElementType().getIntOrFloatBitWidth();
1263 Type llvmResultType = typeConverter->convertType(resultType);
1265 switch (elementTypeSize) {
1267 assert(numElements == 16);
1268 auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
1269 rocdlResultType, srcPtr);
1274 assert(numElements == 16);
1275 auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
1276 rocdlResultType, srcPtr);
1281 assert(numElements == 8);
1282 auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
1283 rocdlResultType, srcPtr);
1288 assert(numElements == 4);
1294 return op.emitOpError(
"Unsupported element size for transpose load");
1307 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1310 return op.emitOpError(
"pre-gfx9 and post-gfx10 not supported");
1314 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1315 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
1320 Type transferType = op.getTransferType();
1321 int loadWidth = [&]() ->
int {
1322 if (
auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1323 return (transferVectorType.getNumElements() *
1324 transferVectorType.getElementTypeBitWidth()) /
1331 if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
1332 return op.emitOpError(
"chipset unsupported element size");
1334 if (chipset !=
kGfx950 && llvm::is_contained({12, 16}, loadWidth))
1335 return op.emitOpError(
"Gather to LDS instructions with 12-byte and "
1336 "16-byte load widths are only supported on gfx950");
1340 (adaptor.getSrcIndices()));
1343 (adaptor.getDstIndices()));
1356 struct ExtPackedFp8OpLowering final
1364 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1368 struct PackedTrunc2xFp8OpLowering final
1377 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1381 struct PackedStochRoundFp8OpLowering final
1390 matchAndRewrite(PackedStochRoundFp8Op op,
1391 PackedStochRoundFp8OpAdaptor adaptor,
1395 struct ScaledExtPackedOpLowering final
1403 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1407 struct PackedScaledTruncOpLowering final
1416 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1422 LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1423 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1428 loc,
"Fp8 conversion instructions are not available on target "
1429 "architecture and their emulation is not implemented");
1432 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
1433 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1435 Value source = adaptor.getSource();
1436 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1437 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1440 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1441 Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
1442 if (!sourceVecType) {
1443 longVec = LLVM::InsertElementOp::create(
1446 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1448 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1450 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1455 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1456 if (resultVecType) {
1476 LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
1477 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1482 loc,
"Scaled fp conversion instructions are not available on target "
1483 "architecture and their emulation is not implemented");
1484 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
1486 Value source = adaptor.getSource();
1487 Value scale = adaptor.getScale();
1489 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1490 Type sourceElemType = sourceVecType.getElementType();
1491 VectorType destVecType = cast<VectorType>(op.getResult().getType());
1492 Type destElemType = destVecType.getElementType();
1494 VectorType packedVecType;
1495 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
1497 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
1498 }
else if (isa<Float4E2M1FNType>(sourceElemType)) {
1500 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
1502 llvm_unreachable(
"invalid element type for scaled ext");
1506 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
1507 Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
1508 if (!sourceVecType) {
1509 longVec = LLVM::InsertElementOp::create(
1512 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1514 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1516 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1521 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1523 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF32())
1525 op, destVecType, i32Source, scale, op.getIndex());
1526 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF16())
1528 op, destVecType, i32Source, scale, op.getIndex());
1529 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isBF16())
1531 op, destVecType, i32Source, scale, op.getIndex());
1532 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF32())
1534 op, destVecType, i32Source, scale, op.getIndex());
1535 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF16())
1537 op, destVecType, i32Source, scale, op.getIndex());
1538 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isBF16())
1540 op, destVecType, i32Source, scale, op.getIndex());
1541 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF32())
1543 op, destVecType, i32Source, scale, op.getIndex());
1544 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF16())
1546 op, destVecType, i32Source, scale, op.getIndex());
1547 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isBF16())
1549 op, destVecType, i32Source, scale, op.getIndex());
1556 LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
1557 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1562 loc,
"Scaled fp conversion instructions are not available on target "
1563 "architecture and their emulation is not implemented");
1564 Type v2i16 = getTypeConverter()->convertType(
1566 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
1568 Type resultType = op.getResult().getType();
1570 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1571 Type sourceElemType = sourceVecType.getElementType();
1573 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
1575 Value source = adaptor.getSource();
1576 Value scale = adaptor.getScale();
1577 Value existing = adaptor.getExisting();
1579 existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
1581 existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
1583 if (sourceVecType.getNumElements() < 2) {
1585 Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
1587 source = LLVM::ZeroOp::create(rewriter, loc, v2);
1588 source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
1591 Value sourceA, sourceB;
1592 if (sourceElemType.
isF32()) {
1595 sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
1596 sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
1600 if (sourceElemType.
isF32() && isa<Float8E5M2Type>(resultElemType))
1601 result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
1602 existing, sourceA, sourceB,
1603 scale, op.getIndex());
1604 else if (sourceElemType.
isF16() && isa<Float8E5M2Type>(resultElemType))
1605 result = ROCDL::CvtScaleF32PkBf8F16Op::create(
1606 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1607 else if (sourceElemType.
isBF16() && isa<Float8E5M2Type>(resultElemType))
1608 result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
1609 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1610 else if (sourceElemType.
isF32() && isa<Float8E4M3FNType>(resultElemType))
1611 result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
1612 existing, sourceA, sourceB,
1613 scale, op.getIndex());
1614 else if (sourceElemType.
isF16() && isa<Float8E4M3FNType>(resultElemType))
1615 result = ROCDL::CvtScaleF32PkFp8F16Op::create(
1616 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1617 else if (sourceElemType.
isBF16() && isa<Float8E4M3FNType>(resultElemType))
1618 result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
1619 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1620 else if (sourceElemType.
isF32() && isa<Float4E2M1FNType>(resultElemType))
1621 result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
1622 existing, sourceA, sourceB,
1623 scale, op.getIndex());
1624 else if (sourceElemType.
isF16() && isa<Float4E2M1FNType>(resultElemType))
1625 result = ROCDL::CvtScaleF32PkFp4F16Op::create(
1626 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1627 else if (sourceElemType.
isBF16() && isa<Float4E2M1FNType>(resultElemType))
1628 result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
1629 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
1634 op, getTypeConverter()->convertType(resultType), result);
1638 LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
1639 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1644 loc,
"Fp8 conversion instructions are not available on target "
1645 "architecture and their emulation is not implemented");
1646 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
1648 Type resultType = op.getResult().getType();
1651 Value sourceA = adaptor.getSourceA();
1652 Value sourceB = adaptor.getSourceB();
1654 sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.
getType());
1655 Value existing = adaptor.getExisting();
1657 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
1659 existing = LLVM::UndefOp::create(rewriter, loc, i32);
1663 result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
1664 existing, op.getWordIndex());
1666 result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
1667 existing, op.getWordIndex());
1670 op, getTypeConverter()->convertType(resultType), result);
1674 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
1675 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
1680 loc,
"Fp8 conversion instructions are not available on target "
1681 "architecture and their emulation is not implemented");
1682 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
1684 Type resultType = op.getResult().getType();
1687 Value source = adaptor.getSource();
1688 Value stoch = adaptor.getStochiasticParam();
1689 Value existing = adaptor.getExisting();
1691 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
1693 existing = LLVM::UndefOp::create(rewriter, loc, i32);
1697 result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
1698 existing, op.getStoreIndex());
1700 result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
1701 existing, op.getStoreIndex());
1704 op, getTypeConverter()->convertType(resultType), result);
1716 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
1721 Value src = adaptor.getSrc();
1722 Value old = adaptor.getOld();
1725 Type llvmType =
nullptr;
1728 }
else if (isa<FloatType>(srcType)) {
1732 }
else if (isa<IntegerType>(srcType)) {
1737 auto llvmSrcIntType = typeConverter->convertType(
1741 auto convertOperand = [&](
Value operand,
Type operandType) {
1742 if (operandType.getIntOrFloatBitWidth() <= 16) {
1743 if (llvm::isa<FloatType>(operandType)) {
1745 LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
1748 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
1749 Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
1751 LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
1753 operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
1758 src = convertOperand(src, srcType);
1759 old = convertOperand(old, oldType);
1762 enum DppCtrl :
unsigned {
1771 ROW_HALF_MIRROR = 0x141,
1776 auto kind = DppOp.getKind();
1777 auto permArgument = DppOp.getPermArgument();
1778 uint32_t DppCtrl = 0;
1782 case DPPPerm::quad_perm:
1783 if (
auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
1785 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
1786 uint32_t num = elem.getInt();
1787 DppCtrl |= num << (i * 2);
1792 case DPPPerm::row_shl:
1793 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
1794 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
1797 case DPPPerm::row_shr:
1798 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
1799 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
1802 case DPPPerm::row_ror:
1803 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
1804 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
1807 case DPPPerm::wave_shl:
1808 DppCtrl = DppCtrl::WAVE_SHL1;
1810 case DPPPerm::wave_shr:
1811 DppCtrl = DppCtrl::WAVE_SHR1;
1813 case DPPPerm::wave_rol:
1814 DppCtrl = DppCtrl::WAVE_ROL1;
1816 case DPPPerm::wave_ror:
1817 DppCtrl = DppCtrl::WAVE_ROR1;
1819 case DPPPerm::row_mirror:
1820 DppCtrl = DppCtrl::ROW_MIRROR;
1822 case DPPPerm::row_half_mirror:
1823 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
1825 case DPPPerm::row_bcast_15:
1826 DppCtrl = DppCtrl::BCAST15;
1828 case DPPPerm::row_bcast_31:
1829 DppCtrl = DppCtrl::BCAST31;
1835 auto rowMask = DppOp->getAttrOfType<IntegerAttr>(
"row_mask").getInt();
1836 auto bankMask = DppOp->getAttrOfType<IntegerAttr>(
"bank_mask").getInt();
1837 bool boundCtrl = DppOp->getAttrOfType<
BoolAttr>(
"bound_ctrl").getValue();
1841 ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
1842 rowMask, bankMask, boundCtrl);
1844 Value result = dppMovOp.getRes();
1846 result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result);
1847 if (!llvm::isa<IntegerType>(srcType)) {
1848 result = LLVM::BitcastOp::create(rewriter, loc, srcType, result);
1859 struct AMDGPUSwizzleBitModeLowering
1864 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
1868 Value src = adaptor.getSrc();
1871 unsigned andMask = op.getAndMask();
1872 unsigned orMask = op.getOrMask();
1873 unsigned xorMask = op.getXorMask();
1877 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
1880 for (
Value v : decomposed) {
1882 ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
1883 swizzled.emplace_back(res);
1900 matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
1903 return op->emitOpError(
"permlane_swap is only supported on gfx950+");
1907 Value src = adaptor.getSrc();
1908 unsigned rowLength = op.getRowLength();
1909 bool fi = op.getFetchInactive();
1910 bool boundctrl = op.getBoundCtrl();
1916 for (
Value v : decomposed) {
1918 Type i32pair = LLVM::LLVMStructType::getLiteral(
1919 rewriter.
getContext(), {v.getType(), v.getType()});
1921 if (rowLength == 16)
1922 res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
1924 else if (rowLength == 32)
1925 res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
1928 llvm_unreachable(
"unsupported row length");
1930 const Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
1931 const Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
1933 const Value isEqual =
1934 rewriter.
create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, vdst0, v);
1939 rewriter.
create<LLVM::SelectOp>(loc, isEqual, vdst1, vdst0);
1940 permuted.emplace_back(vdstNew);
1949 struct ConvertAMDGPUToROCDLPass
1950 :
public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
1953 void runOnOperation()
override {
1956 if (
failed(maybeChipset)) {
1958 return signalPassFailure();
1965 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
1966 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
1967 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
1970 signalPassFailure();
1982 switch (as.getValue()) {
1983 case amdgpu::AddressSpace::FatRawBuffer:
1985 case amdgpu::AddressSpace::BufferRsrc:
1987 case amdgpu::AddressSpace::FatStructuredBuffer:
1999 .add<FatRawBufferCastLowering,
2000 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
2001 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
2002 RawBufferOpLowering<RawBufferAtomicFaddOp,
2003 ROCDL::RawPtrBufferAtomicFaddOp>,
2004 RawBufferOpLowering<RawBufferAtomicFmaxOp,
2005 ROCDL::RawPtrBufferAtomicFmaxOp>,
2006 RawBufferOpLowering<RawBufferAtomicSmaxOp,
2007 ROCDL::RawPtrBufferAtomicSmaxOp>,
2008 RawBufferOpLowering<RawBufferAtomicUminOp,
2009 ROCDL::RawPtrBufferAtomicUminOp>,
2010 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
2011 ROCDL::RawPtrBufferAtomicCmpSwap>,
2012 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
2013 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
2014 WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
2015 PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
2016 PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
2017 TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
2018 patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Return the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsic...
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value)
constexpr Chipset kGfx908
constexpr Chipset kGfx90a
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector< Value, 4 > &operands)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value)
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVector< Value, 4 > &operands)
Push an input operand.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to ROCDL ...
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i64.
static std::optional< uint32_t > mfmaTypeSelectCode(Type mlirElemType)
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
constexpr Chipset kGfx950
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, int64_t elementByteWidth)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1245::ArityGroupAndKind::Kind kind
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getI16IntegerAttr(int16_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult abort()
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
bool hasOcpFp8(const Chipset &chipset)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Include the generated interface declarations.
void populateAMDGPUMemorySpaceAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces,...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.