22 #include "../LLVMCommon/MemRefDescriptor.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/ErrorHandling.h"
31 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
32 #include "mlir/Conversion/Passes.h.inc"
49 auto valTy = cast<IntegerType>(val.
getType());
52 return valTy.getWidth() > 32
53 ?
Value(rewriter.
create<LLVM::TruncOp>(loc, i32, val))
54 :
Value(rewriter.
create<LLVM::ZExtOp>(loc, i32, val));
60 return rewriter.
create<LLVM::ConstantOp>(loc, i32, value);
66 return rewriter.
create<LLVM::ConstantOp>(loc, llvmI1, value);
78 ShapedType::isDynamic(stride)
80 memRefDescriptor.
stride(rewriter, loc, i))
81 : rewriter.
create<LLVM::ConstantOp>(loc, i32, stride);
82 increment = rewriter.
create<LLVM::MulOp>(loc, increment, strideValue);
85 index ? rewriter.
create<LLVM::AddOp>(loc, index, increment) : increment;
94 MemRefType memrefType,
97 uint32_t elementByteWidth) {
98 if (memrefType.hasStaticShape() &&
99 !llvm::any_of(strides, ShapedType::isDynamic)) {
100 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
102 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
103 size =
std::max(shape[i] * strides[i], size);
104 size = size * elementByteWidth;
106 "the memref buffer is too large");
110 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
111 Value size = memrefDescriptor.
size(rewriter, loc, i);
112 Value stride = memrefDescriptor.
stride(rewriter, loc, i);
113 Value maxThisDim = rewriter.
create<LLVM::MulOp>(loc, size, stride);
115 ? rewriter.
create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
120 return rewriter.
create<LLVM::MulOp>(loc, maxIndexI32, byteWidthConst);
125 bool boundsCheck, amdgpu::Chipset chipset,
126 Value cacheSwizzleStride =
nullptr,
127 unsigned addressSpace = 8) {
133 if (chipset.majorVersion == 9 && chipset >=
kGfx942 && cacheSwizzleStride) {
134 Value cacheStrideZext =
135 rewriter.
create<LLVM::ZExtOp>(loc, i16, cacheSwizzleStride);
136 Value swizzleBit = rewriter.
create<LLVM::ConstantOp>(
138 stride = rewriter.
create<LLVM::OrOp>(loc, cacheStrideZext, swizzleBit,
141 stride = rewriter.
create<LLVM::ConstantOp>(loc, i16,
159 uint32_t flags = (7 << 12) | (4 << 15);
160 if (chipset.majorVersion >= 10) {
162 uint32_t oob = boundsCheck ? 3 : 2;
163 flags |= (oob << 28);
169 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
174 struct FatRawBufferCastLowering
183 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
186 Value memRef = adaptor.getSource();
187 Value unconvertedMemref = op.getSource();
188 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
192 int64_t elementByteWidth =
195 int64_t unusedOffset = 0;
197 if (failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
198 return op.emitOpError(
"Can't lower non-stride-offset memrefs");
200 Value numRecords = adaptor.getValidBytes();
202 numRecords =
getNumRecords(rewriter, loc, memrefType, descriptor,
203 strideVals, elementByteWidth);
206 adaptor.getResetOffset()
207 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
209 : descriptor.alignedPtr(rewriter, loc);
211 Value offset = adaptor.getResetOffset()
212 ? rewriter.
create<LLVM::ConstantOp>(
214 : descriptor.offset(rewriter, loc);
216 bool hasSizes = memrefType.getRank() > 0;
219 Value sizes = hasSizes ? rewriter.
create<LLVM::ExtractValueOp>(
222 Value strides = hasSizes
223 ? rewriter.
create<LLVM::ExtractValueOp>(
228 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
229 chipset, adaptor.getCacheSwizzleStride(), 7);
233 getTypeConverter()->convertType(op.getResult().getType()));
234 result = rewriter.
create<LLVM::InsertValueOp>(
236 result = rewriter.
create<LLVM::InsertValueOp>(
238 result = rewriter.
create<LLVM::InsertValueOp>(loc, result, offset,
241 result = rewriter.
create<LLVM::InsertValueOp>(loc, result, sizes,
243 result = rewriter.
create<LLVM::InsertValueOp>(
252 template <
typename GpuOp,
typename Intrinsic>
258 static constexpr uint32_t maxVectorOpWidth = 128;
261 matchAndRewrite(GpuOp gpuOp,
typename GpuOp::Adaptor adaptor,
264 Value memref = adaptor.getMemref();
265 Value unconvertedMemref = gpuOp.getMemref();
266 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
269 return gpuOp.emitOpError(
"raw buffer ops require GCN or higher");
271 Value storeData = adaptor.getODSOperands(0)[0];
272 if (storeData == memref)
276 wantedDataType = storeData.
getType();
278 wantedDataType = gpuOp.getODSResults(0)[0].getType();
283 Value maybeCmpData = adaptor.getODSOperands(1)[0];
284 if (maybeCmpData != memref)
285 atomicCmpData = maybeCmpData;
288 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
294 int64_t elementByteWidth =
303 Type llvmBufferValType = llvmWantedDataType;
305 if (
auto floatType = dyn_cast<FloatType>(wantedDataType))
306 llvmBufferValType = this->getTypeConverter()->convertType(
309 if (
auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
310 uint32_t vecLen = dataVector.getNumElements();
313 uint32_t totalBits = elemBits * vecLen;
315 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
316 if (totalBits > maxVectorOpWidth)
317 return gpuOp.emitOpError(
318 "Total width of loads or stores must be no more than " +
319 Twine(maxVectorOpWidth) +
" bits, but we call for " +
321 " bits. This should've been caught in validation");
322 if (!usePackedFp16 && elemBits < 32) {
323 if (totalBits > 32) {
324 if (totalBits % 32 != 0)
325 return gpuOp.emitOpError(
"Load or store of more than 32-bits that "
326 "doesn't fit into words. Can't happen\n");
327 llvmBufferValType = this->typeConverter->convertType(
330 llvmBufferValType = this->typeConverter->convertType(
335 if (
auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
338 if (vecType.getNumElements() == 1)
339 llvmBufferValType = vecType.getElementType();
344 if (llvmBufferValType != llvmWantedDataType) {
346 rewriter.
create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
347 args.push_back(castForStore);
349 args.push_back(storeData);
354 if (llvmBufferValType != llvmWantedDataType) {
355 Value castForCmp = rewriter.
create<LLVM::BitcastOp>(
356 loc, llvmBufferValType, atomicCmpData);
357 args.push_back(castForCmp);
359 args.push_back(atomicCmpData);
366 if (failed(memrefType.getStridesAndOffset(strides, offset)))
367 return gpuOp.emitOpError(
"Can't lower non-stride-offset memrefs");
371 Value ptr = memrefDescriptor.bufferPtr(
372 rewriter, loc, *this->getTypeConverter(), memrefType);
374 rewriter, loc, memrefType, memrefDescriptor, strides, elementByteWidth);
376 adaptor.getBoundsCheck(), chipset);
377 args.push_back(resource);
381 adaptor.getIndices(), strides);
382 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
383 indexOffset && *indexOffset > 0) {
386 voffset ? rewriter.
create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
389 voffset = rewriter.
create<LLVM::MulOp>(loc, voffset, byteWidthConst);
390 args.push_back(voffset);
393 Value sgprOffset = adaptor.getSgprOffset();
396 sgprOffset = rewriter.
create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst);
397 args.push_back(sgprOffset);
406 Operation *lowered = rewriter.
create<Intrinsic>(loc, resultTypes, args,
410 if (llvmBufferValType != llvmWantedDataType) {
411 replacement = rewriter.
create<LLVM::BitcastOp>(loc, llvmWantedDataType,
429 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
433 if (requiresInlineAsm) {
435 LLVM::AsmDialect::AD_ATT);
437 ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
438 const char *constraints =
"";
442 asmStr, constraints,
true,
449 constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
450 constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8);
453 constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4);
457 ldsOnlyBits = ldsOnlyBitsGfx11;
459 ldsOnlyBits = ldsOnlyBitsGfx10;
461 ldsOnlyBits = ldsOnlyBitsGfx6789;
463 return op.emitOpError(
464 "don't know how to lower this for chipset major version")
468 rewriter.
create<ROCDL::SWaitcntOp>(loc, ldsOnlyBits);
472 rewriter.
create<ROCDL::WaitDscntOp>(loc, 0);
473 rewriter.
create<ROCDL::BarrierSignalOp>(loc, -1);
488 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
491 (uint32_t)op.getOpts());
515 bool allowBf16 =
true) {
517 if (
auto vectorType = dyn_cast<VectorType>(inputType)) {
518 if (vectorType.getElementType().isBF16() && !allowBf16)
519 return rewriter.
create<LLVM::BitcastOp>(
520 loc, vectorType.clone(rewriter.
getI16Type()), input);
521 if (vectorType.getElementType().isInteger(8) &&
522 vectorType.getNumElements() <= 8)
523 return rewriter.
create<LLVM::BitcastOp>(
524 loc, rewriter.
getIntegerType(vectorType.getNumElements() * 8), input);
525 if (isa<IntegerType>(vectorType.getElementType()) &&
526 vectorType.getElementTypeBitWidth() <= 8) {
528 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
530 return rewriter.
create<LLVM::BitcastOp>(
551 if (
auto intType = dyn_cast<IntegerType>(inputType))
552 return rewriter.
create<LLVM::ZExtOp>(loc, outputType, input);
553 return rewriter.
create<LLVM::BitcastOp>(loc, outputType, input);
567 bool isUnsigned,
Value llvmInput,
571 auto vectorType = dyn_cast<VectorType>(inputType);
573 operands.push_back(llvmInput);
576 Type elemType = vectorType.getElementType();
579 llvmInput = rewriter.
create<LLVM::BitcastOp>(
580 loc, vectorType.clone(rewriter.
getI16Type()), llvmInput);
582 operands.push_back(llvmInput);
589 auto mlirInputType = cast<VectorType>(mlirInput.
getType());
590 bool isInputInteger = mlirInputType.getElementType().isInteger();
591 if (isInputInteger) {
593 bool localIsUnsigned = isUnsigned;
595 localIsUnsigned =
true;
597 localIsUnsigned =
false;
600 operands.push_back(sign);
606 Type intrinsicInType = numBits <= 32
609 auto llvmIntrinsicInType = typeConverter->
convertType(intrinsicInType);
611 loc, llvmIntrinsicInType, llvmInput);
616 castInput = rewriter.
create<LLVM::ZExtOp>(loc, i32, castInput);
617 operands.push_back(castInput);
630 Value output, int32_t subwordOffset,
633 auto vectorType = dyn_cast<VectorType>(inputType);
634 Type elemType = vectorType.getElementType();
636 output = rewriter.
create<LLVM::BitcastOp>(
637 loc, vectorType.clone(rewriter.
getI16Type()), output);
638 operands.push_back(output);
649 return (chipset ==
kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
650 (
hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
656 return (chipset ==
kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
657 (
hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
665 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
666 b = mfma.getBlocks();
671 if (mfma.getReducePrecision() && chipset >=
kGfx942) {
672 if (m == 32 && n == 32 && k == 4 && b == 1)
673 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
674 if (m == 16 && n == 16 && k == 8 && b == 1)
675 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
677 if (m == 32 && n == 32 && k == 1 && b == 2)
678 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
679 if (m == 16 && n == 16 && k == 1 && b == 4)
680 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
681 if (m == 4 && n == 4 && k == 1 && b == 16)
682 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
683 if (m == 32 && n == 32 && k == 2 && b == 1)
684 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
685 if (m == 16 && n == 16 && k == 4 && b == 1)
686 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
691 if (m == 32 && n == 32 && k == 16 && b == 1)
692 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
693 if (m == 16 && n == 16 && k == 32 && b == 1)
694 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
696 if (m == 32 && n == 32 && k == 4 && b == 2)
697 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
698 if (m == 16 && n == 16 && k == 4 && b == 4)
699 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
700 if (m == 4 && n == 4 && k == 4 && b == 16)
701 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
702 if (m == 32 && n == 32 && k == 8 && b == 1)
703 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
704 if (m == 16 && n == 16 && k == 16 && b == 1)
705 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
710 if (m == 32 && n == 32 && k == 16 && b == 1)
711 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
712 if (m == 16 && n == 16 && k == 32 && b == 1)
713 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
716 if (m == 32 && n == 32 && k == 4 && b == 2)
717 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
718 if (m == 16 && n == 16 && k == 4 && b == 4)
719 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
720 if (m == 4 && n == 4 && k == 4 && b == 16)
721 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
722 if (m == 32 && n == 32 && k == 8 && b == 1)
723 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
724 if (m == 16 && n == 16 && k == 16 && b == 1)
725 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
727 if (m == 32 && n == 32 && k == 2 && b == 2)
728 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
729 if (m == 16 && n == 16 && k == 2 && b == 4)
730 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
731 if (m == 4 && n == 4 && k == 2 && b == 16)
732 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
733 if (m == 32 && n == 32 && k == 4 && b == 1)
734 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
735 if (m == 16 && n == 16 && k == 8 && b == 1)
736 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
741 if (m == 32 && n == 32 && k == 32 && b == 1)
742 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
743 if (m == 16 && n == 16 && k == 64 && b == 1)
744 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
746 if (m == 32 && n == 32 && k == 4 && b == 2)
747 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
748 if (m == 16 && n == 16 && k == 4 && b == 4)
749 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
750 if (m == 4 && n == 4 && k == 4 && b == 16)
751 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
752 if (m == 32 && n == 32 && k == 8 && b == 1)
753 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
754 if (m == 16 && n == 16 && k == 16 && b == 1)
755 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
756 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >=
kGfx942)
757 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
758 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >=
kGfx942)
759 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
763 if (m == 16 && n == 16 && k == 4 && b == 1)
764 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
765 if (m == 4 && n == 4 && k == 4 && b == 4)
766 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
773 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
774 if (m == 16 && n == 16 && k == 32 && b == 1) {
776 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
778 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
780 if (m == 32 && n == 32 && k == 16 && b == 1) {
782 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
784 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
790 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
791 if (m == 16 && n == 16 && k == 32 && b == 1) {
793 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
795 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
797 if (m == 32 && n == 32 && k == 16 && b == 1) {
799 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
801 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
810 .Case([](Float8E4M3FNType) {
return 0u; })
811 .Case([](Float8E5M2Type) {
return 1u; })
812 .Case([](Float6E2M3FNType) {
return 2u; })
813 .Case([](Float6E3M2FNType) {
return 3u; })
814 .Case([](Float4E2M1FNType) {
return 4u; })
815 .Default([](
Type) {
return std::nullopt; });
825 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
827 uint32_t n, uint32_t k, uint32_t b,
Chipset chipset) {
834 if (!isa<Float32Type>(destType))
839 if (!aTypeCode || !bTypeCode)
842 if (m == 32 && n == 32 && k == 64 && b == 1)
843 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
844 *aTypeCode, *bTypeCode};
845 if (m == 16 && n == 16 && k == 128 && b == 1)
847 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
853 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
856 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
857 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
858 mfma.getBlocks(), chipset);
861 static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
864 smfma.getSourceB().getType(),
865 smfma.getDestC().getType(), smfma.getM(),
866 smfma.getN(), smfma.getK(), 1u, chipset);
874 auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
875 auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
876 auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
877 auto elemSourceType = sourceVectorType.getElementType();
878 auto elemBSourceType = sourceBVectorType.getElementType();
879 auto elemDestType = destVectorType.getElementType();
881 if (elemSourceType.isF16() && elemDestType.isF32())
882 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
883 if (elemSourceType.isBF16() && elemDestType.isF32())
884 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
885 if (elemSourceType.isF16() && elemDestType.isF16())
886 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
887 if (elemSourceType.isBF16() && elemDestType.isBF16())
888 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
889 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
890 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
892 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
893 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
896 if (isa<Float8E4M3FNType>(elemSourceType) &&
897 isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
898 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
899 if (isa<Float8E4M3FNType>(elemSourceType) &&
900 isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
901 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
902 if (isa<Float8E5M2Type>(elemSourceType) &&
903 isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
904 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
905 if (isa<Float8E5M2Type>(elemSourceType) &&
906 isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
907 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
908 if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
909 bool isWave64 = destVectorType.getNumElements() == 4;
912 bool has8Inputs = sourceVectorType.getNumElements() == 8;
913 if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
914 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
915 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
929 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
932 Type outType = typeConverter->convertType(op.getDestD().getType());
933 Type intrinsicOutType = outType;
934 if (
auto outVecType = dyn_cast<VectorType>(outType))
935 if (outVecType.getElementType().isBF16())
936 intrinsicOutType = outVecType.clone(rewriter.
getI16Type());
939 return op->emitOpError(
"MFMA only supported on gfx908+");
940 uint32_t getBlgpField =
static_cast<uint32_t
>(op.getBlgp());
941 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
943 return op.emitOpError(
"negation unsupported on older than gfx942");
945 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
948 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
950 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
951 return op.emitOpError(
"no intrinsic matching MFMA size on given chipset");
954 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
956 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
957 return op.emitOpError(
958 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
959 "be scaled as those fields are used for type information");
962 StringRef intrinsicName =
963 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
966 bool allowBf16 = [&]() {
971 return intrinsicName.contains(
"16x16x32.bf16") ||
972 intrinsicName.contains(
"32x32x16.bf16");
975 loweredOp.addTypes(intrinsicOutType);
977 rewriter, loc, adaptor.getSourceA(), allowBf16),
979 rewriter, loc, adaptor.getSourceB(), allowBf16),
980 adaptor.getDestC()});
983 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
994 if (outType != intrinsicOutType)
995 lowered = rewriter.
create<LLVM::BitcastOp>(loc, outType, lowered);
1008 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1011 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1014 return op->emitOpError(
"scaled MFMA only supported on gfx908+");
1015 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1017 if (!maybeScaledIntrinsic.has_value())
1018 return op.emitOpError(
1019 "no intrinsic matching scaled MFMA size on given chipset");
1021 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1023 loweredOp.addTypes(intrinsicOutType);
1024 loweredOp.addOperands(
1027 adaptor.getDestC()});
1032 loweredOp.addOperands(
1054 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1058 typeConverter->convertType<VectorType>(op.getDestD().getType());
1063 return op->emitOpError(
"WMMA only supported on gfx11 and gfx12");
1067 VectorType rawOutType = outType;
1068 if (outType.getElementType().
isBF16())
1069 rawOutType = outType.clone(rewriter.
getI16Type());
1073 if (!maybeIntrinsic.has_value())
1074 return op.emitOpError(
"no intrinsic matching WMMA on the given chipset");
1076 if (chipset.
majorVersion >= 12 && op.getSubwordOffset() != 0)
1077 return op.emitOpError(
"subwordOffset not supported on gfx12+");
1080 loweredOp.addTypes(rawOutType);
1084 adaptor.getSourceA(), op.getSourceA(), operands);
1086 adaptor.getSourceB(), op.getSourceB(), operands);
1088 op.getSubwordOffset(), op.getClamp(), operands);
1090 loweredOp.addOperands(operands);
1094 if (rawOutType != outType)
1096 rewriter.
create<LLVM::BitcastOp>(loc, outType, lowered->
getResult(0));
1110 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1113 return op.emitOpError(
"pre-gfx9 and post-gfx10 not supported");
1117 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1118 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
1123 Type transferType = op.getTransferType();
1124 size_t loadWidth = [&]() ->
size_t {
1125 if (
auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1126 return transferVectorType.getNumElements() *
1127 (transferVectorType.getElementTypeBitWidth() / 8);
1133 if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
1134 return op.emitOpError(
"chipset unsupported element size");
1138 (adaptor.getSrcIndices()));
1141 (adaptor.getDstIndices()));
1154 struct ExtPackedFp8OpLowering final
1162 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1166 struct PackedTrunc2xFp8OpLowering final
1175 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1179 struct PackedStochRoundFp8OpLowering final
1188 matchAndRewrite(PackedStochRoundFp8Op op,
1189 PackedStochRoundFp8OpAdaptor adaptor,
1193 struct ScaledExtPackedOpLowering final
1201 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1205 struct PackedScaledTruncOpLowering final
1214 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1220 LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1221 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1226 loc,
"Fp8 conversion instructions are not available on target "
1227 "architecture and their emulation is not implemented");
1230 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
1231 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1233 Value source = adaptor.getSource();
1234 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1235 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1238 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1239 Value longVec = rewriter.
create<LLVM::UndefOp>(loc, v4i8);
1240 if (!sourceVecType) {
1241 longVec = rewriter.
create<LLVM::InsertElementOp>(
1244 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1246 Value elem = rewriter.
create<LLVM::ExtractElementOp>(loc, source, idx);
1248 rewriter.
create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
1253 Value i32Source = rewriter.
create<LLVM::BitcastOp>(loc, i32, source);
1254 if (resultVecType) {
1274 LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
1275 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1280 loc,
"Scaled fp conversion instructions are not available on target "
1281 "architecture and their emulation is not implemented");
1282 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
1284 Value source = adaptor.getSource();
1285 Value scale = adaptor.getScale();
1287 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1288 Type sourceElemType = sourceVecType.getElementType();
1289 VectorType destVecType = cast<VectorType>(op.getResult().getType());
1290 Type destElemType = destVecType.getElementType();
1292 VectorType packedVecType;
1293 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
1295 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
1296 }
else if (isa<Float4E2M1FNType>(sourceElemType)) {
1298 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
1300 llvm_unreachable(
"invalid element type for scaled ext");
1304 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
1305 Value longVec = rewriter.
create<LLVM::ZeroOp>(loc, packedVecType);
1306 if (!sourceVecType) {
1307 longVec = rewriter.
create<LLVM::InsertElementOp>(
1310 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1312 Value elem = rewriter.
create<LLVM::ExtractElementOp>(loc, source, idx);
1314 rewriter.
create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
1319 Value i32Source = rewriter.
create<LLVM::BitcastOp>(loc, i32, source);
1321 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF32())
1323 op, destVecType, i32Source, scale, op.getIndex());
1324 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF16())
1326 op, destVecType, i32Source, scale, op.getIndex());
1327 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isBF16())
1329 op, destVecType, i32Source, scale, op.getIndex());
1330 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF32())
1332 op, destVecType, i32Source, scale, op.getIndex());
1333 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF16())
1335 op, destVecType, i32Source, scale, op.getIndex());
1336 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isBF16())
1338 op, destVecType, i32Source, scale, op.getIndex());
1339 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF32())
1341 op, destVecType, i32Source, scale, op.getIndex());
1342 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF16())
1344 op, destVecType, i32Source, scale, op.getIndex());
1345 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isBF16())
1347 op, destVecType, i32Source, scale, op.getIndex());
1354 LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
1355 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1360 loc,
"Scaled fp conversion instructions are not available on target "
1361 "architecture and their emulation is not implemented");
1362 Type v2i16 = getTypeConverter()->convertType(
1364 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
1366 Type resultType = op.getResult().getType();
1368 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1369 Type sourceElemType = sourceVecType.getElementType();
1371 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
1373 Value source = adaptor.getSource();
1374 Value scale = adaptor.getScale();
1375 Value existing = adaptor.getExisting();
1377 existing = rewriter.
create<LLVM::BitcastOp>(loc, intResultType, existing);
1379 existing = rewriter.
create<LLVM::ZeroOp>(loc, intResultType);
1381 if (sourceVecType.getNumElements() < 2) {
1383 Value elem0 = rewriter.
create<LLVM::ExtractElementOp>(loc, source, c0);
1385 source = rewriter.
create<LLVM::ZeroOp>(loc, v2);
1386 source = rewriter.
create<LLVM::InsertElementOp>(loc, source, elem0, c0);
1389 Value sourceA, sourceB;
1390 if (sourceElemType.
isF32()) {
1393 sourceA = rewriter.
create<LLVM::ExtractElementOp>(loc, source, c0);
1394 sourceB = rewriter.
create<LLVM::ExtractElementOp>(loc, source, c1);
1398 if (sourceElemType.
isF32() && isa<Float8E5M2Type>(resultElemType))
1399 result = rewriter.
create<ROCDL::CvtScaleF32PkBf8F32Op>(
1400 loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
1401 else if (sourceElemType.
isF16() && isa<Float8E5M2Type>(resultElemType))
1402 result = rewriter.
create<ROCDL::CvtScaleF32PkBf8F16Op>(
1403 loc, intResultType, existing, source, scale, op.getIndex());
1404 else if (sourceElemType.
isBF16() && isa<Float8E5M2Type>(resultElemType))
1405 result = rewriter.
create<ROCDL::CvtScaleF32PkBf8Bf16Op>(
1406 loc, intResultType, existing, source, scale, op.getIndex());
1407 else if (sourceElemType.
isF32() && isa<Float8E4M3FNType>(resultElemType))
1408 result = rewriter.
create<ROCDL::CvtScaleF32PkFp8F32Op>(
1409 loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
1410 else if (sourceElemType.
isF16() && isa<Float8E4M3FNType>(resultElemType))
1411 result = rewriter.
create<ROCDL::CvtScaleF32PkFp8F16Op>(
1412 loc, intResultType, existing, source, scale, op.getIndex());
1413 else if (sourceElemType.
isBF16() && isa<Float8E4M3FNType>(resultElemType))
1414 result = rewriter.
create<ROCDL::CvtScaleF32PkFp8Bf16Op>(
1415 loc, intResultType, existing, source, scale, op.getIndex());
1416 else if (sourceElemType.
isF32() && isa<Float4E2M1FNType>(resultElemType))
1417 result = rewriter.
create<ROCDL::CvtScaleF32PkFp4F32Op>(
1418 loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex());
1419 else if (sourceElemType.
isF16() && isa<Float4E2M1FNType>(resultElemType))
1420 result = rewriter.
create<ROCDL::CvtScaleF32PkFp4F16Op>(
1421 loc, intResultType, existing, source, scale, op.getIndex());
1422 else if (sourceElemType.
isBF16() && isa<Float4E2M1FNType>(resultElemType))
1423 result = rewriter.
create<ROCDL::CvtScaleF32PkFp4Bf16Op>(
1424 loc, intResultType, existing, source, scale, op.getIndex());
1429 op, getTypeConverter()->convertType(resultType), result);
1433 LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
1434 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1439 loc,
"Fp8 conversion instructions are not available on target "
1440 "architecture and their emulation is not implemented");
1441 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
1443 Type resultType = op.getResult().getType();
1446 Value sourceA = adaptor.getSourceA();
1447 Value sourceB = adaptor.getSourceB();
1449 sourceB = rewriter.
create<LLVM::UndefOp>(loc, sourceA.
getType());
1450 Value existing = adaptor.getExisting();
1452 existing = rewriter.
create<LLVM::BitcastOp>(loc, i32, existing);
1454 existing = rewriter.
create<LLVM::UndefOp>(loc, i32);
1458 result = rewriter.
create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
1459 existing, op.getWordIndex());
1461 result = rewriter.
create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
1462 existing, op.getWordIndex());
1465 op, getTypeConverter()->convertType(resultType), result);
1469 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
1470 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
1475 loc,
"Fp8 conversion instructions are not available on target "
1476 "architecture and their emulation is not implemented");
1477 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
1479 Type resultType = op.getResult().getType();
1482 Value source = adaptor.getSource();
1483 Value stoch = adaptor.getStochiasticParam();
1484 Value existing = adaptor.getExisting();
1486 existing = rewriter.
create<LLVM::BitcastOp>(loc, i32, existing);
1488 existing = rewriter.
create<LLVM::UndefOp>(loc, i32);
1492 result = rewriter.
create<ROCDL::CvtSrBf8F32Op>(
1493 loc, i32, source, stoch, existing, op.getStoreIndex());
1495 result = rewriter.
create<ROCDL::CvtSrFp8F32Op>(
1496 loc, i32, source, stoch, existing, op.getStoreIndex());
1499 op, getTypeConverter()->convertType(resultType), result);
1511 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
1516 Value src = adaptor.getSrc();
1517 Value old = adaptor.getOld();
1520 Type llvmType =
nullptr;
1523 }
else if (isa<FloatType>(srcType)) {
1527 }
else if (isa<IntegerType>(srcType)) {
1532 auto llvmSrcIntType = typeConverter->convertType(
1536 auto convertOperand = [&](
Value operand,
Type operandType) {
1537 if (operandType.getIntOrFloatBitWidth() <= 16) {
1538 if (llvm::isa<FloatType>(operandType)) {
1540 rewriter.
create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
1543 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
1544 Value undefVec = rewriter.
create<LLVM::UndefOp>(loc, llvmVecType);
1545 operand = rewriter.
create<LLVM::InsertElementOp>(
1547 operand = rewriter.
create<LLVM::BitcastOp>(loc, llvmType, operand);
1552 src = convertOperand(src, srcType);
1553 old = convertOperand(old, oldType);
1556 enum DppCtrl :
unsigned {
1565 ROW_HALF_MIRROR = 0x141,
1570 auto kind = DppOp.getKind();
1571 auto permArgument = DppOp.getPermArgument();
1572 uint32_t DppCtrl = 0;
1576 case DPPPerm::quad_perm:
1577 if (
auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
1579 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
1580 uint32_t num = elem.getInt();
1581 DppCtrl |= num << (i * 2);
1586 case DPPPerm::row_shl:
1587 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
1588 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
1591 case DPPPerm::row_shr:
1592 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
1593 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
1596 case DPPPerm::row_ror:
1597 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
1598 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
1601 case DPPPerm::wave_shl:
1602 DppCtrl = DppCtrl::WAVE_SHL1;
1604 case DPPPerm::wave_shr:
1605 DppCtrl = DppCtrl::WAVE_SHR1;
1607 case DPPPerm::wave_rol:
1608 DppCtrl = DppCtrl::WAVE_ROL1;
1610 case DPPPerm::wave_ror:
1611 DppCtrl = DppCtrl::WAVE_ROR1;
1613 case DPPPerm::row_mirror:
1614 DppCtrl = DppCtrl::ROW_MIRROR;
1616 case DPPPerm::row_half_mirror:
1617 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
1619 case DPPPerm::row_bcast_15:
1620 DppCtrl = DppCtrl::BCAST15;
1622 case DPPPerm::row_bcast_31:
1623 DppCtrl = DppCtrl::BCAST31;
1629 auto rowMask = DppOp->getAttrOfType<IntegerAttr>(
"row_mask").getInt();
1630 auto bankMask = DppOp->getAttrOfType<IntegerAttr>(
"bank_mask").getInt();
1631 bool boundCtrl = DppOp->getAttrOfType<
BoolAttr>(
"bound_ctrl").getValue();
1634 auto dppMovOp = rewriter.
create<ROCDL::DPPUpdateOp>(
1635 loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
1637 Value result = dppMovOp.getRes();
1639 result = rewriter.
create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
1640 if (!llvm::isa<IntegerType>(srcType)) {
1641 result = rewriter.
create<LLVM::BitcastOp>(loc, srcType, result);
1652 struct AMDGPUSwizzleBitModeLowering
1657 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
1661 Value src = adaptor.getSrc();
1664 unsigned andMask = op.getAndMask();
1665 unsigned orMask = op.getOrMask();
1666 unsigned xorMask = op.getXorMask();
1670 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
1673 for (
Value v : decomposed) {
1675 rewriter.
create<ROCDL::DsSwizzleOp>(loc, v.getType(), v, maskValue);
1676 swizzled.emplace_back(res);
1685 struct ConvertAMDGPUToROCDLPass
1686 :
public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
1689 void runOnOperation()
override {
1692 if (failed(maybeChipset)) {
1694 return signalPassFailure();
1701 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
1702 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
1703 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
1706 signalPassFailure();
1718 switch (as.getValue()) {
1719 case amdgpu::AddressSpace::FatRawBuffer:
1721 case amdgpu::AddressSpace::BufferRsrc:
1723 case amdgpu::AddressSpace::FatStructuredBuffer:
1735 .add<FatRawBufferCastLowering,
1736 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
1737 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
1738 RawBufferOpLowering<RawBufferAtomicFaddOp,
1739 ROCDL::RawPtrBufferAtomicFaddOp>,
1740 RawBufferOpLowering<RawBufferAtomicFmaxOp,
1741 ROCDL::RawPtrBufferAtomicFmaxOp>,
1742 RawBufferOpLowering<RawBufferAtomicSmaxOp,
1743 ROCDL::RawPtrBufferAtomicSmaxOp>,
1744 RawBufferOpLowering<RawBufferAtomicUminOp,
1745 ROCDL::RawPtrBufferAtomicUminOp>,
1746 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
1747 ROCDL::RawPtrBufferAtomicCmpSwap>,
1748 AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1749 MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
1750 ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
1751 PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
1752 PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
1754 patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Return the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsic...
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value)
constexpr Chipset kGfx908
constexpr Chipset kGfx90a
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector< Value, 4 > &operands)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVector< Value, 4 > &operands)
Push an input operand.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to ROCDL ...
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, uint32_t elementByteWidth)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static std::optional< uint32_t > mfmaTypeSelectCode(Type mlirElemType)
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
constexpr Chipset kGfx950
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1205::ArityGroupAndKind::Kind kind
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class provides a shared interface for ranked and unranked memref types.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getI16IntegerAttr(int16_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
static MemRefDescriptor poison(OpBuilder &builder, Location loc, Type descriptorType)
Builds IR creating a poison value of the descriptor type.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
The general result of a type attribute conversion callback, allowing for early termination.
static AttributeConversionResult abort()
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
void addTypeAttributeConversion(FnT &&callback)
Register a conversion function for attributes within types.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
Value composeValue(OpBuilder &builder, Location loc, ValueRange src, Type dstType)
Composes a set of src values into a single value of type dstType through series of bitcasts and vecto...
SmallVector< Value > decomposeValue(OpBuilder &builder, Location loc, Value src, Type dstType)
Decomposes a src value into a set of values of type dstType through series of bitcasts and vector ops...
bool hasOcpFp8(const Chipset &chipset)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Include the generated interface declarations.
void populateAMDGPUMemorySpaceAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces,...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.