22 #include "../LLVMCommon/MemRefDescriptor.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/ErrorHandling.h"
31 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
32 #include "mlir/Conversion/Passes.h.inc"
49 auto valTy = cast<IntegerType>(val.
getType());
52 return valTy.getWidth() > 32
53 ?
Value(rewriter.
create<LLVM::TruncOp>(loc, i32, val))
54 :
Value(rewriter.
create<LLVM::ZExtOp>(loc, i32, val));
60 return rewriter.
create<LLVM::ConstantOp>(loc, i32, value);
66 return rewriter.
create<LLVM::ConstantOp>(loc, llvmI1, value);
78 ShapedType::isDynamic(stride)
80 memRefDescriptor.
stride(rewriter, loc, i))
81 : rewriter.
create<LLVM::ConstantOp>(loc, i32, stride);
82 increment = rewriter.
create<LLVM::MulOp>(loc, increment, strideValue);
85 index ? rewriter.
create<LLVM::AddOp>(loc, index, increment) : increment;
94 MemRefType memrefType,
97 uint32_t elementByteWidth) {
98 if (memrefType.hasStaticShape() &&
99 !llvm::any_of(strides, ShapedType::isDynamic)) {
100 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
102 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
103 size =
std::max(shape[i] * strides[i], size);
104 size = size * elementByteWidth;
106 "the memref buffer is too large");
110 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
111 Value size = memrefDescriptor.
size(rewriter, loc, i);
112 Value stride = memrefDescriptor.
stride(rewriter, loc, i);
113 Value maxThisDim = rewriter.
create<LLVM::MulOp>(loc, size, stride);
115 ? rewriter.
create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
120 return rewriter.
create<LLVM::MulOp>(loc, maxIndexI32, byteWidthConst);
125 bool boundsCheck, amdgpu::Chipset chipset,
126 Value cacheSwizzleStride =
nullptr,
127 unsigned addressSpace = 8) {
133 if (chipset.majorVersion == 9 && chipset >=
kGfx942 && cacheSwizzleStride) {
134 Value cacheStrideZext =
135 rewriter.
create<LLVM::ZExtOp>(loc, i16, cacheSwizzleStride);
136 Value swizzleBit = rewriter.
create<LLVM::ConstantOp>(
138 stride = rewriter.
create<LLVM::OrOp>(loc, cacheStrideZext, swizzleBit,
141 stride = rewriter.
create<LLVM::ConstantOp>(loc, i16,
159 uint32_t flags = (7 << 12) | (4 << 15);
160 if (chipset.majorVersion >= 10) {
162 uint32_t oob = boundsCheck ? 3 : 2;
163 flags |= (oob << 28);
169 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
174 struct FatRawBufferCastLowering
183 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
186 Value memRef = adaptor.getSource();
187 Value unconvertedMemref = op.getSource();
188 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
192 int64_t elementByteWidth =
195 int64_t unusedOffset = 0;
197 if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
198 return op.emitOpError(
"Can't lower non-stride-offset memrefs");
200 Value numRecords = adaptor.getValidBytes();
202 numRecords =
getNumRecords(rewriter, loc, memrefType, descriptor,
203 strideVals, elementByteWidth);
206 adaptor.getResetOffset()
207 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
209 : descriptor.alignedPtr(rewriter, loc);
211 Value offset = adaptor.getResetOffset()
212 ? rewriter.
create<LLVM::ConstantOp>(
214 : descriptor.offset(rewriter, loc);
216 bool hasSizes = memrefType.getRank() > 0;
219 Value sizes = hasSizes ? rewriter.
create<LLVM::ExtractValueOp>(
222 Value strides = hasSizes
223 ? rewriter.
create<LLVM::ExtractValueOp>(
228 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
229 chipset, adaptor.getCacheSwizzleStride(), 7);
233 getTypeConverter()->convertType(op.getResult().getType()));
234 result = rewriter.
create<LLVM::InsertValueOp>(
236 result = rewriter.
create<LLVM::InsertValueOp>(
238 result = rewriter.
create<LLVM::InsertValueOp>(loc, result, offset,
241 result = rewriter.
create<LLVM::InsertValueOp>(loc, result, sizes,
243 result = rewriter.
create<LLVM::InsertValueOp>(
252 template <
typename GpuOp,
typename Intrinsic>
258 static constexpr uint32_t maxVectorOpWidth = 128;
261 matchAndRewrite(GpuOp gpuOp,
typename GpuOp::Adaptor adaptor,
264 Value memref = adaptor.getMemref();
265 Value unconvertedMemref = gpuOp.getMemref();
266 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
269 return gpuOp.emitOpError(
"raw buffer ops require GCN or higher");
271 Value storeData = adaptor.getODSOperands(0)[0];
272 if (storeData == memref)
276 wantedDataType = storeData.
getType();
278 wantedDataType = gpuOp.getODSResults(0)[0].getType();
283 Value maybeCmpData = adaptor.getODSOperands(1)[0];
284 if (maybeCmpData != memref)
285 atomicCmpData = maybeCmpData;
288 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
294 int64_t elementByteWidth =
303 Type llvmBufferValType = llvmWantedDataType;
305 if (
auto floatType = dyn_cast<FloatType>(wantedDataType))
306 llvmBufferValType = this->getTypeConverter()->convertType(
309 if (
auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
310 uint32_t vecLen = dataVector.getNumElements();
313 uint32_t totalBits = elemBits * vecLen;
315 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
316 if (totalBits > maxVectorOpWidth)
317 return gpuOp.emitOpError(
318 "Total width of loads or stores must be no more than " +
319 Twine(maxVectorOpWidth) +
" bits, but we call for " +
321 " bits. This should've been caught in validation");
322 if (!usePackedFp16 && elemBits < 32) {
323 if (totalBits > 32) {
324 if (totalBits % 32 != 0)
325 return gpuOp.emitOpError(
"Load or store of more than 32-bits that "
326 "doesn't fit into words. Can't happen\n");
327 llvmBufferValType = this->typeConverter->convertType(
330 llvmBufferValType = this->typeConverter->convertType(
335 if (
auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
338 if (vecType.getNumElements() == 1)
339 llvmBufferValType = vecType.getElementType();
344 if (llvmBufferValType != llvmWantedDataType) {
346 rewriter.
create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
347 args.push_back(castForStore);
349 args.push_back(storeData);
354 if (llvmBufferValType != llvmWantedDataType) {
355 Value castForCmp = rewriter.
create<LLVM::BitcastOp>(
356 loc, llvmBufferValType, atomicCmpData);
357 args.push_back(castForCmp);
359 args.push_back(atomicCmpData);
366 if (failed(memrefType.getStridesAndOffset(strides, offset)))
367 return gpuOp.emitOpError(
"Can't lower non-stride-offset memrefs");
371 Value ptr = memrefDescriptor.bufferPtr(
372 rewriter, loc, *this->getTypeConverter(), memrefType);
374 rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
376 adaptor.getBoundsCheck(), chipset);
377 args.push_back(resource);
381 adaptor.getIndices(), strides);
382 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
383 indexOffset && *indexOffset > 0) {
386 voffset ? rewriter.
create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
389 voffset = rewriter.
create<LLVM::MulOp>(loc, voffset, byteWidthConst);
390 args.push_back(voffset);
393 Value sgprOffset = adaptor.getSgprOffset();
396 sgprOffset = rewriter.
create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst);
397 args.push_back(sgprOffset);
406 Operation *lowered = rewriter.
create<Intrinsic>(loc, resultTypes, args,
410 if (llvmBufferValType != llvmWantedDataType) {
411 replacement = rewriter.
create<LLVM::BitcastOp>(loc, llvmWantedDataType,
429 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
433 if (requiresInlineAsm) {
435 LLVM::AsmDialect::AD_ATT);
437 ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
438 const char *constraints =
"";
442 asmStr, constraints,
true,
449 constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
450 constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8);
453 constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4);
457 ldsOnlyBits = ldsOnlyBitsGfx11;
459 ldsOnlyBits = ldsOnlyBitsGfx10;
461 ldsOnlyBits = ldsOnlyBitsGfx6789;
463 return op.emitOpError(
464 "don't know how to lower this for chipset major version")
468 rewriter.
create<ROCDL::SWaitcntOp>(loc, ldsOnlyBits);
472 rewriter.
create<ROCDL::WaitDscntOp>(loc, 0);
473 rewriter.
create<ROCDL::BarrierSignalOp>(loc, -1);
488 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
491 (uint32_t)op.getOpts());
514 if (
auto vectorType = dyn_cast<VectorType>(inputType)) {
515 if (vectorType.getElementType().isBF16())
516 return rewriter.
create<LLVM::BitcastOp>(
517 loc, vectorType.clone(rewriter.
getI16Type()), input);
518 if (vectorType.getElementType().isInteger(8) &&
519 vectorType.getNumElements() <= 8)
520 return rewriter.
create<LLVM::BitcastOp>(
521 loc, rewriter.
getIntegerType(vectorType.getNumElements() * 8), input);
522 if (isa<IntegerType>(vectorType.getElementType()) &&
523 vectorType.getElementTypeBitWidth() <= 8) {
525 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
527 return rewriter.
create<LLVM::BitcastOp>(
548 if (
auto intType = dyn_cast<IntegerType>(inputType))
549 return rewriter.
create<LLVM::ZExtOp>(loc, outputType, input);
550 return rewriter.
create<LLVM::BitcastOp>(loc, outputType, input);
564 bool isUnsigned,
Value llvmInput,
568 auto vectorType = dyn_cast<VectorType>(inputType);
570 operands.push_back(llvmInput);
573 Type elemType = vectorType.getElementType();
576 llvmInput = rewriter.
create<LLVM::BitcastOp>(
577 loc, vectorType.clone(rewriter.
getI16Type()), llvmInput);
579 operands.push_back(llvmInput);
586 auto mlirInputType = cast<VectorType>(mlirInput.
getType());
587 bool isInputInteger = mlirInputType.getElementType().isInteger();
588 if (isInputInteger) {
590 bool localIsUnsigned = isUnsigned;
592 localIsUnsigned =
true;
594 localIsUnsigned =
false;
597 operands.push_back(sign);
603 Type intrinsicInType = numBits <= 32
606 auto llvmIntrinsicInType = typeConverter->
convertType(intrinsicInType);
608 loc, llvmIntrinsicInType, llvmInput);
613 castInput = rewriter.
create<LLVM::ZExtOp>(loc, i32, castInput);
614 operands.push_back(castInput);
627 Value output, int32_t subwordOffset,
630 auto vectorType = dyn_cast<VectorType>(inputType);
631 Type elemType = vectorType.getElementType();
633 output = rewriter.
create<LLVM::BitcastOp>(
634 loc, vectorType.clone(rewriter.
getI16Type()), output);
635 operands.push_back(output);
646 return (chipset ==
kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
647 (
hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
653 return (chipset ==
kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
654 (
hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
662 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
663 b = mfma.getBlocks();
668 if (mfma.getReducePrecision() && chipset >=
kGfx942) {
669 if (m == 32 && n == 32 && k == 4 && b == 1)
670 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
671 if (m == 16 && n == 16 && k == 8 && b == 1)
672 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
674 if (m == 32 && n == 32 && k == 1 && b == 2)
675 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
676 if (m == 16 && n == 16 && k == 1 && b == 4)
677 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
678 if (m == 4 && n == 4 && k == 1 && b == 16)
679 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
680 if (m == 32 && n == 32 && k == 2 && b == 1)
681 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
682 if (m == 16 && n == 16 && k == 4 && b == 1)
683 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
688 if (m == 32 && n == 32 && k == 16 && b == 1)
689 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
690 if (m == 16 && n == 16 && k == 32 && b == 1)
691 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
693 if (m == 32 && n == 32 && k == 4 && b == 2)
694 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
695 if (m == 16 && n == 16 && k == 4 && b == 4)
696 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
697 if (m == 4 && n == 4 && k == 4 && b == 16)
698 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
699 if (m == 32 && n == 32 && k == 8 && b == 1)
700 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
701 if (m == 16 && n == 16 && k == 16 && b == 1)
702 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
707 if (m == 32 && n == 32 && k == 16 && b == 1)
708 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
709 if (m == 16 && n == 16 && k == 32 && b == 1)
710 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
713 if (m == 32 && n == 32 && k == 4 && b == 2)
714 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
715 if (m == 16 && n == 16 && k == 4 && b == 4)
716 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
717 if (m == 4 && n == 4 && k == 4 && b == 16)
718 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
719 if (m == 32 && n == 32 && k == 8 && b == 1)
720 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
721 if (m == 16 && n == 16 && k == 16 && b == 1)
722 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
724 if (m == 32 && n == 32 && k == 2 && b == 2)
725 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
726 if (m == 16 && n == 16 && k == 2 && b == 4)
727 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
728 if (m == 4 && n == 4 && k == 2 && b == 16)
729 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
730 if (m == 32 && n == 32 && k == 4 && b == 1)
731 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
732 if (m == 16 && n == 16 && k == 8 && b == 1)
733 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
738 if (m == 32 && n == 32 && k == 32 && b == 1)
739 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
740 if (m == 16 && n == 16 && k == 64 && b == 1)
741 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
743 if (m == 32 && n == 32 && k == 4 && b == 2)
744 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
745 if (m == 16 && n == 16 && k == 4 && b == 4)
746 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
747 if (m == 4 && n == 4 && k == 4 && b == 16)
748 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
749 if (m == 32 && n == 32 && k == 8 && b == 1)
750 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
751 if (m == 16 && n == 16 && k == 16 && b == 1)
752 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
753 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >=
kGfx942)
754 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
755 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >=
kGfx942)
756 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
760 if (m == 16 && n == 16 && k == 4 && b == 1)
761 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
762 if (m == 4 && n == 4 && k == 4 && b == 4)
763 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
770 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
771 if (m == 16 && n == 16 && k == 32 && b == 1) {
773 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
775 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
777 if (m == 32 && n == 32 && k == 16 && b == 1) {
779 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
781 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
787 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
788 if (m == 16 && n == 16 && k == 32 && b == 1) {
790 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
792 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
794 if (m == 32 && n == 32 && k == 16 && b == 1) {
796 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
798 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
807 .Case([](Float8E4M3FNType) {
return 0u; })
808 .Case([](Float8E5M2Type) {
return 1u; })
809 .Case([](Float6E2M3FNType) {
return 2u; })
810 .Case([](Float6E3M2FNType) {
return 3u; })
811 .Case([](Float4E2M1FNType) {
return 4u; })
812 .Default([](
Type) {
return std::nullopt; });
822 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
824 uint32_t n, uint32_t k, uint32_t b,
Chipset chipset) {
831 if (!isa<Float32Type>(destType))
836 if (!aTypeCode || !bTypeCode)
839 if (m == 32 && n == 32 && k == 64 && b == 1)
840 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
841 *aTypeCode, *bTypeCode};
842 if (m == 16 && n == 16 && k == 128 && b == 1)
844 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
850 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
853 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
854 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
855 mfma.getBlocks(), chipset);
858 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
861 smfma.getSourceB().getType(),
862 smfma.getDestC().getType(), smfma.getM(),
863 smfma.getN(), smfma.getK(), 1u, chipset);
871 auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
872 auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
873 auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
874 auto elemSourceType = sourceVectorType.getElementType();
875 auto elemBSourceType = sourceBVectorType.getElementType();
876 auto elemDestType = destVectorType.getElementType();
878 if (elemSourceType.isF16() && elemDestType.isF32())
879 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
880 if (elemSourceType.isBF16() && elemDestType.isF32())
881 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
882 if (elemSourceType.isF16() && elemDestType.isF16())
883 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
884 if (elemSourceType.isBF16() && elemDestType.isBF16())
885 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
886 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
887 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
889 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
890 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
893 if (isa<Float8E4M3FNType>(elemSourceType) &&
894 isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
895 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
896 if (isa<Float8E4M3FNType>(elemSourceType) &&
897 isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
898 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
899 if (isa<Float8E5M2Type>(elemSourceType) &&
900 isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
901 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
902 if (isa<Float8E5M2Type>(elemSourceType) &&
903 isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
904 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
905 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
906 bool isWave64 = destVectorType.getNumElements() == 4;
909 bool has8Inputs = sourceVectorType.getNumElements() == 8;
910 if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
911 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
912 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
926 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
929 Type outType = typeConverter->convertType(op.getDestD().getType());
930 Type intrinsicOutType = outType;
931 if (
auto outVecType = dyn_cast<VectorType>(outType))
932 if (outVecType.getElementType().isBF16())
933 intrinsicOutType = outVecType.clone(rewriter.
getI16Type());
936 return op->emitOpError(
"MFMA only supported on gfx908+");
937 uint32_t getBlgpField =
static_cast<uint32_t
>(op.getBlgp());
938 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
940 return op.emitOpError(
"negation unsupported on older than gfx942");
942 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
945 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
947 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
948 return op.emitOpError(
"no intrinsic matching MFMA size on given chipset");
951 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
953 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
954 return op.emitOpError(
955 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
956 "be scaled as those fields are used for type information");
959 StringRef intrinsicName =
960 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
962 loweredOp.addTypes(intrinsicOutType);
963 loweredOp.addOperands(
966 adaptor.getDestC()});
969 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
980 if (outType != intrinsicOutType)
981 lowered = rewriter.
create<LLVM::BitcastOp>(loc, outType, lowered);
994 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
997 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1000 return op->emitOpError(
"scaled MFMA only supported on gfx908+");
1001 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1003 if (!maybeScaledIntrinsic.has_value())
1004 return op.emitOpError(
1005 "no intrinsic matching scaled MFMA size on given chipset");
1007 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1009 loweredOp.addTypes(intrinsicOutType);
1010 loweredOp.addOperands(
1013 adaptor.getDestC()});
1018 loweredOp.addOperands(
1040 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1044 typeConverter->convertType<VectorType>(op.getDestD().getType());
1049 return op->emitOpError(
"WMMA only supported on gfx11 and gfx12");
1053 VectorType rawOutType = outType;
1054 if (outType.getElementType().
isBF16())
1055 rawOutType = outType.clone(rewriter.
getI16Type());
1059 if (!maybeIntrinsic.has_value())
1060 return op.emitOpError(
"no intrinsic matching WMMA on the given chipset");
1062 if (chipset.
majorVersion >= 12 && op.getSubwordOffset() != 0)
1063 return op.emitOpError(
"subwordOffset not supported on gfx12+");
1066 loweredOp.addTypes(rawOutType);
1070 adaptor.getSourceA(), op.getSourceA(), operands);
1072 adaptor.getSourceB(), op.getSourceB(), operands);
1074 op.getSubwordOffset(), op.getClamp(), operands);
1076 loweredOp.addOperands(operands);
1080 if (rawOutType != outType)
1082 rewriter.
create<LLVM::BitcastOp>(loc, outType, lowered->
getResult(0));
1096 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1099 return op.emitOpError(
"pre-gfx9 and post-gfx10 not supported");
1103 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1104 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
1109 Type transferType = op.getTransferType();
1110 size_t loadWidth = [&]() ->
size_t {
1111 if (
auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1112 return transferVectorType.getNumElements() *
1113 (transferVectorType.getElementTypeBitWidth() / 8);
1119 if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
1120 return op.emitOpError(
"chipset unsupported element size");
1124 (adaptor.getSrcIndices()));
1127 (adaptor.getDstIndices()));
1140 struct ExtPackedFp8OpLowering final
1148 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1152 struct PackedTrunc2xFp8OpLowering final
1161 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1165 struct PackedStochRoundFp8OpLowering final
1174 matchAndRewrite(PackedStochRoundFp8Op op,
1175 PackedStochRoundFp8OpAdaptor adaptor,
1179 struct ScaledExtPackedOpLowering final
1187 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1191 struct PackedScaledTruncOpLowering final
1200 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1206 LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1207 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1212 loc,
"Fp8 conversion instructions are not available on target "
1213 "architecture and their emulation is not implemented");
1216 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
1217 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1219 Value source = adaptor.getSource();
1220 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1221 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1224 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1225 Value longVec = rewriter.
create<LLVM::UndefOp>(loc, v4i8);
1226 if (!sourceVecType) {
1227 longVec = rewriter.
create<LLVM::InsertElementOp>(
1230 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1232 Value elem = rewriter.
create<LLVM::ExtractElementOp>(loc, source, idx);
1234 rewriter.
create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
1239 Value i32Source = rewriter.
create<LLVM::BitcastOp>(loc, i32, source);
1240 if (resultVecType) {
1260 LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
1261 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1266 loc,
"Scaled fp conversion instructions are not available on target "
1267 "architecture and their emulation is not implemented");
1268 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
1270 Value source = adaptor.getSource();
1271 Value scale = adaptor.getScale();
1273 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1274 Type sourceElemType = sourceVecType.getElementType();
1275 VectorType destVecType = cast<VectorType>(op.getResult().getType());
1276 Type destElemType = destVecType.getElementType();
1278 VectorType packedVecType;
1279 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
1281 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
1282 }
else if (isa<Float4E2M1FNType>(sourceElemType)) {
1284 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
1286 llvm_unreachable(
"invalid element type for scaled ext");
1290 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
1291 Value longVec = rewriter.
create<LLVM::ZeroOp>(loc, packedVecType);
1292 if (!sourceVecType) {
1293 longVec = rewriter.
create<LLVM::InsertElementOp>(
1296 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1298 Value elem = rewriter.
create<LLVM::ExtractElementOp>(loc, source, idx);
1300 rewriter.
create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
1305 Value i32Source = rewriter.
create<LLVM::BitcastOp>(loc, i32, source);
1307 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF32())
1309 op, destVecType, i32Source, scale, op.getIndex());
1310 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF16())
1312 op, destVecType, i32Source, scale, op.getIndex());
1313 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isBF16())
1315 op, destVecType, i32Source, scale, op.getIndex());
1316 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF32())
1318 op, destVecType, i32Source, scale, op.getIndex());
1319 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF16())
1321 op, destVecType, i32Source, scale, op.getIndex());
1322 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isBF16())
1324 op, destVecType, i32Source, scale, op.getIndex());
1325 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF32())
1327 op, destVecType, i32Source, scale, op.getIndex());
1328 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF16())
1330 op, destVecType, i32Source, scale, op.getIndex());
1331 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isBF16())
1333 op, destVecType, i32Source, scale, op.getIndex());
1340 LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
1341 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1346 loc,
"Scaled fp conversion instructions are not available on target "
1347 "architecture and their emulation is not implemented");
1348 Type v2i16 = getTypeConverter()->convertType(
1350 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
1352 Type resultType = op.getResult().getType();
1354 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1355 Type sourceElemType = sourceVecType.getElementType();
1357 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
1359 Value source = adaptor.getSource();
1360 Value scale = adaptor.getScale();
1361 Value existing = adaptor.getExisting();
1363 existing = rewriter.
create<LLVM::BitcastOp>(loc, intResultType, existing);
1365 existing = rewriter.
create<LLVM::ZeroOp>(loc, intResultType);
1367 if (sourceVecType.getNumElements() < 2) {
1369 Value elem0 = rewriter.
create<LLVM::ExtractElementOp>(loc, source, c0);
1371 source = rewriter.
create<LLVM::ZeroOp>(loc, v2);
1372 source = rewriter.
create<LLVM::InsertElementOp>(loc, source, elem0, c0);
1375 Value sourceA, sourceB;
1376 if (sourceElemType.
isF32()) {
1379 sourceA = rewriter.
create<LLVM::ExtractElementOp>(loc, source, c0);
1380 sourceB = rewriter.
create<LLVM::ExtractElementOp>(loc, source, c1);
1384 if (sourceElemType.
isF32() && isa<Float8E5M2Type>(resultElemType))
1385 result = rewriter.
create<ROCDL::CvtScaleF32PkBf8F32Op>(
1386 loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
1387 else if (sourceElemType.
isF16() && isa<Float8E5M2Type>(resultElemType))
1388 result = rewriter.
create<ROCDL::CvtScaleF32PkBf8F16Op>(
1389 loc, intResultType, existing, source, scale, op.getIndex());
1390 else if (sourceElemType.
isBF16() && isa<Float8E5M2Type>(resultElemType))
1391 result = rewriter.
create<ROCDL::CvtScaleF32PkBf8Bf16Op>(
1392 loc, intResultType, existing, source, scale, op.getIndex());
1393 else if (sourceElemType.
isF32() && isa<Float8E4M3FNType>(resultElemType))
1394 result = rewriter.
create<ROCDL::CvtScaleF32PkFp8F32Op>(
1395 loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
1396 else if (sourceElemType.
isF16() && isa<Float8E4M3FNType>(resultElemType))
1397 result = rewriter.
create<ROCDL::CvtScaleF32PkFp8F16Op>(
1398 loc, intResultType, existing, source, scale, op.getIndex());
1399 else if (sourceElemType.
isBF16() && isa<Float8E4M3FNType>(resultElemType))
1400 result = rewriter.
create<ROCDL::CvtScaleF32PkFp8Bf16Op>(
1401 loc, intResultType, existing, source, scale, op.getIndex());
1402 else if (sourceElemType.
isF32() && isa<Float4E2M1FNType>(resultElemType))
1403 result = rewriter.
create<ROCDL::CvtScaleF32PkFp4F32Op>(
1404 loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
1405 else if (sourceElemType.
isF16() && isa<Float4E2M1FNType>(resultElemType))
1406 result = rewriter.
create<ROCDL::CvtScaleF32PkFp4F16Op>(
1407 loc, intResultType, existing, source, scale, op.getIndex());
1408 else if (sourceElemType.
isBF16() && isa<Float4E2M1FNType>(resultElemType))
1409 result = rewriter.
create<ROCDL::CvtScaleF32PkFp4Bf16Op>(
1410 loc, intResultType, existing, source, scale, op.getIndex());
1415 op, getTypeConverter()->convertType(resultType), result);
1419 LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
1420 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1425 loc,
"Fp8 conversion instructions are not available on target "
1426 "architecture and their emulation is not implemented");
1427 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
1429 Type resultType = op.getResult().getType();
1432 Value sourceA = adaptor.getSourceA();
1433 Value sourceB = adaptor.getSourceB();
1435 sourceB = rewriter.
create<LLVM::UndefOp>(loc, sourceA.
getType());
1436 Value existing = adaptor.getExisting();
1438 existing = rewriter.
create<LLVM::BitcastOp>(loc, i32, existing);
1440 existing = rewriter.
create<LLVM::UndefOp>(loc, i32);
1444 result = rewriter.
create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
1445 existing, op.getWordIndex());
1447 result = rewriter.
create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
1448 existing, op.getWordIndex());
1451 op, getTypeConverter()->convertType(resultType), result);
1455 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
1456 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
1461 loc,
"Fp8 conversion instructions are not available on target "
1462 "architecture and their emulation is not implemented");
1463 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
1465 Type resultType = op.getResult().getType();
1468 Value source = adaptor.getSource();
1469 Value stoch = adaptor.getStochiasticParam();
1470 Value existing = adaptor.getExisting();
1472 existing = rewriter.
create<LLVM::BitcastOp>(loc, i32, existing);
1474 existing = rewriter.
create<LLVM::UndefOp>(loc, i32);
1478 result = rewriter.
create<ROCDL::CvtSrBf8F32Op>(
1479 loc, i32, source, stoch, existing, op.getStoreIndex());
1481 result = rewriter.
create<ROCDL::CvtSrFp8F32Op>(
1482 loc, i32, source, stoch, existing, op.getStoreIndex());
1485 op, getTypeConverter()->convertType(resultType), result);
1497 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
1502 Value src = adaptor.getSrc();
1503 Value old = adaptor.getOld();
1506 Type llvmType =
nullptr;
1509 }
else if (isa<FloatType>(srcType)) {
1513 }
else if (isa<IntegerType>(srcType)) {
1518 auto llvmSrcIntType = typeConverter->convertType(
1522 auto convertOperand = [&](
Value operand,
Type operandType) {
1523 if (operandType.getIntOrFloatBitWidth() <= 16) {
1524 if (llvm::isa<FloatType>(operandType)) {
1526 rewriter.
create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
1529 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
1530 Value undefVec = rewriter.
create<LLVM::UndefOp>(loc, llvmVecType);
1531 operand = rewriter.
create<LLVM::InsertElementOp>(
1533 operand = rewriter.
create<LLVM::BitcastOp>(loc, llvmType, operand);
1538 src = convertOperand(src, srcType);
1539 old = convertOperand(old, oldType);
1542 enum DppCtrl :
unsigned {
1551 ROW_HALF_MIRROR = 0x141,
1556 auto kind = DppOp.getKind();
1557 auto permArgument = DppOp.getPermArgument();
1558 uint32_t DppCtrl = 0;
1562 case DPPPerm::quad_perm:
1563 if (
auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
1565 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
1566 uint32_t num = elem.getInt();
1567 DppCtrl |= num << (i * 2);
1572 case DPPPerm::row_shl:
1573 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
1574 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
1577 case DPPPerm::row_shr:
1578 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
1579 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
1582 case DPPPerm::row_ror:
1583 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
1584 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
1587 case DPPPerm::wave_shl:
1588 DppCtrl = DppCtrl::WAVE_SHL1;
1590 case DPPPerm::wave_shr:
1591 DppCtrl = DppCtrl::WAVE_SHR1;
1593 case DPPPerm::wave_rol:
1594 DppCtrl = DppCtrl::WAVE_ROL1;
1596 case DPPPerm::wave_ror:
1597 DppCtrl = DppCtrl::WAVE_ROR1;
1599 case DPPPerm::row_mirror:
1600 DppCtrl = DppCtrl::ROW_MIRROR;
1602 case DPPPerm::row_half_mirror:
1603 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
1605 case DPPPerm::row_bcast_15:
1606 DppCtrl = DppCtrl::BCAST15;
1608 case DPPPerm::row_bcast_31:
1609 DppCtrl = DppCtrl::BCAST31;
1615 auto rowMask = DppOp->getAttrOfType<IntegerAttr>(
"row_mask").getInt();
1616 auto bankMask = DppOp->getAttrOfType<IntegerAttr>(
"bank_mask").getInt();
1617 bool boundCtrl = DppOp->getAttrOfType<
BoolAttr>(
"bound_ctrl").getValue();
1620 auto dppMovOp = rewriter.
create<ROCDL::DPPUpdateOp>(
1621 loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
1623 Value result = dppMovOp.getRes();
1625 result = rewriter.
create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
1626 if (!llvm::isa<IntegerType>(srcType)) {
1627 result = rewriter.
create<LLVM::BitcastOp>(loc, srcType, result);
1638 struct AMDGPUSwizzleBitModeLowering
1643 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
1647 Value src = adaptor.getSrc();
1650 unsigned andMask = op.getAndMask();
1651 unsigned orMask = op.getOrMask();
1652 unsigned xorMask = op.getXorMask();
1656 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
1659 for (
Value v : decomposed) {
1661 rewriter.
create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue);
1662 swizzled.emplace_back(res);
1671 struct ConvertAMDGPUToROCDLPass
1672 :
public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
1675 void runOnOperation()
override {
1678 if (failed(maybeChipset)) {
1680 return signalPassFailure();
1687 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
1688 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
1689 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
1692 signalPassFailure();
1704 switch (as.getValue()) {
1705 case amdgpu::AddressSpace::FatRawBuffer:
1707 case amdgpu::AddressSpace::BufferRsrc:
1709 case amdgpu::AddressSpace::FatStructuredBuffer:
1721 .add<FatRawBufferCastLowering,
1722 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
1723 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
1724 RawBufferOpLowering<RawBufferAtomicFaddOp,
1725 ROCDL::RawPtrBufferAtomicFaddOp>,
1726 RawBufferOpLowering<RawBufferAtomicFmaxOp,
1727 ROCDL::RawPtrBufferAtomicFmaxOp>,
1728 RawBufferOpLowering<RawBufferAtomicSmaxOp,
1729 ROCDL::RawPtrBufferAtomicSmaxOp>,
1730 RawBufferOpLowering<RawBufferAtomicUminOp,
1731 ROCDL::RawPtrBufferAtomicUminOp>,
1732 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
1733 ROCDL::RawPtrBufferAtomicCmpSwap>,
1734 AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1735 MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
1736 ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
1737 PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
1738 PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
1740 patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Return the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value)
constexpr Chipset kGfx908
constexpr Chipset kGfx90a
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector< Value, 4 > &operands)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVector< Value, 4 > &operands)
Push an input operand.
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsic...
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to ROCDL ...
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, uint32_t elementByteWidth)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static std::optional< uint32_t > mfmaTypeSelectCode(Type mlirElemType)
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
constexpr Chipset kGfx950
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1204::ArityGroupAndKind::Kind kind
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class provides a shared interface for ranked and unranked memref types.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getI16IntegerAttr(int16_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult abort()
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
bool hasOcpFp8(const Chipset &chipset)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Include the generated interface declarations.
void populateAMDGPUMemorySpaceAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces,...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.