30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/TypeSwitch.h"
32#include "llvm/Support/AMDGPUAddrSpace.h"
33#include "llvm/Support/Casting.h"
34#include "llvm/Support/ErrorHandling.h"
39#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
40#include "mlir/Conversion/Passes.h.inc"
58 return chipset >=
Chipset(9, 0, 6);
100 if (chipset ==
Chipset(9, 5, 0))
110 IntegerType i32 = rewriter.getI32Type();
112 auto valTy = cast<IntegerType>(val.
getType());
115 return valTy.getWidth() > 32
116 ?
Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
117 :
Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
122 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
128 IntegerType i64 = rewriter.getI64Type();
130 auto valTy = cast<IntegerType>(val.
getType());
133 return valTy.getWidth() > 64
134 ?
Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
135 :
Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
140 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
147 IntegerType i32 = rewriter.getI32Type();
149 for (
auto [i, increment, stride] : llvm::enumerate(
indices, strides)) {
152 ShapedType::isDynamic(stride)
154 memRefDescriptor.
stride(rewriter, loc, i))
155 : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
156 increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
168 MemRefType memrefType,
172 if (chipset >=
kGfx1250 && !boundsCheck) {
173 constexpr int64_t first45bits = (1ll << 45) - 1;
176 if (memrefType.hasStaticShape() &&
177 !llvm::any_of(strides, ShapedType::isDynamic)) {
178 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
180 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
181 size = std::max(
shape[i] * strides[i], size);
182 size = size * elementByteWidth;
186 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
187 Value size = memrefDescriptor.
size(rewriter, loc, i);
188 Value stride = memrefDescriptor.
stride(rewriter, loc, i);
189 Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
191 ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
196 return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
202 Value cacheSwizzleStride =
nullptr,
203 unsigned addressSpace = 8) {
207 Type i16 = rewriter.getI16Type();
210 Value cacheStrideZext =
211 LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
212 Value swizzleBit = LLVM::ConstantOp::create(
213 rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
214 stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
217 stride = LLVM::ConstantOp::create(rewriter, loc, i16,
218 rewriter.getI16IntegerAttr(0));
247 flags |= (7 << 12) | (4 << 15);
250 uint32_t oob = boundsCheck ? 3 : 2;
251 flags |= (oob << 28);
256 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
257 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
258 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
263struct FatRawBufferCastLowering
265 FatRawBufferCastLowering(
const LLVMTypeConverter &converter, Chipset chipset)
266 : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
272 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
273 ConversionPatternRewriter &rewriter)
const override {
274 Location loc = op.getLoc();
275 Value memRef = adaptor.getSource();
276 Value unconvertedMemref = op.getSource();
277 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
278 MemRefDescriptor descriptor(memRef);
280 DataLayout dataLayout = DataLayout::closest(op);
281 int64_t elementByteWidth =
284 int64_t unusedOffset = 0;
285 SmallVector<int64_t, 5> strideVals;
286 if (
failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
287 return op.emitOpError(
"Can't lower non-stride-offset memrefs");
289 Value numRecords = adaptor.getValidBytes();
292 getNumRecords(rewriter, loc, memrefType, descriptor, strideVals,
293 elementByteWidth, chipset, adaptor.getBoundsCheck());
296 adaptor.getResetOffset()
297 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
299 : descriptor.alignedPtr(rewriter, loc);
301 Value offset = adaptor.getResetOffset()
302 ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
303 rewriter.getIndexAttr(0))
304 : descriptor.offset(rewriter, loc);
306 bool hasSizes = memrefType.getRank() > 0;
309 Value sizes = hasSizes
310 ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
314 hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
319 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
320 chipset, adaptor.getCacheSwizzleStride(), 7);
322 Value
result = MemRefDescriptor::poison(
324 getTypeConverter()->convertType(op.getResult().getType()));
326 result = LLVM::InsertValueOp::create(rewriter, loc,
result, fatPtr, pos);
327 result = LLVM::InsertValueOp::create(rewriter, loc,
result, fatPtr,
329 result = LLVM::InsertValueOp::create(rewriter, loc,
result, offset,
332 result = LLVM::InsertValueOp::create(rewriter, loc,
result, sizes,
334 result = LLVM::InsertValueOp::create(rewriter, loc,
result, strides,
337 rewriter.replaceOp(op,
result);
343template <
typename GpuOp,
typename Intrinsic>
345 RawBufferOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
346 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
349 static constexpr uint32_t maxVectorOpWidth = 128;
352 matchAndRewrite(GpuOp gpuOp,
typename GpuOp::Adaptor adaptor,
353 ConversionPatternRewriter &rewriter)
const override {
354 Location loc = gpuOp.getLoc();
355 Value memref = adaptor.getMemref();
356 Value unconvertedMemref = gpuOp.getMemref();
357 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
359 if (chipset.majorVersion < 9)
360 return gpuOp.emitOpError(
"raw buffer ops require GCN or higher");
362 Value storeData = adaptor.getODSOperands(0)[0];
363 if (storeData == memref)
367 wantedDataType = storeData.
getType();
369 wantedDataType = gpuOp.getODSResults(0)[0].getType();
371 Value atomicCmpData = Value();
374 Value maybeCmpData = adaptor.getODSOperands(1)[0];
375 if (maybeCmpData != memref)
376 atomicCmpData = maybeCmpData;
379 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
381 Type i32 = rewriter.getI32Type();
384 DataLayout dataLayout = DataLayout::closest(gpuOp);
385 int64_t elementByteWidth =
394 Type llvmBufferValType = llvmWantedDataType;
396 if (
auto floatType = dyn_cast<FloatType>(wantedDataType))
397 llvmBufferValType = this->getTypeConverter()->convertType(
398 rewriter.getIntegerType(floatType.getWidth()));
400 if (
auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
401 uint32_t vecLen = dataVector.getNumElements();
404 uint32_t totalBits = elemBits * vecLen;
406 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
407 if (totalBits > maxVectorOpWidth)
408 return gpuOp.emitOpError(
409 "Total width of loads or stores must be no more than " +
410 Twine(maxVectorOpWidth) +
" bits, but we call for " +
412 " bits. This should've been caught in validation");
413 if (!usePackedFp16 && elemBits < 32) {
414 if (totalBits > 32) {
415 if (totalBits % 32 != 0)
416 return gpuOp.emitOpError(
"Load or store of more than 32-bits that "
417 "doesn't fit into words. Can't happen\n");
418 llvmBufferValType = this->typeConverter->convertType(
419 VectorType::get(totalBits / 32, i32));
421 llvmBufferValType = this->typeConverter->convertType(
422 rewriter.getIntegerType(totalBits));
426 if (
auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
429 if (vecType.getNumElements() == 1)
430 llvmBufferValType = vecType.getElementType();
433 SmallVector<Value, 6> args;
435 if (llvmBufferValType != llvmWantedDataType) {
436 Value castForStore = LLVM::BitcastOp::create(
437 rewriter, loc, llvmBufferValType, storeData);
438 args.push_back(castForStore);
440 args.push_back(storeData);
445 if (llvmBufferValType != llvmWantedDataType) {
446 Value castForCmp = LLVM::BitcastOp::create(
447 rewriter, loc, llvmBufferValType, atomicCmpData);
448 args.push_back(castForCmp);
450 args.push_back(atomicCmpData);
456 SmallVector<int64_t, 5> strides;
457 if (
failed(memrefType.getStridesAndOffset(strides, offset)))
458 return gpuOp.emitOpError(
"Can't lower non-stride-offset memrefs");
460 MemRefDescriptor memrefDescriptor(memref);
462 Value ptr = memrefDescriptor.bufferPtr(
463 rewriter, loc, *this->getTypeConverter(), memrefType);
465 getNumRecords(rewriter, loc, memrefType, memrefDescriptor, strides,
466 elementByteWidth, chipset, adaptor.getBoundsCheck());
468 adaptor.getBoundsCheck(), chipset);
469 args.push_back(resource);
473 adaptor.getIndices(), strides);
474 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
475 indexOffset && *indexOffset > 0) {
477 voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
481 voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
482 args.push_back(voffset);
485 Value sgprOffset = adaptor.getSgprOffset();
488 sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
489 args.push_back(sgprOffset);
491 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
493 typename Intrinsic::Properties properties;
494 properties.aux = rewriter.getI32IntegerAttr(0);
496 Intrinsic::create(rewriter, loc, resultTypes, args, properties);
499 if (llvmBufferValType != llvmWantedDataType) {
500 replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
505 rewriter.eraseOp(gpuOp);
522static FailureOr<unsigned> encodeWaitcnt(
Chipset chipset,
unsigned vmcnt,
523 unsigned expcnt,
unsigned lgkmcnt) {
525 vmcnt = std::min(15u, vmcnt);
526 expcnt = std::min(7u, expcnt);
527 lgkmcnt = std::min(15u, lgkmcnt);
528 return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
531 vmcnt = std::min(63u, vmcnt);
532 expcnt = std::min(7u, expcnt);
533 lgkmcnt = std::min(15u, lgkmcnt);
534 unsigned lowBits = vmcnt & 0xF;
535 unsigned highBits = (vmcnt >> 4) << 14;
536 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
537 return lowBits | highBits | otherCnts;
540 vmcnt = std::min(63u, vmcnt);
541 expcnt = std::min(7u, expcnt);
542 lgkmcnt = std::min(63u, lgkmcnt);
543 unsigned lowBits = vmcnt & 0xF;
544 unsigned highBits = (vmcnt >> 4) << 14;
545 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
546 return lowBits | highBits | otherCnts;
549 vmcnt = std::min(63u, vmcnt);
550 expcnt = std::min(7u, expcnt);
551 lgkmcnt = std::min(63u, lgkmcnt);
552 return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
557struct MemoryCounterWaitOpLowering
559 MemoryCounterWaitOpLowering(
const LLVMTypeConverter &converter,
561 : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
567 matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
568 ConversionPatternRewriter &rewriter)
const override {
569 if (chipset.majorVersion >= 12) {
570 Location loc = op.getLoc();
571 if (std::optional<int> ds = adaptor.getDs())
572 ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
574 if (std::optional<int>
load = adaptor.getLoad())
575 ROCDL::WaitLoadcntOp::create(rewriter, loc, *
load);
577 if (std::optional<int> store = adaptor.getStore())
578 ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
580 if (std::optional<int> exp = adaptor.getExp())
581 ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
583 if (std::optional<int> tensor = adaptor.getTensor())
584 ROCDL::WaitTensorcntOp::create(rewriter, loc, *tensor);
586 rewriter.eraseOp(op);
590 if (adaptor.getTensor())
591 return op.emitOpError(
"unsupported chipset");
593 auto getVal = [](Attribute attr) ->
unsigned {
595 return cast<IntegerAttr>(attr).getInt();
600 unsigned ds = getVal(adaptor.getDsAttr());
601 unsigned exp = getVal(adaptor.getExpAttr());
603 unsigned vmcnt = 1024;
604 Attribute
load = adaptor.getLoadAttr();
605 Attribute store = adaptor.getStoreAttr();
607 vmcnt = getVal(
load) + getVal(store);
609 vmcnt = getVal(
load);
611 vmcnt = getVal(store);
614 FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
616 return op.emitOpError(
"unsupported chipset");
618 rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
624 LDSBarrierOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
625 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
630 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
631 ConversionPatternRewriter &rewriter)
const override {
632 Location loc = op.getLoc();
635 bool requiresInlineAsm = chipset <
kGfx90a;
638 rewriter.getAttr<LLVM::MMRATagAttr>(
"amdgpu-synchronize-as",
"local");
647 StringRef scope =
"workgroup";
649 auto relFence = LLVM::FenceOp::create(rewriter, loc,
650 LLVM::AtomicOrdering::release, scope);
651 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
652 if (requiresInlineAsm) {
653 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
654 LLVM::AsmDialect::AD_ATT);
655 const char *asmStr =
";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
656 const char *constraints =
"";
657 LLVM::InlineAsmOp::create(
660 asmStr, constraints,
true,
661 false, LLVM::TailCallKind::None,
664 }
else if (chipset.majorVersion < 12) {
665 ROCDL::SBarrierOp::create(rewriter, loc);
667 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
668 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
671 auto acqFence = LLVM::FenceOp::create(rewriter, loc,
672 LLVM::AtomicOrdering::acquire, scope);
673 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
674 rewriter.replaceOp(op, acqFence);
680 SchedBarrierOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
681 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
686 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
687 ConversionPatternRewriter &rewriter)
const override {
688 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op, op.getOptsAttr());
712 bool allowBf16 =
true) {
714 if (
auto vectorType = dyn_cast<VectorType>(inputType)) {
715 if (vectorType.getElementType().isBF16() && !allowBf16)
716 return LLVM::BitcastOp::create(
717 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
718 if (vectorType.getElementType().isInteger(8) &&
719 vectorType.getNumElements() <= 8)
720 return LLVM::BitcastOp::create(
722 rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
723 if (isa<IntegerType>(vectorType.getElementType()) &&
724 vectorType.getElementTypeBitWidth() <= 8) {
725 int64_t numWords = llvm::divideCeil(
726 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
728 return LLVM::BitcastOp::create(
729 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
739 bool allowBf16 =
true) {
741 auto vectorType = cast<VectorType>(inputType);
743 if (vectorType.getElementType().isBF16() && !allowBf16)
744 return LLVM::BitcastOp::create(
745 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
747 if (isa<IntegerType>(vectorType.getElementType()) &&
748 vectorType.getElementTypeBitWidth() <= 8) {
749 int64_t numWords = llvm::divideCeil(
750 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32);
751 Type castType = (numWords > 1)
752 ?
Type{VectorType::get(numWords, rewriter.getI32Type())}
753 : rewriter.getI32Type();
754 return LLVM::BitcastOp::create(rewriter, loc, castType, input);
772 .Case([&](IntegerType) {
774 return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(),
777 .Case([&](VectorType vectorType) {
779 int64_t numElements = vectorType.getNumElements();
780 assert((numElements == 4 || numElements == 8) &&
781 "scale operand must be a vector of length 4 or 8");
782 IntegerType outputType =
783 (numElements == 4) ? rewriter.getI32Type() : rewriter.getI64Type();
784 return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
786 .DefaultUnreachable(
"unexpected input type for scale operand");
790static std::optional<ROCDL::WMMAMatrixScaleFormat>
793 .Case([](Float8E8M0FNUType) {
return ROCDL::WMMAMatrixScaleFormat::e8; })
794 .Case([](Float8E4M3FNType) {
return ROCDL::WMMAMatrixScaleFormat::e4m3; })
795 .Default(std::nullopt);
800static std::optional<StringRef>
802 if (m == 16 && n == 16 && k == 128)
804 ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
805 : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
807 if (m == 32 && n == 16 && k == 128)
808 return isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
809 : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
823 ConversionPatternRewriter &rewriter,
Location loc,
828 auto vectorType = dyn_cast<VectorType>(inputType);
830 operands.push_back(llvmInput);
833 Type elemType = vectorType.getElementType();
835 operands.push_back(llvmInput);
842 auto mlirInputType = cast<VectorType>(mlirInput.
getType());
843 bool isInputInteger = mlirInputType.getElementType().isInteger();
844 if (isInputInteger) {
846 bool localIsUnsigned = isUnsigned;
848 localIsUnsigned =
true;
850 localIsUnsigned =
false;
853 NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
858 Type i32 = rewriter.getI32Type();
859 Type intrinsicInType = numBits <= 32
860 ? (
Type)rewriter.getIntegerType(numBits)
861 : (
Type)VectorType::get(numBits / 32, i32);
862 auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
863 Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
864 loc, llvmIntrinsicInType, llvmInput);
869 castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
870 operands.push_back(castInput);
883 Value output, int32_t subwordOffset,
887 auto vectorType = dyn_cast<VectorType>(inputType);
888 Type elemType = vectorType.getElementType();
889 operands.push_back(output);
901 return (chipset ==
kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
902 (
hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
908 return (chipset ==
kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
909 (
hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
917 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
918 b = mfma.getBlocks();
923 if (mfma.getReducePrecision() && chipset >=
kGfx942) {
924 if (m == 32 && n == 32 && k == 4 &&
b == 1)
925 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
926 if (m == 16 && n == 16 && k == 8 &&
b == 1)
927 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
929 if (m == 32 && n == 32 && k == 1 &&
b == 2)
930 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
931 if (m == 16 && n == 16 && k == 1 &&
b == 4)
932 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
933 if (m == 4 && n == 4 && k == 1 &&
b == 16)
934 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
935 if (m == 32 && n == 32 && k == 2 &&
b == 1)
936 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
937 if (m == 16 && n == 16 && k == 4 &&
b == 1)
938 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
943 if (m == 32 && n == 32 && k == 16 &&
b == 1)
944 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
945 if (m == 16 && n == 16 && k == 32 &&
b == 1)
946 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
948 if (m == 32 && n == 32 && k == 4 &&
b == 2)
949 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
950 if (m == 16 && n == 16 && k == 4 &&
b == 4)
951 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
952 if (m == 4 && n == 4 && k == 4 &&
b == 16)
953 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
954 if (m == 32 && n == 32 && k == 8 &&
b == 1)
955 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
956 if (m == 16 && n == 16 && k == 16 &&
b == 1)
957 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
962 if (m == 32 && n == 32 && k == 16 &&
b == 1)
963 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
964 if (m == 16 && n == 16 && k == 32 &&
b == 1)
965 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
968 if (m == 32 && n == 32 && k == 4 &&
b == 2)
969 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
970 if (m == 16 && n == 16 && k == 4 &&
b == 4)
971 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
972 if (m == 4 && n == 4 && k == 4 &&
b == 16)
973 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
974 if (m == 32 && n == 32 && k == 8 &&
b == 1)
975 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
976 if (m == 16 && n == 16 && k == 16 &&
b == 1)
977 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
979 if (m == 32 && n == 32 && k == 2 &&
b == 2)
980 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
981 if (m == 16 && n == 16 && k == 2 &&
b == 4)
982 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
983 if (m == 4 && n == 4 && k == 2 &&
b == 16)
984 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
985 if (m == 32 && n == 32 && k == 4 &&
b == 1)
986 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
987 if (m == 16 && n == 16 && k == 8 &&
b == 1)
988 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
993 if (m == 32 && n == 32 && k == 32 &&
b == 1)
994 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
995 if (m == 16 && n == 16 && k == 64 &&
b == 1)
996 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
998 if (m == 32 && n == 32 && k == 4 &&
b == 2)
999 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
1000 if (m == 16 && n == 16 && k == 4 &&
b == 4)
1001 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
1002 if (m == 4 && n == 4 && k == 4 &&
b == 16)
1003 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
1004 if (m == 32 && n == 32 && k == 8 &&
b == 1)
1005 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
1006 if (m == 16 && n == 16 && k == 16 &&
b == 1)
1007 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
1008 if (m == 32 && n == 32 && k == 16 &&
b == 1 && chipset >=
kGfx942)
1009 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
1010 if (m == 16 && n == 16 && k == 32 &&
b == 1 && chipset >=
kGfx942)
1011 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
1015 if (m == 16 && n == 16 && k == 4 &&
b == 1)
1016 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
1017 if (m == 4 && n == 4 && k == 4 &&
b == 4)
1018 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
1025 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
1026 if (m == 16 && n == 16 && k == 32 &&
b == 1) {
1028 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
1030 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
1032 if (m == 32 && n == 32 && k == 16 &&
b == 1) {
1034 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
1036 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
1042 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
1043 if (m == 16 && n == 16 && k == 32 &&
b == 1) {
1045 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
1047 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
1049 if (m == 32 && n == 32 && k == 16 &&
b == 1) {
1051 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
1053 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
1057 return std::nullopt;
1060static std::optional<ROCDL::MatrixFormat>
1064 .Case([](Float8E4M3FNType) {
return ROCDL::MatrixFormat::fp8_e4m3; })
1065 .Case([](Float8E5M2Type) {
return ROCDL::MatrixFormat::fp8_e5m2; })
1066 .Case([](Float6E2M3FNType) {
return ROCDL::MatrixFormat::fp6_e2m3; })
1067 .Case([](Float6E3M2FNType) {
return ROCDL::MatrixFormat::fp6_e3m2; })
1068 .Case([](Float4E2M1FNType) {
return ROCDL::MatrixFormat::fp4_e2m1; })
1069 .Default(std::nullopt);
1080 std::tuple<StringRef, ROCDL::MatrixFormat, ROCDL::MatrixFormat>;
1082static std::optional<ScaledMFMAIntrinsic>
1084 uint32_t n, uint32_t k, uint32_t
b,
Chipset chipset) {
1090 return std::nullopt;
1091 if (!isa<Float32Type>(destType))
1092 return std::nullopt;
1094 std::optional<ROCDL::MatrixFormat> aTypeCode =
1096 std::optional<ROCDL::MatrixFormat> bTypeCode =
1098 if (!aTypeCode || !bTypeCode)
1099 return std::nullopt;
1101 if (m == 32 && n == 32 && k == 64 &&
b == 1)
1102 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
1103 *aTypeCode, *bTypeCode};
1104 if (m == 16 && n == 16 && k == 128 &&
b == 1)
1106 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
1109 return std::nullopt;
1112static std::optional<ScaledMFMAIntrinsic>
1115 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
1116 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
1117 mfma.getBlocks(), chipset);
1120static std::optional<ScaledMFMAIntrinsic>
1123 smfma.getSourceB().getType(),
1124 smfma.getDestC().getType(), smfma.getM(),
1125 smfma.getN(), smfma.getK(), 1u, chipset);
1130static std::optional<StringRef>
1132 Type elemDestType, uint32_t k,
bool isRDNA3) {
1133 using fp8 = Float8E4M3FNType;
1134 using bf8 = Float8E5M2Type;
1139 if (elemSourceType.
isF16() && elemDestType.
isF32())
1140 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
1141 if (elemSourceType.
isBF16() && elemDestType.
isF32())
1142 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
1143 if (elemSourceType.
isF16() && elemDestType.
isF16())
1144 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1146 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1148 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1153 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1154 return std::nullopt;
1158 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1159 elemDestType.
isF32())
1160 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1161 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1162 elemDestType.
isF32())
1163 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1164 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1165 elemDestType.
isF32())
1166 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1167 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1168 elemDestType.
isF32())
1169 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1171 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1173 return std::nullopt;
1177 if (k == 32 && !isRDNA3) {
1179 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1182 return std::nullopt;
1188 Type elemBSourceType,
1191 using fp8 = Float8E4M3FNType;
1192 using bf8 = Float8E5M2Type;
1195 if (elemSourceType.
isF32() && elemDestType.
isF32())
1196 return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1198 return std::nullopt;
1202 if (elemSourceType.
isF16() && elemDestType.
isF32())
1203 return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
1204 if (elemSourceType.
isBF16() && elemDestType.
isF32())
1205 return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
1206 if (elemSourceType.
isF16() && elemDestType.
isF16())
1207 return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
1209 return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
1211 return std::nullopt;
1215 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1216 if (elemDestType.
isF32())
1217 return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
1218 if (elemDestType.
isF16())
1219 return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
1221 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1222 if (elemDestType.
isF32())
1223 return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
1224 if (elemDestType.
isF16())
1225 return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
1227 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1228 if (elemDestType.
isF32())
1229 return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
1230 if (elemDestType.
isF16())
1231 return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
1233 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1234 if (elemDestType.
isF32())
1235 return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
1236 if (elemDestType.
isF16())
1237 return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
1240 return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
1242 return std::nullopt;
1246 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1247 if (elemDestType.
isF32())
1248 return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
1249 if (elemDestType.
isF16())
1250 return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
1252 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1253 if (elemDestType.
isF32())
1254 return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
1255 if (elemDestType.
isF16())
1256 return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
1258 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1259 if (elemDestType.
isF32())
1260 return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
1261 if (elemDestType.
isF16())
1262 return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
1264 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1265 if (elemDestType.
isF32())
1266 return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
1267 if (elemDestType.
isF16())
1268 return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
1271 return std::nullopt;
1274 return std::nullopt;
1282 bool isGfx950 = chipset >=
kGfx950;
1286 uint32_t m = op.getM(), n = op.getN(), k = op.getK();
1291 if (m == 16 && n == 16 && k == 32) {
1293 return ROCDL::smfmac_f32_16x16x32_f16::getOperationName();
1295 return ROCDL::smfmac_f32_16x16x32_bf16::getOperationName();
1298 if (m == 16 && n == 16 && k == 64) {
1301 return ROCDL::smfmac_f32_16x16x64_f16::getOperationName();
1303 return ROCDL::smfmac_f32_16x16x64_bf16::getOperationName();
1307 return ROCDL::smfmac_i32_16x16x64_i8::getOperationName();
1308 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1309 return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName();
1310 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1311 return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName();
1312 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1313 return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName();
1314 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1315 return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName();
1318 if (m == 16 && n == 16 && k == 128 && isGfx950) {
1321 return ROCDL::smfmac_i32_16x16x128_i8::getOperationName();
1322 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1323 return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName();
1324 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1325 return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName();
1326 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1327 return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName();
1328 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1329 return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName();
1332 if (m == 32 && n == 32 && k == 16) {
1334 return ROCDL::smfmac_f32_32x32x16_f16::getOperationName();
1336 return ROCDL::smfmac_f32_32x32x16_bf16::getOperationName();
1339 if (m == 32 && n == 32 && k == 32) {
1342 return ROCDL::smfmac_f32_32x32x32_f16::getOperationName();
1344 return ROCDL::smfmac_f32_32x32x32_bf16::getOperationName();
1348 return ROCDL::smfmac_i32_32x32x32_i8::getOperationName();
1349 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1350 return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName();
1351 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1352 return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName();
1353 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1354 return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName();
1355 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1356 return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName();
1359 if (m == 32 && n == 32 && k == 64 && isGfx950) {
1362 return ROCDL::smfmac_i32_32x32x64_i8::getOperationName();
1363 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1364 return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName();
1365 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1366 return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName();
1367 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1368 return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName();
1369 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1370 return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName();
1373 return std::nullopt;
1381 auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
1382 auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
1383 auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1384 Type elemSourceType = sourceVectorType.getElementType();
1385 Type elemBSourceType = sourceBVectorType.getElementType();
1386 Type elemDestType = destVectorType.getElementType();
1388 const uint32_t k = wmma.getK();
1393 if (isRDNA3 || isRDNA4)
1402 return std::nullopt;
1415static std::optional<SparseWMMAOpInfo>
1421 uint32_t m = swmmac.getM(), n = swmmac.getN(), k = swmmac.getK();
1423 if ((m != 16) || (n != 16))
1424 return std::nullopt;
1431 ROCDL::swmmac_f32_16x16x32_f16::getOperationName(),
false,
false,
1435 ROCDL::swmmac_f32_16x16x32_bf16::getOperationName(),
false,
false,
1439 ROCDL::swmmac_f16_16x16x32_f16::getOperationName(),
false,
false,
1443 ROCDL::swmmac_bf16_16x16x32_bf16::getOperationName(),
false,
false,
1448 ROCDL::swmmac_i32_16x16x32_iu8::getOperationName(),
true,
false,
1453 ROCDL::swmmac_i32_16x16x32_iu4::getOperationName(),
true,
false,
1458 ROCDL::swmmac_f32_16x16x32_fp8_fp8::getOperationName(),
false,
1463 ROCDL::swmmac_f32_16x16x32_fp8_bf8::getOperationName(),
false,
1468 ROCDL::swmmac_f32_16x16x32_bf8_fp8::getOperationName(),
false,
1472 ROCDL::swmmac_f32_16x16x32_bf8_bf8::getOperationName(),
false,
1479 ROCDL::swmmac_i32_16x16x64_iu4::getOperationName(),
true,
false,
1484 const bool isGFX1250 = chipset ==
kGfx1250;
1485 const bool isWavesize64 = swmmac.getWave64();
1486 if (isGFX1250 && !isWavesize64) {
1490 ROCDL::swmmac_f32_16x16x64_f16::getOperationName(),
true,
true,
1494 ROCDL::swmmac_f32_16x16x64_bf16::getOperationName(),
true,
true,
1498 ROCDL::swmmac_f16_16x16x64_f16::getOperationName(),
true,
true,
1502 ROCDL::swmmac_bf16_16x16x64_bf16::getOperationName(),
true,
true,
1509 ROCDL::swmmac_f32_16x16x128_fp8_fp8::getOperationName(),
false,
1514 ROCDL::swmmac_f32_16x16x128_fp8_bf8::getOperationName(),
false,
1519 ROCDL::swmmac_f32_16x16x128_bf8_fp8::getOperationName(),
false,
1523 ROCDL::swmmac_f32_16x16x128_bf8_bf8::getOperationName(),
false,
1528 ROCDL::swmmac_f16_16x16x128_fp8_fp8::getOperationName(),
false,
1533 ROCDL::swmmac_f16_16x16x128_fp8_bf8::getOperationName(),
false,
1538 ROCDL::swmmac_f16_16x16x128_bf8_fp8::getOperationName(),
false,
1542 ROCDL::swmmac_f16_16x16x128_bf8_bf8::getOperationName(),
false,
1547 ROCDL::swmmac_f16_16x16x128_bf8_bf8::getOperationName(),
false,
1552 ROCDL::swmmac_i32_16x16x128_iu8::getOperationName(),
true,
true,
1557 return std::nullopt;
1562 MFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1563 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
1568 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1569 ConversionPatternRewriter &rewriter)
const override {
1570 Location loc = op.getLoc();
1572 Type outType = typeConverter->convertType(op.getDestD().getType());
1573 Type intrinsicOutType = outType;
1574 if (
auto outVecType = dyn_cast<VectorType>(outType))
1575 if (outVecType.getElementType().isBF16())
1576 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
1578 if (chipset.majorVersion != 9 || chipset <
kGfx908)
1579 return op->emitOpError(
"MFMA only supported on gfx908+");
1580 uint32_t getBlgpField =
static_cast<uint32_t
>(op.getBlgp());
1581 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1583 return op.emitOpError(
"negation unsupported on older than gfx942");
1585 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1588 std::optional<ScaledMFMAIntrinsic> maybeScaledIntrinsic =
1590 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1591 return op.emitOpError(
"no intrinsic matching MFMA size on given chipset");
1594 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1596 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1597 return op.emitOpError(
1598 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1599 "be scaled as those fields are used for type information");
1602 StringRef intrinsicName =
1603 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1606 bool allowBf16 = [&]() {
1611 return intrinsicName.contains(
"16x16x32.bf16") ||
1612 intrinsicName.contains(
"32x32x16.bf16");
1614 OperationState loweredOp(loc, intrinsicName);
1615 loweredOp.addTypes(intrinsicOutType);
1617 rewriter, loc, adaptor.getSourceA(), allowBf16),
1619 rewriter, loc, adaptor.getSourceB(), allowBf16),
1620 adaptor.getDestC()});
1623 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1624 loweredOp.addOperands({zero, zero});
1625 loweredOp.addAttributes(
1627 ROCDL::MatrixFormatAttr::get(rewriter.getContext(), aTypeCode)},
1629 ROCDL::MatrixFormatAttr::get(rewriter.getContext(), bTypeCode)},
1630 {
"opselA", rewriter.getI32IntegerAttr(0)},
1631 {
"opselB", rewriter.getI32IntegerAttr(0)}});
1633 Attribute blgpAttr =
1635 ? Attribute(ROCDL::MFMANegModifierAttr::get(
1636 rewriter.getContext(),
1637 static_cast<ROCDL::MFMANegModifier
>(getBlgpField)))
1638 : Attribute(ROCDL::MFMAPermBAttr::
get(
1640 static_cast<ROCDL::MFMAPermB>(getBlgpField)));
1641 loweredOp.addAttributes(
1642 {{
"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1643 {
"abid", rewriter.getI32IntegerAttr(op.getAbid())},
1644 {
"blgp", blgpAttr}});
1646 Value lowered = rewriter.create(loweredOp)->getResult(0);
1647 if (outType != intrinsicOutType)
1648 lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1649 rewriter.replaceOp(op, lowered);
1655 ScaledMFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1656 : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1661 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1662 ConversionPatternRewriter &rewriter)
const override {
1663 Location loc = op.getLoc();
1664 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1666 if (chipset.majorVersion != 9 || chipset <
kGfx950)
1667 return op->emitOpError(
"scaled MFMA only supported on gfx908+");
1668 std::optional<ScaledMFMAIntrinsic> maybeScaledIntrinsic =
1670 if (!maybeScaledIntrinsic.has_value())
1671 return op.emitOpError(
1672 "no intrinsic matching scaled MFMA size on given chipset");
1674 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1675 OperationState loweredOp(loc, intrinsicName);
1676 loweredOp.addTypes(intrinsicOutType);
1677 loweredOp.addOperands(
1680 adaptor.getDestC()});
1681 loweredOp.addOperands(
1686 loweredOp.addAttributes(
1688 ROCDL::MatrixFormatAttr::get(rewriter.getContext(), aTypeCode)},
1690 ROCDL::MatrixFormatAttr::get(rewriter.getContext(), bTypeCode)},
1691 {
"opselA", rewriter.getI32IntegerAttr(adaptor.getScalesIdxA())},
1692 {
"opselB", rewriter.getI32IntegerAttr(adaptor.getScalesIdxB())}});
1694 Value lowered = rewriter.create(loweredOp)->getResult(0);
1695 rewriter.replaceOp(op, lowered);
1701 SparseMFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1702 : ConvertOpToLLVMPattern<SparseMFMAOp>(converter), chipset(chipset) {}
1707 matchAndRewrite(SparseMFMAOp op, SparseMFMAOpAdaptor adaptor,
1708 ConversionPatternRewriter &rewriter)
const override {
1709 Location loc = op.getLoc();
1711 typeConverter->convertType<VectorType>(op.getDestC().
getType());
1713 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1716 if (chipset.majorVersion != 9 || chipset <
kGfx942)
1717 return op->emitOpError(
"sparse MFMA (smfmac) only supported on gfx942+");
1720 if (!maybeIntrinsic.has_value())
1721 return op.emitOpError(
1722 "no intrinsic matching sparse MFMA on the given chipset");
1725 ROCDL::smfmac_f32_16x16x32_bf16::getOperationName() ||
1727 ROCDL::smfmac_f32_32x32x16_bf16::getOperationName());
1728 bool isGfx950 = (chipset >=
kGfx950) && !isGfx942BF16;
1734 Value c = adaptor.getDestC();
1738 Value sparseIdx = adaptor.getSparseIdx();
1739 Type i32Type = rewriter.getI32Type();
1740 if (sparseIdx.
getType() != i32Type)
1741 sparseIdx = LLVM::BitcastOp::create(rewriter, loc, i32Type, sparseIdx);
1743 OperationState loweredOp(loc, maybeIntrinsic.value());
1744 loweredOp.addTypes(outType);
1745 loweredOp.addOperands({a,
b, c, sparseIdx});
1746 loweredOp.addAttributes(
1747 {{
"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1748 {
"abid", rewriter.getI32IntegerAttr(op.getAbid())}});
1749 Value lowered = rewriter.create(loweredOp)->getResult(0);
1750 rewriter.replaceOp(op, lowered);
1756 WMMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1757 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1762 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1763 ConversionPatternRewriter &rewriter)
const override {
1764 Location loc = op.getLoc();
1766 typeConverter->convertType<VectorType>(op.getDestD().
getType());
1768 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1770 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1771 return op->emitOpError(
"WMMA only supported on gfx11 and gfx12");
1773 bool isGFX1250 = chipset >=
kGfx1250;
1778 auto aType = cast<VectorType>(adaptor.getSourceA().getType());
1779 auto bType = cast<VectorType>(adaptor.getSourceB().getType());
1780 auto destCType = cast<VectorType>(adaptor.getDestC().getType());
1781 bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
1782 bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
1783 bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
1784 bool castOutToI16 = outType.getElementType().
isBF16() && !isGFX1250;
1785 VectorType rawOutType = outType;
1787 rawOutType = outType.clone(rewriter.getI16Type());
1788 Value a = adaptor.getSourceA();
1790 a = LLVM::BitcastOp::create(rewriter, loc,
1791 aType.clone(rewriter.getI16Type()), a);
1792 Value
b = adaptor.getSourceB();
1794 b = LLVM::BitcastOp::create(rewriter, loc,
1795 bType.clone(rewriter.getI16Type()),
b);
1796 Value destC = adaptor.getDestC();
1798 destC = LLVM::BitcastOp::create(
1799 rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
1803 if (!maybeIntrinsic.has_value())
1804 return op.emitOpError(
"no intrinsic matching WMMA on the given chipset");
1806 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1807 return op.emitOpError(
"subwordOffset not supported on gfx12+");
1809 SmallVector<Value, 4> operands;
1810 SmallVector<NamedAttribute, 4> attrs;
1812 op.getSourceA(), operands, attrs,
"signA");
1814 op.getSourceB(), operands, attrs,
"signB");
1816 op.getSubwordOffset(), op.getClamp(), operands,
1819 OperationState loweredOp(loc, *maybeIntrinsic);
1820 loweredOp.addTypes(rawOutType);
1821 loweredOp.addOperands(operands);
1822 loweredOp.addAttributes(attrs);
1823 Operation *lowered = rewriter.create(loweredOp);
1825 Operation *maybeCastBack = lowered;
1826 if (rawOutType != outType)
1827 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1829 rewriter.replaceOp(op, maybeCastBack->
getResults());
1835enum class DotFamily {
1844static std::optional<std::pair<StringRef, DotFamily>>
1845dotOpToIntrinsic(DotOp op,
Chipset chipset) {
1846 Type aElem = cast<VectorType>(op.getSourceA().getType()).getElementType();
1847 Type bElem = cast<VectorType>(op.getSourceB().getType()).getElementType();
1848 Type dest = op.getDestC().getType();
1849 bool uA = op.getUnsignedA();
1850 bool uB = op.getUnsignedB();
1855 return {{ROCDL::fdot2::getOperationName(), DotFamily::Clamp}};
1857 return {{ROCDL::fdot2_f16_f16::getOperationName(), DotFamily::NoClamp}};
1858 return std::nullopt;
1864 return {{ROCDL::fdot2_f32_bf16::getOperationName(), DotFamily::Clamp}};
1866 return {{ROCDL::fdot2_bf16_bf16::getOperationName(), DotFamily::NoClamp}};
1867 return std::nullopt;
1871 if (isa<IntegerType>(aElem) && isa<IntegerType>(bElem) &&
1873 bool mixedSign = (uA != uB);
1878 return std::nullopt;
1880 switch (elemWidth) {
1882 name = ROCDL::sudot4::getOperationName();
1885 name = ROCDL::sudot8::getOperationName();
1888 return std::nullopt;
1890 return {{name, DotFamily::Sudot}};
1894 bool supported =
false;
1895 switch (elemWidth) {
1898 name = uA ? ROCDL::udot2::getOperationName()
1899 :
ROCDL::sdot2::getOperationName();
1904 name = uA ? ROCDL::udot4::getOperationName()
1905 :
ROCDL::sdot4::getOperationName();
1910 name = uA ? ROCDL::udot8::getOperationName()
1911 :
ROCDL::sdot8::getOperationName();
1914 return std::nullopt;
1917 return std::nullopt;
1918 return {{name, DotFamily::Clamp}};
1922 bool aIsFp8 = isa<Float8E4M3FNType>(aElem);
1923 bool aIsBf8 = isa<Float8E5M2Type>(aElem);
1924 bool bIsFp8 = isa<Float8E4M3FNType>(bElem);
1925 bool bIsBf8 = isa<Float8E5M2Type>(bElem);
1926 if ((aIsFp8 || aIsBf8) && (bIsFp8 || bIsBf8) && dest.
isF32()) {
1928 return std::nullopt;
1930 if (aIsFp8 && bIsFp8)
1931 name = ROCDL::dot4_f32_fp8_fp8::getOperationName();
1932 else if (aIsFp8 && bIsBf8)
1933 name = ROCDL::dot4_f32_fp8_bf8::getOperationName();
1934 else if (aIsBf8 && bIsFp8)
1935 name = ROCDL::dot4_f32_bf8_fp8::getOperationName();
1937 name = ROCDL::dot4_f32_bf8_bf8::getOperationName();
1938 return {{name, DotFamily::NoClamp}};
1941 return std::nullopt;
1945 DotOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1946 : ConvertOpToLLVMPattern<DotOp>(converter), chipset(chipset) {}
1951 matchAndRewrite(DotOp op, DotOpAdaptor adaptor,
1952 ConversionPatternRewriter &rewriter)
const override {
1953 Location loc = op.getLoc();
1955 std::optional<std::pair<StringRef, DotFamily>> maybeIntrinsic =
1956 dotOpToIntrinsic(op, chipset);
1957 if (!maybeIntrinsic)
1958 return op.emitOpError(
"no intrinsic matching dot on the given chipset: ")
1959 << op.getSourceA().getType() <<
" * " << op.getSourceB().getType()
1960 <<
" + " << op.getDestC().getType();
1962 auto [intrinsicName, family] = maybeIntrinsic.value();
1966 Value c = adaptor.getDestC();
1968 SmallVector<NamedAttribute, 3> attrs;
1969 if (family == DotFamily::Sudot) {
1970 attrs.push_back(rewriter.getNamedAttr(
1971 "signA", rewriter.getBoolAttr(!op.getUnsignedA())));
1972 attrs.push_back(rewriter.getNamedAttr(
1973 "signB", rewriter.getBoolAttr(!op.getUnsignedB())));
1976 if (family != DotFamily::NoClamp && op.getClamp())
1978 rewriter.getNamedAttr(
"clamp", rewriter.getBoolAttr(
true)));
1980 Type resultType = typeConverter->convertType(op.getDestD().getType());
1982 OperationState loweredOp(loc, intrinsicName);
1983 loweredOp.addTypes(resultType);
1984 loweredOp.addOperands({a,
b, c});
1985 loweredOp.addAttributes(attrs);
1986 Operation *lowered = rewriter.create(loweredOp);
1987 rewriter.replaceOp(op, lowered->
getResults());
1993 SparseWMMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1994 : ConvertOpToLLVMPattern<SparseWMMAOp>(converter), chipset(chipset) {}
1999 matchAndRewrite(SparseWMMAOp op, SparseWMMAOpAdaptor adaptor,
2000 ConversionPatternRewriter &rewriter)
const override {
2001 Location loc = op.getLoc();
2003 typeConverter->convertType<VectorType>(op.getDestD().
getType());
2005 return rewriter.notifyMatchFailure(op,
"type conversion failed");
2007 std::optional<SparseWMMAOpInfo> maybeIntrinsic =
2010 if (!maybeIntrinsic.has_value())
2011 return op.emitOpError(
2012 "no intrinsic matching Sparse WMMA on the given chipset");
2013 SparseWMMAOpInfo intrinsic = maybeIntrinsic.value();
2015 SmallVector<NamedAttribute> attrs;
2017 if ((op.getUnsignedA() || op.getUnsignedB()) && !intrinsic.
useSign)
2018 return op->emitOpError(
"intrinsic doesn't support unsign");
2020 if (
auto attr = op.getUnsignedAAttr())
2021 attrs.push_back({
"signA", attr});
2022 if (
auto attr = op.getUnsignedBAttr())
2023 attrs.push_back({
"signB", attr});
2026 if ((op.getReuseA() || op.getReuseB()) && !intrinsic.
useReuse)
2027 return op->emitOpError(
"intrinsic doesn't support reuse");
2029 if (
auto attr = op.getReuseAAttr())
2030 attrs.push_back({
"reuseA", attr});
2031 if (
auto attr = op.getReuseBAttr())
2032 attrs.push_back({
"reuseB", attr});
2035 if (op.getClamp() && !intrinsic.
useClamp)
2036 return op->emitOpError(
"intrinsic doesn't support clamp");
2037 if (intrinsic.
useClamp && op.getClampAttr())
2038 attrs.push_back({
"clamp", op.getClampAttr()});
2040 const bool isGFX1250orHigher =
2041 chipset.majorVersion == 12 && chipset.minorVersion >= 5;
2046 Value c = adaptor.getDestC();
2047 VectorType rawOutType = outType;
2048 if (!isGFX1250orHigher) {
2050 rawOutType = cast<VectorType>(c.
getType());
2054 Value sparseIdx = LLVM::BitcastOp::create(
2055 rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
2057 OperationState loweredOp(loc, intrinsic.
name);
2058 loweredOp.addTypes(rawOutType);
2059 loweredOp.addOperands({a,
b, c, sparseIdx});
2060 loweredOp.addAttributes(attrs);
2061 Operation *lowered = rewriter.create(loweredOp);
2063 Operation *maybeCastBack = lowered;
2064 if (rawOutType != outType)
2065 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
2067 rewriter.replaceOp(op, maybeCastBack->
getResults());
2074 ScaledWMMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2075 : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
2080 matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
2081 ConversionPatternRewriter &rewriter)
const override {
2082 Location loc = op.getLoc();
2084 typeConverter->convertType<VectorType>(op.getDestD().
getType());
2086 return rewriter.notifyMatchFailure(op,
"type conversion failed");
2089 return op->emitOpError(
"WMMA scale only supported on gfx1250+");
2091 int64_t m = op.getM();
2092 int64_t n = op.getN();
2093 int64_t k = op.getK();
2098 std::optional<ROCDL::MatrixFormat> aFmtCode =
2100 std::optional<ROCDL::MatrixFormat> bFmtCode =
2103 if (!aFmtCode || !bFmtCode)
2104 return op.emitOpError(
"unsupported element types for scaled_wmma");
2107 auto scaleAVecType = cast<VectorType>(op.getScaleA().getType());
2108 auto scaleBVecType = cast<VectorType>(op.getScaleB().getType());
2110 if (scaleAVecType.getNumElements() != scaleBVecType.getNumElements())
2111 return op.emitOpError(
"scaleA and scaleB must have equal vector length");
2114 Type scaleAElemType = scaleAVecType.getElementType();
2115 Type scaleBElemType = scaleBVecType.getElementType();
2117 std::optional<ROCDL::WMMAMatrixScaleFormat> scaleAFmt =
2119 std::optional<ROCDL::WMMAMatrixScaleFormat> scaleBFmt =
2122 if (!scaleAFmt || !scaleBFmt)
2123 return op.emitOpError(
"unsupported scale element types");
2126 bool isScale16 = (scaleAVecType.getNumElements() == 8);
2127 std::optional<StringRef> intrinsicName =
2130 return op.emitOpError(
"unsupported scaled_wmma dimensions: ")
2131 << m <<
"x" << n <<
"x" << k;
2133 SmallVector<NamedAttribute, 8> attrs;
2136 bool is32x16 = (m == 32 && n == 16 && k == 128);
2138 attrs.emplace_back(
"fmtA", ROCDL::MatrixFormatAttr::get(
2139 rewriter.getContext(), *aFmtCode));
2140 attrs.emplace_back(
"fmtB", ROCDL::MatrixFormatAttr::get(
2141 rewriter.getContext(), *bFmtCode));
2146 "modC", ROCDL::WMMACModifierAttr::get(rewriter.getContext(),
2147 ROCDL::WMMACModifier::none));
2151 attrs.emplace_back(
"scaleAType", ROCDL::WMMAMatrixScaleAttr::get(
2152 rewriter.getContext(),
2153 static_cast<ROCDL::WMMAMatrixScale
>(
2154 op.getAFirstScaleLane() / 16)));
2155 attrs.emplace_back(
"fmtScaleA", ROCDL::WMMAMatrixScaleFormatAttr::get(
2156 rewriter.getContext(), *scaleAFmt));
2157 attrs.emplace_back(
"scaleBType", ROCDL::WMMAMatrixScaleAttr::get(
2158 rewriter.getContext(),
2159 static_cast<ROCDL::WMMAMatrixScale
>(
2160 op.getBFirstScaleLane() / 16)));
2161 attrs.emplace_back(
"fmtScaleB", ROCDL::WMMAMatrixScaleFormatAttr::get(
2162 rewriter.getContext(), *scaleBFmt));
2165 attrs.emplace_back(
"reuseA", rewriter.getBoolAttr(
false));
2166 attrs.emplace_back(
"reuseB", rewriter.getBoolAttr(
false));
2179 OperationState loweredOp(loc, *intrinsicName);
2180 loweredOp.addTypes(outType);
2181 loweredOp.addOperands(
2182 {sourceA, sourceB, adaptor.getDestC(), packedScaleA, packedScaleB});
2183 loweredOp.addAttributes(attrs);
2185 Operation *lowered = rewriter.create(loweredOp);
2186 rewriter.replaceOp(op, lowered->
getResults());
2192struct TransposeLoadOpLowering
2194 TransposeLoadOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2195 : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
2200 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
2201 ConversionPatternRewriter &rewriter)
const override {
2203 return op.emitOpError(
2204 "transpose_load is only supported on gfx950 and gfx1250+");
2206 Location loc = op.getLoc();
2207 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2211 size_t srcElementSize =
2212 srcMemRefType.getElementType().getIntOrFloatBitWidth();
2213 if (srcElementSize < 8)
2214 return op.emitOpError(
"Expect source memref to have at least 8 bits "
2215 "element size, got ")
2218 auto resultType = cast<VectorType>(op.getResult().getType());
2221 (adaptor.getSrcIndices()));
2223 size_t numElements = resultType.getNumElements();
2224 size_t elementTypeSize =
2227 Type llvmResultType = typeConverter->convertType(resultType);
2230 Type rocdlResultType =
2231 elementTypeSize < 16
2232 ? VectorType::get((numElements * elementTypeSize) / 32,
2233 rewriter.getIntegerType(32))
2236 auto emitNumElementsError = [&](
size_t expected, StringRef chipsetName) {
2237 return op.emitOpError()
2238 << elementTypeSize <<
"-bit transpose_load requires " << expected
2239 <<
" elements on " << chipsetName;
2244 switch (elementTypeSize) {
2246 if (numElements != 16)
2247 return emitNumElementsError(16,
"gfx1250+");
2249 ROCDL::DsLoadTr4_B64::create(rewriter, loc, rocdlResultType, srcPtr)
2254 if (numElements != 16)
2255 return emitNumElementsError(16,
"gfx1250+");
2257 ROCDL::DsLoadTr6_B96::create(rewriter, loc, rocdlResultType, srcPtr)
2262 if (numElements != 8)
2263 return emitNumElementsError(8,
"gfx1250+");
2265 ROCDL::DsLoadTr8_B64::create(rewriter, loc, rocdlResultType, srcPtr)
2270 if (numElements != 8)
2271 return emitNumElementsError(8,
"gfx1250+");
2272 intrinsic = ROCDL::DsLoadTr16_B128::create(rewriter, loc,
2273 rocdlResultType, srcPtr)
2278 return op.emitOpError(
"Unsupported element size for transpose load");
2281 switch (elementTypeSize) {
2283 if (numElements != 16)
2284 return emitNumElementsError(16,
"gfx950");
2285 intrinsic = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
2286 rocdlResultType, srcPtr)
2291 if (numElements != 16)
2292 return emitNumElementsError(16,
"gfx950");
2293 intrinsic = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
2294 rocdlResultType, srcPtr)
2299 if (numElements != 8)
2300 return emitNumElementsError(8,
"gfx950");
2301 intrinsic = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
2302 rocdlResultType, srcPtr)
2307 if (numElements != 4)
2308 return emitNumElementsError(4,
"gfx950");
2309 intrinsic = ROCDL::ds_read_tr16_b64::create(rewriter, loc,
2310 rocdlResultType, srcPtr)
2315 return op.emitOpError(
"Unsupported element size for transpose load");
2319 assert(intrinsic &&
"expected ROCDL transpose load intrinsic");
2320 if (intrinsic.
getType() == llvmResultType) {
2321 rewriter.replaceOp(op, intrinsic);
2324 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, intrinsic);
2329struct GlobalTransposeLoadOpLowering
2331 GlobalTransposeLoadOpLowering(
const LLVMTypeConverter &converter,
2333 : ConvertOpToLLVMPattern<GlobalTransposeLoadOp>(converter),
2339 matchAndRewrite(GlobalTransposeLoadOp op,
2340 GlobalTransposeLoadOpAdaptor adaptor,
2341 ConversionPatternRewriter &rewriter)
const override {
2343 return op.emitOpError(
2344 "global_transpose_load is only supported on gfx1200+");
2346 Location loc = op.getLoc();
2347 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2348 auto resultType = cast<VectorType>(op.getResult().getType());
2351 rewriter, loc, srcMemRefType, adaptor.getSrc(), adaptor.getSrcIndices(),
2352 LLVM::GEPNoWrapFlags::inbounds | LLVM::GEPNoWrapFlags::nuw);
2354 size_t numElements = resultType.getNumElements();
2355 size_t elementTypeSize =
2360 Type rocdlResultType =
2361 elementTypeSize < 16
2362 ? VectorType::get((numElements * elementTypeSize) / 32,
2363 rewriter.getIntegerType(32))
2364 : typeConverter->convertType(resultType);
2365 Type llvmResultType = typeConverter->convertType(resultType);
2367 switch (elementTypeSize) {
2369 assert(numElements == 16);
2371 return op.emitOpError(
"4-bit global_transpose_load requires gfx1250+");
2372 auto rocdlOp = ROCDL::GlobalLoadTr4_B64::create(rewriter, loc,
2373 rocdlResultType, srcPtr);
2374 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2378 assert(numElements == 16);
2380 return op.emitOpError(
"6-bit global_transpose_load requires gfx1250+");
2381 auto rocdlOp = ROCDL::GlobalLoadTr6_B96::create(rewriter, loc,
2382 rocdlResultType, srcPtr);
2383 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2387 assert(numElements == 8);
2388 auto rocdlOp = ROCDL::GlobalLoadTr8_B64::create(rewriter, loc,
2389 rocdlResultType, srcPtr);
2390 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
2394 assert(numElements == 8);
2395 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadTr8_B128>(op, llvmResultType,
2400 return op.emitOpError(
2401 "unsupported element size for global transpose load");
2408 GatherToLDSOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2409 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
2414 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
2415 ConversionPatternRewriter &rewriter)
const override {
2416 if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
2417 return op.emitOpError(
"pre-gfx9 and post-gfx10 not supported");
2419 Location loc = op.getLoc();
2421 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2422 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
2427 Type transferType = op.getTransferType();
2428 int loadWidth = [&]() ->
int {
2429 if (
auto transferVectorType = dyn_cast<VectorType>(transferType)) {
2430 return (transferVectorType.getNumElements() *
2431 transferVectorType.getElementTypeBitWidth()) /
2438 if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
2439 return op.emitOpError(
"chipset unsupported element size");
2441 if (chipset !=
kGfx950 && llvm::is_contained({12, 16}, loadWidth))
2442 return op.emitOpError(
"Gather to LDS instructions with 12-byte and "
2443 "16-byte load widths are only supported on gfx950");
2447 (adaptor.getSrcIndices()));
2450 (adaptor.getDstIndices()));
2452 if (op.getAsync()) {
2453 rewriter.replaceOpWithNewOp<ROCDL::LoadAsyncToLDSOp>(
2454 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
2455 rewriter.getI32IntegerAttr(0),
2459 rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
2460 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
2461 rewriter.getI32IntegerAttr(0),
2470struct GlobalLoadAsyncToLDSOpLowering
2472 GlobalLoadAsyncToLDSOpLowering(
const LLVMTypeConverter &converter,
2474 : ConvertOpToLLVMPattern<GlobalLoadAsyncToLDSOp>(converter),
2480 matchAndRewrite(GlobalLoadAsyncToLDSOp op,
2481 GlobalLoadAsyncToLDSOpAdaptor adaptor,
2482 ConversionPatternRewriter &rewriter)
const override {
2484 return op.emitOpError(
2485 "global_load_async_to_lds is only supported on gfx1250+");
2487 Location loc = op.getLoc();
2488 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
2489 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
2491 Type transferType = op.getTransferType();
2493 isa<VectorType>(transferType)
2494 ? cast<VectorType>(transferType).getNumElements() *
2495 cast<VectorType>(transferType).getElementTypeBitWidth()
2500 adaptor.getSrcIndices());
2503 adaptor.getDstIndices());
2506 Value mask = adaptor.getMask();
2507 int64_t nullptrVal =
2508 llvm::AMDGPU::getNullPointerValue(llvm::AMDGPUAS::LOCAL_ADDRESS);
2512 LLVM::IntToPtrOp::create(rewriter, loc, dstPtr.
getType(), nullInt);
2513 dstPtr = LLVM::SelectOp::create(rewriter, loc, mask, dstPtr, nullPtr);
2516 auto offset = rewriter.getI32IntegerAttr(0);
2517 Attribute aux = rewriter.getI32IntegerAttr(0);
2519 switch (transferBits) {
2521 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB8Op>(
2526 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB32Op>(
2531 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB64Op>(
2536 rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadAsyncToLDSB128Op>(
2541 return op.emitOpError(
"unsupported transfer width");
2548struct ExtPackedFp8OpLowering final
2550 ExtPackedFp8OpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2551 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
2556 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
2557 ConversionPatternRewriter &rewriter)
const override;
2560struct ScaledExtPackedMatrixOpLowering final
2562 ScaledExtPackedMatrixOpLowering(
const LLVMTypeConverter &converter,
2564 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
2569 matchAndRewrite(ScaledExtPackedMatrixOp op,
2570 ScaledExtPackedMatrixOpAdaptor adaptor,
2571 ConversionPatternRewriter &rewriter)
const override;
2574struct PackedTrunc2xFp8OpLowering final
2576 PackedTrunc2xFp8OpLowering(
const LLVMTypeConverter &converter,
2578 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
2583 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
2584 ConversionPatternRewriter &rewriter)
const override;
2587struct PackedStochRoundFp8OpLowering final
2589 PackedStochRoundFp8OpLowering(
const LLVMTypeConverter &converter,
2591 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
2596 matchAndRewrite(PackedStochRoundFp8Op op,
2597 PackedStochRoundFp8OpAdaptor adaptor,
2598 ConversionPatternRewriter &rewriter)
const override;
2601struct ScaledExtPackedOpLowering final
2603 ScaledExtPackedOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2604 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
2609 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2610 ConversionPatternRewriter &rewriter)
const override;
2613struct PackedScaledTruncOpLowering final
2615 PackedScaledTruncOpLowering(
const LLVMTypeConverter &converter,
2617 : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
2622 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2623 ConversionPatternRewriter &rewriter)
const override;
2628LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
2629 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
2630 ConversionPatternRewriter &rewriter)
const {
2631 Location loc = op.getLoc();
2633 return rewriter.notifyMatchFailure(
2634 loc,
"Fp8 conversion instructions are not available on target "
2635 "architecture and their emulation is not implemented");
2637 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
2638 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2639 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
2641 Value source = adaptor.getSource();
2642 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
2643 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
2646 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
2647 Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
2648 if (!sourceVecType) {
2649 longVec = LLVM::InsertElementOp::create(
2652 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2654 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2656 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2661 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2662 if (resultVecType) {
2664 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
2667 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
2672 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
2675 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
2682int32_t getScaleSel(int32_t blockSize,
unsigned bitWidth, int32_t scaleWaveHalf,
2683 int32_t firstScaleByte) {
2689 assert(llvm::is_contained({16, 32}, blockSize));
2690 assert(llvm::is_contained({4u, 6u, 8u}, bitWidth));
2692 const bool isFp8 = bitWidth == 8;
2693 const bool isBlock16 = blockSize == 16;
2696 int32_t bit0 = isBlock16;
2697 assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
2698 int32_t bit1 = (firstScaleByte == 2) << 1;
2699 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
2700 int32_t bit2 = scaleWaveHalf << 2;
2701 return bit2 | bit1 | bit0;
2704 int32_t bit0 = isBlock16;
2706 assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
2707 int32_t bits2and1 = firstScaleByte << 1;
2708 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
2709 int32_t bit3 = scaleWaveHalf << 3;
2710 int32_t bits = bit3 | bits2and1 | bit0;
2712 assert(!llvm::is_contained(
2713 {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
2717static std::optional<StringRef>
2718scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
2719 using fp4 = Float4E2M1FNType;
2720 using fp8 = Float8E4M3FNType;
2721 using bf8 = Float8E5M2Type;
2722 using fp6 = Float6E2M3FNType;
2723 using bf6 = Float6E3M2FNType;
2724 if (isa<fp4>(srcElemType)) {
2725 if (destElemType.
isF16())
2726 return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
2727 if (destElemType.
isBF16())
2728 return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
2729 if (destElemType.
isF32())
2730 return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
2731 return std::nullopt;
2733 if (isa<fp8>(srcElemType)) {
2734 if (destElemType.
isF16())
2735 return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
2736 if (destElemType.
isBF16())
2737 return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
2738 if (destElemType.
isF32())
2739 return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
2740 return std::nullopt;
2742 if (isa<bf8>(srcElemType)) {
2743 if (destElemType.
isF16())
2744 return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
2745 if (destElemType.
isBF16())
2746 return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
2747 if (destElemType.
isF32())
2748 return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
2749 return std::nullopt;
2751 if (isa<fp6>(srcElemType)) {
2752 if (destElemType.
isF16())
2753 return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
2754 if (destElemType.
isBF16())
2755 return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
2756 if (destElemType.
isF32())
2757 return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
2758 return std::nullopt;
2760 if (isa<bf6>(srcElemType)) {
2761 if (destElemType.
isF16())
2762 return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
2763 if (destElemType.
isBF16())
2764 return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
2765 if (destElemType.
isF32())
2766 return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
2767 return std::nullopt;
2769 llvm_unreachable(
"invalid combination of element types for packed conversion "
2773LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
2774 ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
2775 ConversionPatternRewriter &rewriter)
const {
2776 using fp4 = Float4E2M1FNType;
2777 using fp8 = Float8E4M3FNType;
2778 using bf8 = Float8E5M2Type;
2779 using fp6 = Float6E2M3FNType;
2780 using bf6 = Float6E3M2FNType;
2781 Location loc = op.getLoc();
2783 return rewriter.notifyMatchFailure(
2785 "Scaled fp packed conversion instructions are not available on target "
2786 "architecture and their emulation is not implemented");
2790 int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
2791 int32_t firstScaleByte = op.getFirstScaleByte();
2792 int32_t blockSize = op.getBlockSize();
2793 auto sourceType = cast<VectorType>(op.getSource().getType());
2794 auto srcElemType = cast<FloatType>(sourceType.getElementType());
2795 unsigned bitWidth = srcElemType.getWidth();
2797 auto targetType = cast<VectorType>(op.getResult().getType());
2798 auto destElemType = cast<FloatType>(targetType.getElementType());
2800 IntegerType i32 = rewriter.getI32Type();
2801 Value source = adaptor.getSource();
2802 Type llvmResultType = typeConverter->convertType(op.getResult().getType());
2803 Type packedType =
nullptr;
2804 if (isa<fp4>(srcElemType)) {
2806 packedType = getTypeConverter()->convertType(packedType);
2807 }
else if (isa<fp8, bf8>(srcElemType)) {
2808 packedType = VectorType::get(2, i32);
2809 packedType = getTypeConverter()->convertType(packedType);
2810 }
else if (isa<fp6, bf6>(srcElemType)) {
2811 packedType = VectorType::get(3, i32);
2812 packedType = getTypeConverter()->convertType(packedType);
2814 llvm_unreachable(
"invalid element type for packed scaled ext");
2817 if (!packedType || !llvmResultType) {
2818 return rewriter.notifyMatchFailure(op,
"type conversion failed");
2821 std::optional<StringRef> maybeIntrinsic =
2822 scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
2823 if (!maybeIntrinsic.has_value())
2824 return op.emitOpError(
2825 "no intrinsic matching packed scaled conversion on the given chipset");
2828 getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
2830 LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
2831 Value castedSource =
2832 LLVM::BitcastOp::create(rewriter, loc, packedType, source);
2834 OperationState loweredOp(loc, *maybeIntrinsic);
2835 loweredOp.addTypes({llvmResultType});
2836 loweredOp.addOperands({castedSource, castedScale});
2838 SmallVector<NamedAttribute, 1> attrs;
2840 NamedAttribute(
"scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
2842 loweredOp.addAttributes(attrs);
2843 Operation *lowered = rewriter.create(loweredOp);
2844 rewriter.replaceOp(op, lowered);
2849LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
2850 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2851 ConversionPatternRewriter &rewriter)
const {
2852 Location loc = op.getLoc();
2854 return rewriter.notifyMatchFailure(
2855 loc,
"Scaled fp conversion instructions are not available on target "
2856 "architecture and their emulation is not implemented");
2857 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2859 Value source = adaptor.getSource();
2860 Value scale = adaptor.getScale();
2862 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2863 Type sourceElemType = sourceVecType.getElementType();
2864 VectorType destVecType = cast<VectorType>(op.getResult().getType());
2865 Type destElemType = destVecType.getElementType();
2867 VectorType packedVecType;
2868 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
2869 VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
2870 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
2871 }
else if (isa<Float4E2M1FNType>(sourceElemType)) {
2872 VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
2873 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
2875 llvm_unreachable(
"invalid element type for scaled ext");
2879 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
2880 Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
2881 if (!sourceVecType) {
2882 longVec = LLVM::InsertElementOp::create(
2885 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2887 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2889 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2894 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2896 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF32())
2897 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
2898 op, destVecType, i32Source, scale, op.getIndex());
2899 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF16())
2900 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
2901 op, destVecType, i32Source, scale, op.getIndex());
2902 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isBF16())
2903 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
2904 op, destVecType, i32Source, scale, op.getIndex());
2905 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF32())
2906 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
2907 op, destVecType, i32Source, scale, op.getIndex());
2908 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF16())
2909 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
2910 op, destVecType, i32Source, scale, op.getIndex());
2911 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isBF16())
2912 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
2913 op, destVecType, i32Source, scale, op.getIndex());
2914 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF32())
2915 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
2916 op, destVecType, i32Source, scale, op.getIndex());
2917 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF16())
2918 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
2919 op, destVecType, i32Source, scale, op.getIndex());
2920 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isBF16())
2921 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
2922 op, destVecType, i32Source, scale, op.getIndex());
2929LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
2930 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2931 ConversionPatternRewriter &rewriter)
const {
2932 Location loc = op.getLoc();
2934 return rewriter.notifyMatchFailure(
2935 loc,
"Scaled fp conversion instructions are not available on target "
2936 "architecture and their emulation is not implemented");
2937 Type v2i16 = getTypeConverter()->convertType(
2938 VectorType::get(2, rewriter.getI16Type()));
2939 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2941 Type resultType = op.getResult().getType();
2943 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2944 Type sourceElemType = sourceVecType.getElementType();
2946 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
2948 Value source = adaptor.getSource();
2949 Value scale = adaptor.getScale();
2950 Value existing = adaptor.getExisting();
2952 existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
2954 existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
2956 if (sourceVecType.getNumElements() < 2) {
2958 Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2959 VectorType v2 = VectorType::get(2, sourceElemType);
2960 source = LLVM::ZeroOp::create(rewriter, loc, v2);
2961 source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
2964 Value sourceA, sourceB;
2965 if (sourceElemType.
isF32()) {
2968 sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2969 sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
2973 if (sourceElemType.
isF32() && isa<Float8E5M2Type>(resultElemType))
2974 result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
2975 existing, sourceA, sourceB,
2976 scale, op.getIndex());
2977 else if (sourceElemType.
isF16() && isa<Float8E5M2Type>(resultElemType))
2978 result = ROCDL::CvtScaleF32PkBf8F16Op::create(
2979 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2980 else if (sourceElemType.
isBF16() && isa<Float8E5M2Type>(resultElemType))
2981 result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
2982 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2983 else if (sourceElemType.
isF32() && isa<Float8E4M3FNType>(resultElemType))
2984 result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
2985 existing, sourceA, sourceB,
2986 scale, op.getIndex());
2987 else if (sourceElemType.
isF16() && isa<Float8E4M3FNType>(resultElemType))
2988 result = ROCDL::CvtScaleF32PkFp8F16Op::create(
2989 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2990 else if (sourceElemType.
isBF16() && isa<Float8E4M3FNType>(resultElemType))
2991 result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
2992 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2993 else if (sourceElemType.
isF32() && isa<Float4E2M1FNType>(resultElemType))
2994 result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
2995 existing, sourceA, sourceB,
2996 scale, op.getIndex());
2997 else if (sourceElemType.
isF16() && isa<Float4E2M1FNType>(resultElemType))
2998 result = ROCDL::CvtScaleF32PkFp4F16Op::create(
2999 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
3000 else if (sourceElemType.
isBF16() && isa<Float4E2M1FNType>(resultElemType))
3001 result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
3002 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
3006 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
3007 op, getTypeConverter()->convertType(resultType),
result);
3011LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
3012 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
3013 ConversionPatternRewriter &rewriter)
const {
3014 Location loc = op.getLoc();
3016 return rewriter.notifyMatchFailure(
3017 loc,
"Fp8 conversion instructions are not available on target "
3018 "architecture and their emulation is not implemented");
3019 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
3021 Type resultType = op.getResult().getType();
3024 Value sourceA = adaptor.getSourceA();
3025 Value sourceB = adaptor.getSourceB();
3027 sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.
getType());
3028 Value existing = adaptor.getExisting();
3030 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
3032 existing = LLVM::UndefOp::create(rewriter, loc, i32);
3036 result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
3037 existing, op.getWordIndex());
3039 result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
3040 existing, op.getWordIndex());
3042 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
3043 op, getTypeConverter()->convertType(resultType),
result);
3047LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
3048 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
3049 ConversionPatternRewriter &rewriter)
const {
3050 Location loc = op.getLoc();
3052 return rewriter.notifyMatchFailure(
3053 loc,
"Fp8 conversion instructions are not available on target "
3054 "architecture and their emulation is not implemented");
3055 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
3057 Type resultType = op.getResult().getType();
3060 Value source = adaptor.getSource();
3061 Value stoch = adaptor.getStochiasticParam();
3062 Value existing = adaptor.getExisting();
3064 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
3066 existing = LLVM::UndefOp::create(rewriter, loc, i32);
3070 result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
3071 existing, op.getStoreIndex());
3073 result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
3074 existing, op.getStoreIndex());
3076 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
3077 op, getTypeConverter()->convertType(resultType),
result);
3083struct AMDGPUDPPLowering :
public ConvertOpToLLVMPattern<DPPOp> {
3084 AMDGPUDPPLowering(
const LLVMTypeConverter &converter, Chipset chipset)
3085 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
3089 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
3090 ConversionPatternRewriter &rewriter)
const override {
3093 Location loc = DppOp.getLoc();
3094 Value src = adaptor.getSrc();
3095 Value old = adaptor.getOld();
3098 Type llvmType =
nullptr;
3100 llvmType = rewriter.getI32Type();
3101 }
else if (isa<FloatType>(srcType)) {
3103 ? rewriter.getF32Type()
3104 : rewriter.getF64Type();
3105 }
else if (isa<IntegerType>(srcType)) {
3107 ? rewriter.getI32Type()
3108 : rewriter.getI64Type();
3110 auto llvmSrcIntType = typeConverter->convertType(
3114 auto convertOperand = [&](Value operand, Type operandType) {
3115 if (operandType.getIntOrFloatBitWidth() <= 16) {
3116 if (llvm::isa<FloatType>(operandType)) {
3118 LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
3120 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
3121 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
3122 Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
3124 LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
3126 operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
3131 src = convertOperand(src, srcType);
3132 old = convertOperand(old, oldType);
3135 enum DppCtrl :
unsigned {
3144 ROW_HALF_MIRROR = 0x141,
3149 auto kind = DppOp.getKind();
3150 auto permArgument = DppOp.getPermArgument();
3151 uint32_t DppCtrl = 0;
3155 case DPPPerm::quad_perm: {
3156 auto quadPermAttr = cast<ArrayAttr>(*permArgument);
3158 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
3159 uint32_t num = elem.getInt();
3160 DppCtrl |= num << (i * 2);
3165 case DPPPerm::row_shl: {
3166 auto intAttr = cast<IntegerAttr>(*permArgument);
3167 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
3170 case DPPPerm::row_shr: {
3171 auto intAttr = cast<IntegerAttr>(*permArgument);
3172 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
3175 case DPPPerm::row_ror: {
3176 auto intAttr = cast<IntegerAttr>(*permArgument);
3177 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
3180 case DPPPerm::wave_shl:
3181 DppCtrl = DppCtrl::WAVE_SHL1;
3183 case DPPPerm::wave_shr:
3184 DppCtrl = DppCtrl::WAVE_SHR1;
3186 case DPPPerm::wave_rol:
3187 DppCtrl = DppCtrl::WAVE_ROL1;
3189 case DPPPerm::wave_ror:
3190 DppCtrl = DppCtrl::WAVE_ROR1;
3192 case DPPPerm::row_mirror:
3193 DppCtrl = DppCtrl::ROW_MIRROR;
3195 case DPPPerm::row_half_mirror:
3196 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
3198 case DPPPerm::row_bcast_15:
3199 DppCtrl = DppCtrl::BCAST15;
3201 case DPPPerm::row_bcast_31:
3202 DppCtrl = DppCtrl::BCAST31;
3208 auto rowMask = DppOp->getAttrOfType<IntegerAttr>(
"row_mask").getInt();
3209 auto bankMask = DppOp->getAttrOfType<IntegerAttr>(
"bank_mask").getInt();
3210 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>(
"bound_ctrl").getValue();
3214 ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
3215 rowMask, bankMask, boundCtrl);
3217 Value
result = dppMovOp.getRes();
3219 result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType,
result);
3220 if (!llvm::isa<IntegerType>(srcType)) {
3221 result = LLVM::BitcastOp::create(rewriter, loc, srcType,
result);
3232struct AMDGPUSwizzleBitModeLowering
3233 :
public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
3237 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
3238 ConversionPatternRewriter &rewriter)
const override {
3239 Location loc = op.getLoc();
3240 Type i32 = rewriter.getI32Type();
3241 Value src = adaptor.getSrc();
3242 SmallVector<Value> decomposed;
3243 if (
failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
3244 return rewriter.notifyMatchFailure(op,
3245 "failed to decompose value to i32");
3246 unsigned andMask = op.getAndMask();
3247 unsigned orMask = op.getOrMask();
3248 unsigned xorMask = op.getXorMask();
3252 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
3254 SmallVector<Value> swizzled;
3255 for (Value v : decomposed) {
3257 ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
3258 swizzled.emplace_back(res);
3261 Value
result = LLVM::composeValue(rewriter, loc, swizzled, src.
getType());
3262 rewriter.replaceOp(op,
result);
3267struct AMDGPUPermlaneLowering :
public ConvertOpToLLVMPattern<PermlaneSwapOp> {
3270 AMDGPUPermlaneLowering(
const LLVMTypeConverter &converter, Chipset chipset)
3271 : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
3275 matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
3276 ConversionPatternRewriter &rewriter)
const override {
3278 return op->emitOpError(
"permlane_swap is only supported on gfx950+");
3280 Location loc = op.getLoc();
3281 Type i32 = rewriter.getI32Type();
3282 Value src = adaptor.getSrc();
3283 unsigned rowLength = op.getRowLength();
3284 bool fi = op.getFetchInactive();
3285 bool boundctrl = op.getBoundCtrl();
3287 SmallVector<Value> decomposed;
3288 if (
failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
3289 return rewriter.notifyMatchFailure(op,
3290 "failed to decompose value to i32");
3292 SmallVector<Value> permuted;
3293 for (Value v : decomposed) {
3295 Type i32pair = LLVM::LLVMStructType::getLiteral(
3296 rewriter.getContext(), {v.getType(), v.getType()});
3298 if (rowLength == 16)
3299 res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
3301 else if (rowLength == 32)
3302 res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
3305 llvm_unreachable(
"unsupported row length");
3307 Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
3308 Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
3310 Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
3311 LLVM::ICmpPredicate::eq, vdst0, v);
3316 LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
3317 permuted.emplace_back(vdstNew);
3320 Value
result = LLVM::composeValue(rewriter, loc, permuted, src.
getType());
3321 rewriter.replaceOp(op,
result);
3326struct AMDGPUPermlaneVarLowering
3327 :
public ConvertOpToLLVMPattern<PermlaneVarOp> {
3330 AMDGPUPermlaneVarLowering(
const LLVMTypeConverter &converter, Chipset chipset)
3331 : ConvertOpToLLVMPattern<PermlaneVarOp>(converter), chipset(chipset) {}
3335 matchAndRewrite(PermlaneVarOp op, OpAdaptor adaptor,
3336 ConversionPatternRewriter &rewriter)
const override {
3338 return op->emitOpError(
"permlane_var is only supported on GFX12+");
3340 Location loc = op.getLoc();
3341 Type i32 = rewriter.getI32Type();
3342 Value src = adaptor.getSrc();
3343 Value selector = adaptor.getSelector();
3344 bool cross = op.getCross();
3345 bool fi = op.getFetchInactive();
3346 bool boundCtrl = op.getBoundCtrl();
3348 SmallVector<Value> decomposed;
3349 if (
failed(LLVM::decomposeValue(rewriter, loc, src, i32, decomposed)))
3350 return rewriter.notifyMatchFailure(op,
3351 "failed to decompose value to i32");
3353 SmallVector<Value> permuted;
3354 for (Value v : decomposed) {
3357 res = ROCDL::PermlaneX16VarOp::create(rewriter, loc, i32, v, v,
3358 selector, fi, boundCtrl);
3360 res = ROCDL::Permlane16VarOp::create(rewriter, loc, i32, v, v, selector,
3362 permuted.emplace_back(res);
3365 Value
result = LLVM::composeValue(rewriter, loc, permuted, src.
getType());
3366 rewriter.replaceOp(op,
result);
3379constexpr int32_t kDsBarrierPendingCountBitWidth = 29;
3380constexpr int32_t kDsBarrierPhasePos = kDsBarrierPendingCountBitWidth;
3381constexpr int32_t kDsBarrierInitCountPos = 32;
3382constexpr int32_t kDsBarrierPendingCountMask =
3383 (1 << kDsBarrierPendingCountBitWidth) - 1;
3385struct DsBarrierInitOpLowering
3386 :
public ConvertOpToLLVMPattern<DsBarrierInitOp> {
3389 DsBarrierInitOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
3390 : ConvertOpToLLVMPattern<DsBarrierInitOp>(converter), chipset(chipset) {}
3393 matchAndRewrite(DsBarrierInitOp op, OpAdaptor adaptor,
3394 ConversionPatternRewriter &rewriter)
const override {
3396 return op->emitOpError(
"only supported on gfx1250+");
3398 Location loc = op.getLoc();
3399 Type i64 = rewriter.getI64Type();
3401 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3403 adaptor.getBase(), adaptor.getIndices());
3410 LLVM::SubOp::create(rewriter, loc, adaptor.getParticipants(),
3417 Value maskedCount32 =
3418 LLVM::AndOp::create(rewriter, loc, initCount, countMask);
3419 Value maskedCount = LLVM::ZExtOp::create(rewriter, loc, i64, maskedCount32);
3421 Value initCountShifted = LLVM::ShlOp::create(
3422 rewriter, loc, maskedCount,
3424 Value barrierState =
3425 LLVM::OrOp::create(rewriter, loc, initCountShifted, maskedCount);
3427 LLVM::StoreOp::create(
3428 rewriter, loc, barrierState, ptr, 8,
false,
3430 false, LLVM::AtomicOrdering::release,
3433 rewriter.eraseOp(op);
3438struct DsBarrierPollStateOpLowering
3439 :
public ConvertOpToLLVMPattern<DsBarrierPollStateOp> {
3442 DsBarrierPollStateOpLowering(
const LLVMTypeConverter &converter,
3444 : ConvertOpToLLVMPattern<DsBarrierPollStateOp>(converter),
3448 matchAndRewrite(DsBarrierPollStateOp op, OpAdaptor adaptor,
3449 ConversionPatternRewriter &rewriter)
const override {
3451 return op->emitOpError(
"only supported on gfx1250+");
3453 Location loc = op.getLoc();
3454 Type i64 = rewriter.getI64Type();
3456 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3458 adaptor.getBase(), adaptor.getIndices());
3462 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
3463 op, i64, ptr, 8,
false,
3465 false, LLVM::AtomicOrdering::acquire,
3471struct DsAsyncBarrierArriveOpLowering
3472 :
public ConvertOpToLLVMPattern<DsAsyncBarrierArriveOp> {
3475 DsAsyncBarrierArriveOpLowering(
const LLVMTypeConverter &converter,
3477 : ConvertOpToLLVMPattern<DsAsyncBarrierArriveOp>(converter),
3481 matchAndRewrite(DsAsyncBarrierArriveOp op, OpAdaptor adaptor,
3482 ConversionPatternRewriter &rewriter)
const override {
3484 return op->emitOpError(
"only supported on gfx1250+");
3486 Location loc = op.getLoc();
3488 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3490 adaptor.getBase(), adaptor.getIndices());
3492 rewriter.replaceOpWithNewOp<ROCDL::DsAtomicAsyncBarrierArriveOp>(
3493 op, ptr,
nullptr,
nullptr,
3499struct DsBarrierArriveOpLowering
3500 :
public ConvertOpToLLVMPattern<DsBarrierArriveOp> {
3503 DsBarrierArriveOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
3504 : ConvertOpToLLVMPattern<DsBarrierArriveOp>(converter), chipset(chipset) {
3508 matchAndRewrite(DsBarrierArriveOp op, OpAdaptor adaptor,
3509 ConversionPatternRewriter &rewriter)
const override {
3511 return op->emitOpError(
"only supported on gfx1250+");
3513 Location loc = op.getLoc();
3514 Type i64 = rewriter.getI64Type();
3516 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
3518 adaptor.getBase(), adaptor.getIndices());
3520 rewriter.replaceOpWithNewOp<ROCDL::DsAtomicBarrierArriveRtnOp>(
3521 op, i64, ptr, adaptor.getCount(),
nullptr,
3527struct DsBarrierStatePhaseOpLowering
3528 :
public ConvertOpToLLVMPattern<DsBarrierStatePhaseOp> {
3532 matchAndRewrite(DsBarrierStatePhaseOp op, OpAdaptor adaptor,
3533 ConversionPatternRewriter &rewriter)
const override {
3534 Location loc = op.getLoc();
3535 Type i32 = rewriter.getI32Type();
3537 Value state = adaptor.getState();
3539 Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state);
3540 Value phase = LLVM::LShrOp::create(
3541 rewriter, loc, noInitCount,
3544 rewriter.replaceOp(op, phase);
3549struct DsBarrierStatePendingCountOpLowering
3550 :
public ConvertOpToLLVMPattern<DsBarrierStatePendingCountOp> {
3554 matchAndRewrite(DsBarrierStatePendingCountOp op, OpAdaptor adaptor,
3555 ConversionPatternRewriter &rewriter)
const override {
3556 Location loc = op.getLoc();
3557 Type i32 = rewriter.getI32Type();
3559 Value state = adaptor.getState();
3561 Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state);
3562 Value pendingCount = LLVM::AndOp::create(
3563 rewriter, loc, noInitCount,
3565 static_cast<uint32_t
>(kDsBarrierPendingCountMask)));
3567 rewriter.replaceOp(op, pendingCount);
3572struct DsBarrierStateInitCountOpLowering
3573 :
public ConvertOpToLLVMPattern<DsBarrierStateInitCountOp> {
3577 matchAndRewrite(DsBarrierStateInitCountOp op, OpAdaptor adaptor,
3578 ConversionPatternRewriter &rewriter)
const override {
3579 Location loc = op.getLoc();
3580 Type i32 = rewriter.getI32Type();
3582 Value state = adaptor.getState();
3584 Value initCountI64 = LLVM::LShrOp::create(
3585 rewriter, loc, state,
3587 Value initCount = LLVM::TruncOp::create(rewriter, loc, i32, initCountI64);
3589 rewriter.replaceOp(op, initCount);
3594struct DsBarrierStatePhaseParityLowering
3595 :
public ConvertOpToLLVMPattern<DsBarrierStatePhaseParity> {
3599 matchAndRewrite(DsBarrierStatePhaseParity op, OpAdaptor adaptor,
3600 ConversionPatternRewriter &rewriter)
const override {
3601 Location loc = op.getLoc();
3602 Type i1 = rewriter.getI1Type();
3604 Value state = adaptor.getState();
3607 LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), state);
3608 Value phase = LLVM::LShrOp::create(
3609 rewriter, loc, noInitCount,
3611 Value parity = LLVM::TruncOp::create(rewriter, loc, i1, phase);
3613 rewriter.replaceOp(op, parity);
3622static Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
3623 Value accumulator, Value value, int64_t shift) {
3628 value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
3634 constexpr bool isDisjoint =
true;
3635 return LLVM::OrOp::create(rewriter, loc, accumulator, value, isDisjoint);
3638template <
typename BaseOp>
3639struct AMDGPUMakeDmaBaseLowering :
public ConvertOpToLLVMPattern<BaseOp> {
3640 using ConvertOpToLLVMPattern<BaseOp>::ConvertOpToLLVMPattern;
3643 AMDGPUMakeDmaBaseLowering(
const LLVMTypeConverter &converter, Chipset chipset)
3644 : ConvertOpToLLVMPattern<BaseOp>(converter), chipset(chipset) {}
3648 matchAndRewrite(BaseOp op, Adaptor adaptor,
3649 ConversionPatternRewriter &rewriter)
const override {
3651 return op->emitOpError(
"make_dma_base is only supported on gfx1250");
3653 Location loc = op.getLoc();
3655 constexpr int32_t constlen = 4;
3656 Value consts[constlen];
3657 for (int64_t i = 0; i < constlen; ++i)
3660 constexpr int32_t sgprslen = constlen;
3661 Value sgprs[sgprslen];
3662 for (int64_t i = 0; i < sgprslen; ++i) {
3663 sgprs[i] = consts[0];
3666 sgprs[0] = consts[1];
3668 if constexpr (BaseOp::isGather()) {
3669 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30);
3671 auto type = cast<TDMGatherBaseType>(op.getResult().getType());
3672 Type indexType = type.getIndexType();
3674 assert(llvm::is_contained({16u, 32u}, indexSize) &&
3675 "expected index_size to be 16 or 32");
3676 unsigned idx = (indexSize / 16) - 1;
3679 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 31);
3682 ValueRange ldsIndices = adaptor.getLdsIndices();
3683 Value lds = adaptor.getLds();
3684 auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
3687 rewriter, loc, ldsMemRefType, lds, ldsIndices);
3689 ValueRange globalIndices = adaptor.getGlobalIndices();
3690 Value global = adaptor.getGlobal();
3691 auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
3694 rewriter, loc, globalMemRefType, global, globalIndices);
3696 Type i32 = rewriter.getI32Type();
3697 Type i64 = rewriter.getI64Type();
3699 sgprs[1] = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
3700 Value castForGlobalAddr =
3701 LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
3703 sgprs[2] = LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
3705 Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
3708 Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
3711 highHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
3713 sgprs[3] = setValueAtOffset(rewriter, loc, highHalf, consts[2], 30);
3715 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3716 assert(v4i32 &&
"expected type conversion to succeed");
3717 Value
result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3719 for (
auto [sgpr, constant] : llvm::zip_equal(sgprs, consts))
3721 LLVM::InsertElementOp::create(rewriter, loc,
result, sgpr, constant);
3723 rewriter.replaceOp(op,
result);
3728template <
typename DescriptorOp>
3729struct AMDGPULowerDescriptor :
public ConvertOpToLLVMPattern<DescriptorOp> {
3730 using ConvertOpToLLVMPattern<DescriptorOp>::ConvertOpToLLVMPattern;
3733 AMDGPULowerDescriptor(
const LLVMTypeConverter &converter, Chipset chipset)
3734 : ConvertOpToLLVMPattern<DescriptorOp>(converter), chipset(chipset) {}
3737 Value getDGroup0(OpAdaptor adaptor)
const {
return adaptor.getBase(); }
3739 Value setWorkgroupMask(DescriptorOp op, OpAdaptor adaptor,
3740 ConversionPatternRewriter &rewriter, Location loc,
3741 Value sgpr0)
const {
3742 Value mask = op.getWorkgroupMask();
3746 Type i16 = rewriter.getI16Type();
3747 mask = LLVM::BitcastOp::create(rewriter, loc, i16, mask);
3748 Type i32 = rewriter.getI32Type();
3749 Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask);
3750 return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0);
3753 Value setDataSize(DescriptorOp op, OpAdaptor adaptor,
3754 ConversionPatternRewriter &rewriter, Location loc,
3755 Value sgpr0, ArrayRef<Value> consts)
const {
3756 unsigned elementTypeWidthInBits = op.getElementTypeWidth();
3757 assert(llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidthInBits) &&
3758 "expected type width to be 8, 16, 32, or 64.");
3759 int64_t idx = llvm::Log2_32(elementTypeWidthInBits / 8);
3760 Value size = consts[idx];
3761 return setValueAtOffset(rewriter, loc, sgpr0, size, 16);
3764 Value setAtomicBarrier(DescriptorOp op, OpAdaptor adaptor,
3765 ConversionPatternRewriter &rewriter, Location loc,
3766 Value sgpr0, ArrayRef<Value> consts)
const {
3767 if (!adaptor.getAtomicBarrierAddress())
3770 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
3773 Value setIterateEnable(DescriptorOp op, OpAdaptor adaptor,
3774 ConversionPatternRewriter &rewriter, Location loc,
3775 Value sgpr0, ArrayRef<Value> consts)
const {
3776 if (!adaptor.getGlobalIncrement())
3781 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
3784 Value setPadEnable(DescriptorOp op, OpAdaptor adaptor,
3785 ConversionPatternRewriter &rewriter, Location loc,
3786 Value sgpr0, ArrayRef<Value> consts)
const {
3787 if (!op.getPadAmount())
3790 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
3793 Value setEarlyTimeout(DescriptorOp op, OpAdaptor adaptor,
3794 ConversionPatternRewriter &rewriter, Location loc,
3795 Value sgpr0, ArrayRef<Value> consts)
const {
3796 if (!op.getWorkgroupMask())
3799 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21);
3802 Value setPadInterval(DescriptorOp op, OpAdaptor adaptor,
3803 ConversionPatternRewriter &rewriter, Location loc,
3804 Value sgpr0, ArrayRef<Value> consts)
const {
3805 if (!op.getPadAmount())
3814 IntegerType i32 = rewriter.getI32Type();
3815 Value padInterval = adaptor.getPadInterval();
3816 padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
3817 padInterval,
false);
3818 padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
3820 return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
3823 Value setPadAmount(DescriptorOp op, OpAdaptor adaptor,
3824 ConversionPatternRewriter &rewriter, Location loc,
3825 Value sgpr0, ArrayRef<Value> consts)
const {
3826 if (!op.getPadAmount())
3835 Value padAmount = adaptor.getPadAmount();
3836 padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
3838 return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
3841 Value setAtomicBarrierAddress(DescriptorOp op, OpAdaptor adaptor,
3842 ConversionPatternRewriter &rewriter,
3843 Location loc, Value sgpr1,
3844 ArrayRef<Value> consts)
const {
3845 if (!adaptor.getAtomicBarrierAddress())
3848 Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
3849 auto barrierAddressTy =
3850 cast<MemRefType>(op.getAtomicBarrierAddress().getType());
3851 ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
3853 rewriter, loc, barrierAddressTy, atomicBarrierAddress,
3854 atomicBarrierIndices);
3855 IntegerType i32 = rewriter.getI32Type();
3861 atomicBarrierAddress =
3862 LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
3863 atomicBarrierAddress =
3864 LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
3866 atomicBarrierAddress =
3867 LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
3868 return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
3871 std::pair<Value, Value> setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3872 ConversionPatternRewriter &rewriter,
3873 Location loc, Value sgpr1, Value sgpr2,
3874 ArrayRef<Value> consts, uint64_t dimX,
3875 uint32_t offset)
const {
3876 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3877 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3878 SmallVector<OpFoldResult> mixedGlobalSizes =
3880 if (mixedGlobalSizes.size() <= dimX)
3881 return {sgpr1, sgpr2};
3883 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3890 if (
auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3894 IntegerType i32 = rewriter.getI32Type();
3895 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3896 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3899 sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDimX, offset);
3902 Value tensorDimXHigh = LLVM::LShrOp::create(rewriter, loc, tensorDimX, c16);
3903 sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDimXHigh, offset + 16);
3904 return {sgpr1, sgpr2};
3907 std::pair<Value, Value> setTensorDim0(DescriptorOp op, OpAdaptor adaptor,
3908 ConversionPatternRewriter &rewriter,
3909 Location loc, Value sgpr1, Value sgpr2,
3910 ArrayRef<Value> consts)
const {
3911 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, 0,
3915 std::pair<Value, Value> setTensorDim1(DescriptorOp op, OpAdaptor adaptor,
3916 ConversionPatternRewriter &rewriter,
3917 Location loc, Value sgpr2, Value sgpr3,
3918 ArrayRef<Value> consts)
const {
3919 return setTensorDimX(op, adaptor, rewriter, loc, sgpr2, sgpr3, consts, 1,
3923 Value setTileDimX(DescriptorOp op, OpAdaptor adaptor,
3924 ConversionPatternRewriter &rewriter, Location loc,
3925 Value sgpr, ArrayRef<Value> consts,
size_t dimX,
3926 int64_t offset)
const {
3927 ArrayRef<int64_t> sharedStaticSizes = adaptor.getSharedStaticSizes();
3928 ValueRange sharedDynamicSizes = adaptor.getSharedDynamicSizes();
3929 SmallVector<OpFoldResult> mixedSharedSizes =
3931 if (mixedSharedSizes.size() <= dimX)
3934 OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
3943 if (
auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult)) {
3947 IntegerType i32 = rewriter.getI32Type();
3948 tileDimX = cast<Value>(tileDimXOpFoldResult);
3949 tileDimX = LLVM::TruncOp::create(rewriter, loc, i32, tileDimX);
3952 return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
3955 Value setTileDim0(DescriptorOp op, OpAdaptor adaptor,
3956 ConversionPatternRewriter &rewriter, Location loc,
3957 Value sgpr3, ArrayRef<Value> consts)
const {
3958 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
3961 Value setTileDim1(DescriptorOp op, OpAdaptor adaptor,
3962 ConversionPatternRewriter &rewriter, Location loc,
3963 Value sgpr4, ArrayRef<Value> consts)
const {
3964 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
3967 Value setValidIndices(DescriptorOp op, OpAdaptor adaptor,
3968 ConversionPatternRewriter &rewriter, Location loc,
3969 Value sgpr4, ArrayRef<Value> consts)
const {
3970 auto type = cast<VectorType>(op.getIndices().getType());
3971 ArrayRef<int64_t> shape = type.getShape();
3972 assert(shape.size() == 1 &&
"expected shape to be of rank 1.");
3973 unsigned length = shape.back();
3974 assert(0 < length && length <= 16 &&
"expected length to be at most 16.");
3976 return setValueAtOffset(rewriter, loc, sgpr4, value, 128);
3979 Value setTileDim1OrValidIndices(DescriptorOp op, OpAdaptor adaptor,
3980 ConversionPatternRewriter &rewriter,
3981 Location loc, Value sgpr4,
3982 ArrayRef<Value> consts)
const {
3983 if constexpr (DescriptorOp::isGather())
3984 return setValidIndices(op, adaptor, rewriter, loc, sgpr4, consts);
3985 return setTileDim1(op, adaptor, rewriter, loc, sgpr4, consts);
3988 Value setTileDim2(DescriptorOp op, OpAdaptor adaptor,
3989 ConversionPatternRewriter &rewriter, Location loc,
3990 Value sgpr4, ArrayRef<Value> consts)
const {
3992 if constexpr (DescriptorOp::isGather())
3994 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
3997 std::pair<Value, Value>
3998 setTensorDimXStride(DescriptorOp op, OpAdaptor adaptor,
3999 ConversionPatternRewriter &rewriter, Location loc,
4000 Value sgprY, Value sgprZ, ArrayRef<Value> consts,
4001 size_t dimX, int64_t offset)
const {
4002 ArrayRef<int64_t> globalStaticStrides = adaptor.getGlobalStaticStrides();
4003 ValueRange globalDynamicStrides = adaptor.getGlobalDynamicStrides();
4004 SmallVector<OpFoldResult> mixedGlobalStrides =
4005 getMixedValues(globalStaticStrides, globalDynamicStrides, rewriter);
4007 if (mixedGlobalStrides.size() <= (dimX + 1))
4008 return {sgprY, sgprZ};
4010 OpFoldResult tensorDimXStrideOpFoldResult =
4011 *(mixedGlobalStrides.rbegin() + dimX + 1);
4016 Value tensorDimXStride;
4017 if (
auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult))
4021 tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
4023 constexpr int64_t first48bits = (1ll << 48) - 1;
4026 LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
4027 IntegerType i32 = rewriter.getI32Type();
4028 Value tensorDimXStrideLow =
4029 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
4030 sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
4032 int64_t shift = (offset % 32) == 0 ? 32 : offset % 32;
4034 Value tensorDimXStrideHigh =
4035 LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
4036 tensorDimXStrideHigh =
4037 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
4038 sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
4040 return {sgprY, sgprZ};
4043 std::pair<Value, Value>
4044 setTensorDim0Stride(DescriptorOp op, OpAdaptor adaptor,
4045 ConversionPatternRewriter &rewriter, Location loc,
4046 Value sgpr5, Value sgpr6, ArrayRef<Value> consts)
const {
4047 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
4051 std::pair<Value, Value>
4052 setTensorDim1Stride(DescriptorOp op, OpAdaptor adaptor,
4053 ConversionPatternRewriter &rewriter, Location loc,
4054 Value sgpr5, Value sgpr6, ArrayRef<Value> consts)
const {
4056 if constexpr (DescriptorOp::isGather())
4057 return {sgpr5, sgpr6};
4058 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
4062 Value getDGroup1(DescriptorOp op, OpAdaptor adaptor,
4063 ConversionPatternRewriter &rewriter, Location loc,
4064 ArrayRef<Value> consts)
const {
4066 for (int64_t i = 0; i < 8; ++i) {
4067 sgprs[i] = consts[0];
4070 sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]);
4071 sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts);
4072 sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts);
4073 sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
4074 sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
4075 sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts);
4076 sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts);
4077 sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts);
4080 setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts);
4081 std::tie(sgprs[1], sgprs[2]) =
4082 setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
4083 std::tie(sgprs[2], sgprs[3]) =
4084 setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
4086 sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
4088 setTileDim1OrValidIndices(op, adaptor, rewriter, loc, sgprs[4], consts);
4089 sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
4090 std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
4091 op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
4092 std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride(
4093 op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts);
4095 IntegerType i32 = rewriter.getI32Type();
4096 Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32));
4097 assert(v8i32 &&
"expected type conversion to succeed");
4098 Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32);
4100 for (
auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) {
4102 LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant);
4108 Value setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
4109 ConversionPatternRewriter &rewriter, Location loc,
4110 Value sgpr0, ArrayRef<Value> consts, int64_t dimX,
4111 int64_t offset)
const {
4112 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
4113 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
4114 SmallVector<OpFoldResult> mixedGlobalSizes =
4116 if (mixedGlobalSizes.size() <=
static_cast<unsigned long>(dimX))
4119 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
4121 if (
auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
4125 IntegerType i32 = rewriter.getI32Type();
4126 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
4127 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
4130 return setValueAtOffset(rewriter, loc, sgpr0, tensorDimX, offset);
4133 Value setTensorDim2(DescriptorOp op, OpAdaptor adaptor,
4134 ConversionPatternRewriter &rewriter, Location loc,
4135 Value sgpr0, ArrayRef<Value> consts)
const {
4136 return setTensorDimX(op, adaptor, rewriter, loc, sgpr0, consts, 2, 0);
4139 Value truncateAndSetValueAtOffset(ConversionPatternRewriter &rewriter,
4140 Location loc, Value accumulator,
4141 Value value, int64_t shift)
const {
4143 IntegerType i32 = rewriter.getI32Type();
4144 value = LLVM::TruncOp::create(rewriter, loc, i32, value);
4145 return setValueAtOffset(rewriter, loc, accumulator, value, shift);
4148 Value setLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
4149 ConversionPatternRewriter &rewriter, Location loc,
4150 Value sgpr1, ArrayRef<Value> consts,
4151 int64_t offset)
const {
4152 Value ldsAddrIncrement = adaptor.getLdsIncrement();
4153 return setValueAtOffset(rewriter, loc, sgpr1, ldsAddrIncrement, offset);
4156 std::pair<Value, Value>
4157 setGlobalAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
4158 ConversionPatternRewriter &rewriter, Location loc,
4159 Value sgpr2, Value sgpr3, ArrayRef<Value> consts,
4160 int64_t offset)
const {
4161 Value globalAddrIncrement = adaptor.getGlobalIncrement();
4162 sgpr2 = truncateAndSetValueAtOffset(rewriter, loc, sgpr2,
4163 globalAddrIncrement, offset);
4165 globalAddrIncrement =
4166 LLVM::LShrOp::create(rewriter, loc, globalAddrIncrement, shift);
4167 constexpr int64_t first16BitsHigh = (1ll << 16) - 1;
4168 sgpr3 = truncateAndSetValueAtOffset(rewriter, loc, sgpr3,
4169 globalAddrIncrement, offset + 32);
4171 sgpr3 = LLVM::AndOp::create(rewriter, loc, sgpr3, mask);
4172 return {sgpr2, sgpr3};
4175 Value setTensorDim3OrLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
4176 ConversionPatternRewriter &rewriter,
4177 Location loc, Value sgpr1,
4178 ArrayRef<Value> consts)
const {
4179 Value ldsIncrement = op.getLdsIncrement();
4180 constexpr int64_t dim = 3;
4181 constexpr int64_t offset = 32;
4183 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, consts, dim,
4185 return setLDSAddrIncrement(op, adaptor, rewriter, loc, sgpr1, consts,
4189 std::pair<Value, Value> setTensorDim2StrideOrGlobalAddrIncrement(
4190 DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
4191 Location loc, Value sgpr2, Value sgpr3, ArrayRef<Value> consts)
const {
4192 Value globalIncrement = op.getGlobalIncrement();
4193 constexpr int32_t dim = 2;
4194 constexpr int32_t offset = 64;
4195 if (!globalIncrement)
4196 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr2, sgpr3,
4197 consts, dim, offset);
4198 return setGlobalAddrIncrement(op, adaptor, rewriter, loc, sgpr2, sgpr3,
4202 Value setIterateCount(DescriptorOp op, OpAdaptor adaptor,
4203 ConversionPatternRewriter &rewriter, Location loc,
4204 Value sgpr3, ArrayRef<Value> consts,
4205 int32_t offset)
const {
4206 Value iterationCount = adaptor.getIterationCount();
4207 IntegerType i32 = rewriter.getI32Type();
4214 iterationCount = LLVM::TruncOp::create(rewriter, loc, i32, iterationCount);
4216 LLVM::SubOp::create(rewriter, loc, iterationCount, consts[1]);
4217 return setValueAtOffset(rewriter, loc, sgpr3, iterationCount, offset);
4220 Value setTileDim3OrIterateCount(DescriptorOp op, OpAdaptor adaptor,
4221 ConversionPatternRewriter &rewriter,
4222 Location loc, Value sgpr3,
4223 ArrayRef<Value> consts)
const {
4224 Value iterateCount = op.getIterationCount();
4225 constexpr int32_t dim = 2;
4226 constexpr int32_t offset = 112;
4228 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, dim,
4231 return setIterateCount(op, adaptor, rewriter, loc, sgpr3, consts, offset);
4234 Value getDGroup2(DescriptorOp op, OpAdaptor adaptor,
4235 ConversionPatternRewriter &rewriter, Location loc,
4236 ArrayRef<Value> consts)
const {
4237 if constexpr (DescriptorOp::isGather())
4238 return getDGroup2Gather(op, adaptor, rewriter, loc, consts);
4239 return getDGroup2NonGather(op, adaptor, rewriter, loc, consts);
4242 Value getDGroup2NonGather(DescriptorOp op, OpAdaptor adaptor,
4243 ConversionPatternRewriter &rewriter, Location loc,
4244 ArrayRef<Value> consts)
const {
4245 IntegerType i32 = rewriter.getI32Type();
4246 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
4247 assert(v4i32 &&
"expected type conversion to succeed.");
4249 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
4250 if (onlyNeedsTwoDescriptors)
4251 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
4253 constexpr int64_t sgprlen = 4;
4254 Value sgprs[sgprlen];
4255 for (
int i = 0; i < sgprlen; ++i)
4256 sgprs[i] = consts[0];
4258 sgprs[0] = setTensorDim2(op, adaptor, rewriter, loc, sgprs[0], consts);
4259 sgprs[1] = setTensorDim3OrLDSAddrIncrement(op, adaptor, rewriter, loc,
4261 std::tie(sgprs[2], sgprs[3]) = setTensorDim2StrideOrGlobalAddrIncrement(
4262 op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
4264 setTileDim3OrIterateCount(op, adaptor, rewriter, loc, sgprs[3], consts);
4266 Value dgroup2 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
4267 for (
auto [sgpr, constant] : llvm::zip(sgprs, consts))
4269 LLVM::InsertElementOp::create(rewriter, loc, dgroup2, sgpr, constant);
4274 Value getGatherIndices(DescriptorOp op, OpAdaptor adaptor,
4275 ConversionPatternRewriter &rewriter, Location loc,
4276 ArrayRef<Value> consts,
bool firstHalf)
const {
4277 IntegerType i32 = rewriter.getI32Type();
4278 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
4279 assert(v4i32 &&
"expected type conversion to succeed.");
4281 Value
indices = adaptor.getIndices();
4282 auto vectorType = cast<VectorType>(
indices.getType());
4283 unsigned length = vectorType.getShape().back();
4284 Type elementType = vectorType.getElementType();
4285 unsigned maxLength = elementType == i32 ? 4 : 8;
4286 int32_t offset = firstHalf ? 0 : maxLength;
4287 unsigned discountedLength =
4288 std::max(
static_cast<int32_t
>(length - offset), 0);
4290 unsigned targetSize = std::min(maxLength, discountedLength);
4292 SmallVector<Value> indicesVector;
4293 for (
unsigned i = offset; i < targetSize + offset; ++i) {
4295 if (i < consts.size())
4299 Value elem = LLVM::ExtractElementOp::create(rewriter, loc,
indices, idx);
4300 indicesVector.push_back(elem);
4303 SmallVector<Value> indicesI32Vector;
4304 if (elementType == i32) {
4305 indicesI32Vector = indicesVector;
4307 for (
unsigned i = 0; i < targetSize; ++i) {
4308 Value index = indicesVector[i];
4309 indicesI32Vector.push_back(
4310 LLVM::ZExtOp::create(rewriter, loc, i32, index));
4312 if ((targetSize % 2) != 0)
4314 indicesI32Vector.push_back(consts[0]);
4317 SmallVector<Value> indicesToInsert;
4318 if (elementType == i32) {
4319 indicesToInsert = indicesI32Vector;
4321 unsigned size = indicesI32Vector.size() / 2;
4322 for (
unsigned i = 0; i < size; ++i) {
4323 Value first = indicesI32Vector[2 * i];
4324 Value second = indicesI32Vector[2 * i + 1];
4325 Value joined = setValueAtOffset(rewriter, loc, first, second, 16);
4326 indicesToInsert.push_back(joined);
4330 Value dgroup = LLVM::PoisonOp::create(rewriter, loc, v4i32);
4331 for (
auto [sgpr, constant] : llvm::zip_first(indicesToInsert, consts))
4333 LLVM::InsertElementOp::create(rewriter, loc, dgroup, sgpr, constant);
4338 Value getDGroup2Gather(DescriptorOp op, OpAdaptor adaptor,
4339 ConversionPatternRewriter &rewriter, Location loc,
4340 ArrayRef<Value> consts)
const {
4341 return getGatherIndices(op, adaptor, rewriter, loc, consts,
true);
4344 std::pair<Value, Value>
4345 setTensorDim3Stride(DescriptorOp op, OpAdaptor adaptor,
4346 ConversionPatternRewriter &rewriter, Location loc,
4347 Value sgpr0, Value sgpr1, ArrayRef<Value> consts)
const {
4348 constexpr int32_t dim = 3;
4349 constexpr int32_t offset = 0;
4350 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr0, sgpr1, consts,
4354 std::pair<Value, Value> setTensorDim4(DescriptorOp op, OpAdaptor adaptor,
4355 ConversionPatternRewriter &rewriter,
4356 Location loc, Value sgpr1, Value sgpr2,
4357 ArrayRef<Value> consts)
const {
4358 constexpr int32_t dim = 4;
4359 constexpr int32_t offset = 48;
4360 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, dim,
4364 Value setTileDim4(DescriptorOp op, OpAdaptor adaptor,
4365 ConversionPatternRewriter &rewriter, Location loc,
4366 Value sgpr2, ArrayRef<Value> consts)
const {
4367 constexpr int32_t dim = 4;
4368 constexpr int32_t offset = 80;
4369 return setTileDimX(op, adaptor, rewriter, loc, sgpr2, consts, dim, offset);
4372 Value getDGroup3(DescriptorOp op, OpAdaptor adaptor,
4373 ConversionPatternRewriter &rewriter, Location loc,
4374 ArrayRef<Value> consts)
const {
4375 if constexpr (DescriptorOp::isGather())
4376 return getDGroup3Gather(op, adaptor, rewriter, loc, consts);
4377 return getDGroup3NonGather(op, adaptor, rewriter, loc, consts);
4380 Value getDGroup3NonGather(DescriptorOp op, OpAdaptor adaptor,
4381 ConversionPatternRewriter &rewriter, Location loc,
4382 ArrayRef<Value> consts)
const {
4383 IntegerType i32 = rewriter.getI32Type();
4384 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
4385 assert(v4i32 &&
"expected type conversion to succeed.");
4386 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
4387 if (onlyNeedsTwoDescriptors)
4388 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
4390 constexpr int32_t sgprlen = 4;
4391 Value sgprs[sgprlen];
4392 for (
int i = 0; i < sgprlen; ++i)
4393 sgprs[i] = consts[0];
4395 std::tie(sgprs[0], sgprs[1]) = setTensorDim3Stride(
4396 op, adaptor, rewriter, loc, sgprs[0], sgprs[1], consts);
4397 std::tie(sgprs[1], sgprs[2]) =
4398 setTensorDim4(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
4399 sgprs[2] = setTileDim4(op, adaptor, rewriter, loc, sgprs[2], consts);
4401 Value dgroup3 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
4402 for (
auto [sgpr, constant] : llvm::zip(sgprs, consts))
4404 LLVM::InsertElementOp::create(rewriter, loc, dgroup3, sgpr, constant);
4409 Value getDGroup3Gather(DescriptorOp op, OpAdaptor adaptor,
4410 ConversionPatternRewriter &rewriter, Location loc,
4411 ArrayRef<Value> consts)
const {
4412 return getGatherIndices(op, adaptor, rewriter, loc, consts,
false);
4416 matchAndRewrite(DescriptorOp op, OpAdaptor adaptor,
4417 ConversionPatternRewriter &rewriter)
const override {
4419 return op->emitOpError(
4420 "make_dma_descriptor is only supported on gfx1250");
4422 Location loc = op.getLoc();
4424 SmallVector<Value> consts;
4425 for (int64_t i = 0; i < 8; ++i)
4428 Value dgroup0 = this->getDGroup0(adaptor);
4429 Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts);
4430 Value dgroup2 = this->getDGroup2(op, adaptor, rewriter, loc, consts);
4431 Value dgroup3 = this->getDGroup3(op, adaptor, rewriter, loc, consts);
4432 SmallVector<Value> results = {dgroup0, dgroup1, dgroup2, dgroup3};
4433 rewriter.replaceOpWithMultiple(op, {results});
4438template <
typename SourceOp,
typename TargetOp>
4439struct AMDGPUTensorLoadStoreOpLowering
4440 :
public ConvertOpToLLVMPattern<SourceOp> {
4441 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
4443 AMDGPUTensorLoadStoreOpLowering(
const LLVMTypeConverter &converter,
4445 : ConvertOpToLLVMPattern<SourceOp>(converter), chipset(chipset) {}
4449 matchAndRewrite(SourceOp op, Adaptor adaptor,
4450 ConversionPatternRewriter &rewriter)
const override {
4452 return op->emitOpError(
"is only supported on gfx1250");
4457 auto v8i32 = VectorType::get(8, rewriter.getI32Type());
4458 Value dgroup4 = LLVM::ZeroOp::create(rewriter, op.getLoc(), v8i32);
4459 Attribute cachePolicy = rewriter.getI32IntegerAttr(0);
4460 rewriter.replaceOpWithNewOp<TargetOp>(op, desc[0], desc[1], desc[2],
4461 desc[3], dgroup4, cachePolicy,
4469struct GlobalPrefetchOpLowering
4470 :
public ConvertOpToLLVMPattern<GlobalPrefetchOp> {
4471 GlobalPrefetchOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
4472 : ConvertOpToLLVMPattern<GlobalPrefetchOp>(converter), chipset(chipset) {}
4475 matchAndRewrite(GlobalPrefetchOp op, GlobalPrefetchOpAdaptor adaptor,
4476 ConversionPatternRewriter &rewriter)
const override {
4478 return op->emitOpError(
"is only supported on gfx1250+");
4480 const bool isSpeculative = op.getSpeculative();
4482 op.getTemporalHint(), op.getCacheScope(), isSpeculative);
4485 Attribute cachePolicy = ROCDL::Gfx12CachePolicyAttr::get(
4486 rewriter.getContext(),
4487 static_cast<ROCDL::Gfx12CachePolicy
>(immArgValue));
4490 Value memRef = adaptor.getSrc();
4491 MemRefDescriptor descriptor(memRef);
4492 MemRefType memRefType = op.getSrc().getType();
4493 Location loc = op->getLoc();
4494 auto inboundsFlags = isSpeculative ? LLVM::GEPNoWrapFlags::none
4495 : LLVM::GEPNoWrapFlags::inbounds |
4496 LLVM::GEPNoWrapFlags::nuw;
4498 rewriter, loc, memRefType, descriptor,
indices, inboundsFlags);
4500 rewriter.replaceOpWithNewOp<ROCDL::GlobalPrefetchOp>(
4501 op, prefetchPtr, cachePolicy, mlir::ArrayAttr{}, mlir::ArrayAttr{},
4510struct ConvertAMDGPUToROCDLPass
4511 :
public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
4514 void runOnOperation()
override {
4517 if (
failed(maybeChipset)) {
4518 emitError(UnknownLoc::get(ctx),
"Invalid chipset name: " + chipset);
4519 return signalPassFailure();
4522 RewritePatternSet patterns(ctx);
4523 LLVMTypeConverter converter(ctx);
4526 amdgpu::populateCommonGPUTypeAndAttributeConversions(converter);
4528 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
4529 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
4530 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
4531 if (
failed(applyPartialConversion(getOperation(),
target,
4532 std::move(patterns))))
4533 signalPassFailure();
4541 typeConverter, [](gpu::AddressSpace space) {
4543 case gpu::AddressSpace::Global:
4544 return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace;
4545 case gpu::AddressSpace::Workgroup:
4546 return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace;
4547 case gpu::AddressSpace::Private:
4548 return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace;
4549 case gpu::AddressSpace::Constant:
4550 return ROCDL::ROCDLDialect::kConstantMemoryAddressSpace;
4552 llvm_unreachable(
"unknown address space enum value");
4555 return LLVM::LLVMPointerType::get(
4556 type.getContext(), ROCDL::ROCDLDialect::kSharedMemoryAddressSpace);
4562 typeConverter.addTypeAttributeConversion(
4564 -> TypeConverter::AttributeConversionResult {
4566 Type i64 = IntegerType::get(ctx, 64);
4567 switch (as.getValue()) {
4568 case amdgpu::AddressSpace::FatRawBuffer:
4569 return IntegerAttr::get(i64, 7);
4570 case amdgpu::AddressSpace::BufferRsrc:
4571 return IntegerAttr::get(i64, 8);
4572 case amdgpu::AddressSpace::FatStructuredBuffer:
4573 return IntegerAttr::get(i64, 9);
4575 return TypeConverter::AttributeConversionResult::abort();
4577 typeConverter.addConversion([&](DsBarrierStateType type) ->
Type {
4578 return IntegerType::get(type.
getContext(), 64);
4580 typeConverter.addConversion([&](TDMBaseType type) ->
Type {
4582 return typeConverter.convertType(VectorType::get(4, i32));
4584 typeConverter.addConversion([&](TDMGatherBaseType type) ->
Type {
4586 return typeConverter.convertType(VectorType::get(4, i32));
4588 typeConverter.addConversion(
4589 [&](TDMDescriptorType type,
4592 Type v4i32 = typeConverter.convertType(VectorType::get(4, i32));
4593 Type v8i32 = typeConverter.convertType(VectorType::get(8, i32));
4594 llvm::append_values(
result, v4i32, v8i32, v4i32, v4i32);
4604 if (inputs.size() != 1)
4607 if (!isa<TDMDescriptorType>(inputs[0].
getType()))
4610 auto cast = UnrealizedConversionCastOp::create(builder, loc, types, inputs);
4611 return cast.getResults();
4614 typeConverter.addTargetMaterialization(addUnrealizedCast);
4622 .
add<FatRawBufferCastLowering,
4623 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
4624 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
4625 RawBufferOpLowering<RawBufferAtomicFaddOp,
4626 ROCDL::RawPtrBufferAtomicFaddOp>,
4627 RawBufferOpLowering<RawBufferAtomicFmaxOp,
4628 ROCDL::RawPtrBufferAtomicFmaxOp>,
4629 RawBufferOpLowering<RawBufferAtomicSmaxOp,
4630 ROCDL::RawPtrBufferAtomicSmaxOp>,
4631 RawBufferOpLowering<RawBufferAtomicUminOp,
4632 ROCDL::RawPtrBufferAtomicUminOp>,
4633 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
4634 ROCDL::RawPtrBufferAtomicCmpSwap>,
4635 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
4636 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
4637 SparseMFMAOpLowering, WMMAOpLowering, ScaledWMMAOpLowering,
4638 SparseWMMAOpLowering, DotOpLowering, ExtPackedFp8OpLowering,
4639 ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering,
4640 PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
4641 PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
4642 GlobalLoadAsyncToLDSOpLowering, TransposeLoadOpLowering,
4643 GlobalTransposeLoadOpLowering, AMDGPUPermlaneLowering,
4644 AMDGPUPermlaneVarLowering, AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
4645 AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
4646 AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
4647 AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>,
4648 AMDGPUTensorLoadStoreOpLowering<TensorLoadToLDSOp,
4649 ROCDL::TensorLoadToLDSOp>,
4650 AMDGPUTensorLoadStoreOpLowering<TensorStoreFromLDSOp,
4651 ROCDL::TensorStoreFromLDSOp>,
4652 DsBarrierInitOpLowering, DsBarrierPollStateOpLowering,
4653 DsAsyncBarrierArriveOpLowering, DsBarrierArriveOpLowering,
4654 GlobalPrefetchOpLowering>(converter, chipset);
4655 patterns.
add<AMDGPUSwizzleBitModeLowering, DsBarrierStatePhaseOpLowering,
4656 DsBarrierStatePendingCountOpLowering,
4657 DsBarrierStateInitCountOpLowering,
4658 DsBarrierStatePhaseParityLowering>(converter);
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k, bool isRDNA3)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma for RDNA3/4 architectures.
static bool hasDot10Insts(const Chipset &chipset)
static bool hasDot7Insts(const Chipset &chipset)
static std::optional< SparseWMMAOpInfo > sparseWMMAOpToIntrinsic(SparseWMMAOp swmmac, Chipset chipset)
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
constexpr Chipset kGfx908
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs, StringRef attrName)
Push an input operand.
static std::optional< ScaledMFMAIntrinsic > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
constexpr Chipset kGfx1250
static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA/WMMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to R...
constexpr Chipset kGfx90a
static std::optional< StringRef > getScaledWmmaIntrinsicName(int64_t m, int64_t n, int64_t k, bool isScale16)
Determines the ROCDL intrinsic name for scaled WMMA based on dimensions and scale block size (16 or 3...
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static bool hasDot11Insts(const Chipset &chipset)
static std::optional< StringRef > smfmacOpToIntrinsic(SparseMFMAOp op, Chipset chipset)
Returns the rocdl intrinsic corresponding to a SparseMFMA (smfmac) operation if one exists.
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value)
static bool hasDot9Insts(const Chipset &chipset)
static std::optional< StringRef > wmmaOpToIntrinsicGfx1250(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k)
Return the rocdl intrinsic corresponding to a WMMA operation wmma for the gfx1250 architecture.
constexpr Chipset kGfx1200
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, int64_t elementByteWidth, amdgpu::Chipset chipset, bool boundsCheck)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Pack small float vector operands (fp4/fp6/fp8/bf16) into the format expected by scaled matrix multipl...
static std::optional< ROCDL::WMMAMatrixScaleFormat > getWmmaScaleFormat(Type elemType)
Maps f8 scale element types to WMMA scale format codes.
static Value convertPackedVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts packed vector operands to the expected ROCDL types.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static bool hasDot8Insts(const Chipset &chipset)
static bool hasDot2Insts(const Chipset &chipset)
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static std::optional< ROCDL::MatrixFormat > smallFloatTypeToMatrixFormat(Type mlirElemType)
std::tuple< StringRef, ROCDL::MatrixFormat, ROCDL::MatrixFormat > ScaledMFMAIntrinsic
If there is a scaled MFMA instruction for the input element types aType and bType,...
static bool hasDot12Insts(const Chipset &chipset)
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i64.
constexpr Chipset kGfx950
static bool hasDot1Insts(const Chipset &chipset)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
This class provides a shared interface for ranked and unranked memref types.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
typename SourceOp::Adaptor OpAdaptor
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
NamedAttribute represents a combination of a name and an Attribute value.
This class helps build Operations.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
int32_t getGlobalPrefetchLLVMEncoding(amdgpu::LoadTemporalHint hint, amdgpu::Scope scope, bool isSpeculative)
bool hasOcpFp8(const Chipset &chipset)
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces and types,...
llvm::TypeSwitch< T, ResultT > TypeSwitch
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateAMDGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
Returns the rocdl intrinsic corresponding to a SparseWMMA operation swmmac if one exists.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.