29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/Support/Casting.h"
32#include "llvm/Support/ErrorHandling.h"
36#define GEN_PASS_DEF_CONVERTAMDGPUTOROCDLPASS
37#include "mlir/Conversion/Passes.h.inc"
53 IntegerType i32 = rewriter.getI32Type();
55 auto valTy = cast<IntegerType>(val.
getType());
58 return valTy.getWidth() > 32
59 ?
Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
60 :
Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
65 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), value);
71 IntegerType i64 = rewriter.getI64Type();
73 auto valTy = cast<IntegerType>(val.
getType());
76 return valTy.getWidth() > 64
77 ?
Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
78 :
Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
83 return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
90 IntegerType i32 = rewriter.getI32Type();
92 for (
auto [i, increment, stride] : llvm::enumerate(
indices, strides)) {
95 ShapedType::isDynamic(stride)
97 memRefDescriptor.
stride(rewriter, loc, i))
98 : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
99 increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
111 MemRefType memrefType,
115 if (chipset >=
kGfx1250 && !boundsCheck) {
116 constexpr int64_t first45bits = (1ll << 45) - 1;
119 if (memrefType.hasStaticShape() &&
120 !llvm::any_of(strides, ShapedType::isDynamic)) {
121 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
123 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
124 size = std::max(
shape[i] * strides[i], size);
125 size = size * elementByteWidth;
129 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
130 Value size = memrefDescriptor.
size(rewriter, loc, i);
131 Value stride = memrefDescriptor.
stride(rewriter, loc, i);
132 Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
134 ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
139 return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
145 Value cacheSwizzleStride =
nullptr,
146 unsigned addressSpace = 8) {
150 Type i16 = rewriter.getI16Type();
153 Value cacheStrideZext =
154 LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
155 Value swizzleBit = LLVM::ConstantOp::create(
156 rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
157 stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
160 stride = LLVM::ConstantOp::create(rewriter, loc, i16,
161 rewriter.getI16IntegerAttr(0));
190 flags |= (7 << 12) | (4 << 15);
193 uint32_t oob = boundsCheck ? 3 : 2;
194 flags |= (oob << 28);
199 LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
200 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
201 loc, rsrcType, basePointer, stride, numRecords, flagsConst);
206struct FatRawBufferCastLowering
208 FatRawBufferCastLowering(
const LLVMTypeConverter &converter, Chipset chipset)
209 : ConvertOpToLLVMPattern<FatRawBufferCastOp>(converter),
215 matchAndRewrite(FatRawBufferCastOp op, FatRawBufferCastOpAdaptor adaptor,
216 ConversionPatternRewriter &rewriter)
const override {
217 Location loc = op.getLoc();
218 Value memRef = adaptor.getSource();
219 Value unconvertedMemref = op.getSource();
220 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
221 MemRefDescriptor descriptor(memRef);
223 DataLayout dataLayout = DataLayout::closest(op);
224 int64_t elementByteWidth =
227 int64_t unusedOffset = 0;
228 SmallVector<int64_t, 5> strideVals;
229 if (
failed(memrefType.getStridesAndOffset(strideVals, unusedOffset)))
230 return op.emitOpError(
"Can't lower non-stride-offset memrefs");
232 Value numRecords = adaptor.getValidBytes();
235 getNumRecords(rewriter, loc, memrefType, descriptor, strideVals,
236 elementByteWidth, chipset, adaptor.getBoundsCheck());
239 adaptor.getResetOffset()
240 ? descriptor.bufferPtr(rewriter, loc, *getTypeConverter(),
242 : descriptor.alignedPtr(rewriter, loc);
244 Value offset = adaptor.getResetOffset()
245 ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
246 rewriter.getIndexAttr(0))
247 : descriptor.offset(rewriter, loc);
249 bool hasSizes = memrefType.getRank() > 0;
252 Value sizes = hasSizes
253 ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
257 hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
262 rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
263 chipset, adaptor.getCacheSwizzleStride(), 7);
265 Value
result = MemRefDescriptor::poison(
267 getTypeConverter()->convertType(op.getResult().getType()));
269 result = LLVM::InsertValueOp::create(rewriter, loc,
result, fatPtr, pos);
270 result = LLVM::InsertValueOp::create(rewriter, loc,
result, fatPtr,
272 result = LLVM::InsertValueOp::create(rewriter, loc,
result, offset,
275 result = LLVM::InsertValueOp::create(rewriter, loc,
result, sizes,
277 result = LLVM::InsertValueOp::create(rewriter, loc,
result, strides,
280 rewriter.replaceOp(op,
result);
286template <
typename GpuOp,
typename Intrinsic>
288 RawBufferOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
289 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
292 static constexpr uint32_t maxVectorOpWidth = 128;
295 matchAndRewrite(GpuOp gpuOp,
typename GpuOp::Adaptor adaptor,
296 ConversionPatternRewriter &rewriter)
const override {
297 Location loc = gpuOp.getLoc();
298 Value memref = adaptor.getMemref();
299 Value unconvertedMemref = gpuOp.getMemref();
300 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
302 if (chipset.majorVersion < 9)
303 return gpuOp.emitOpError(
"raw buffer ops require GCN or higher");
305 Value storeData = adaptor.getODSOperands(0)[0];
306 if (storeData == memref)
310 wantedDataType = storeData.
getType();
312 wantedDataType = gpuOp.getODSResults(0)[0].getType();
314 Value atomicCmpData = Value();
317 Value maybeCmpData = adaptor.getODSOperands(1)[0];
318 if (maybeCmpData != memref)
319 atomicCmpData = maybeCmpData;
322 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
324 Type i32 = rewriter.getI32Type();
327 DataLayout dataLayout = DataLayout::closest(gpuOp);
328 int64_t elementByteWidth =
337 Type llvmBufferValType = llvmWantedDataType;
339 if (
auto floatType = dyn_cast<FloatType>(wantedDataType))
340 llvmBufferValType = this->getTypeConverter()->convertType(
341 rewriter.getIntegerType(floatType.getWidth()));
343 if (
auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
344 uint32_t vecLen = dataVector.getNumElements();
347 uint32_t totalBits = elemBits * vecLen;
349 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
350 if (totalBits > maxVectorOpWidth)
351 return gpuOp.emitOpError(
352 "Total width of loads or stores must be no more than " +
353 Twine(maxVectorOpWidth) +
" bits, but we call for " +
355 " bits. This should've been caught in validation");
356 if (!usePackedFp16 && elemBits < 32) {
357 if (totalBits > 32) {
358 if (totalBits % 32 != 0)
359 return gpuOp.emitOpError(
"Load or store of more than 32-bits that "
360 "doesn't fit into words. Can't happen\n");
361 llvmBufferValType = this->typeConverter->convertType(
362 VectorType::get(totalBits / 32, i32));
364 llvmBufferValType = this->typeConverter->convertType(
365 rewriter.getIntegerType(totalBits));
369 if (
auto vecType = dyn_cast<VectorType>(llvmBufferValType)) {
372 if (vecType.getNumElements() == 1)
373 llvmBufferValType = vecType.getElementType();
376 SmallVector<Value, 6> args;
378 if (llvmBufferValType != llvmWantedDataType) {
379 Value castForStore = LLVM::BitcastOp::create(
380 rewriter, loc, llvmBufferValType, storeData);
381 args.push_back(castForStore);
383 args.push_back(storeData);
388 if (llvmBufferValType != llvmWantedDataType) {
389 Value castForCmp = LLVM::BitcastOp::create(
390 rewriter, loc, llvmBufferValType, atomicCmpData);
391 args.push_back(castForCmp);
393 args.push_back(atomicCmpData);
399 SmallVector<int64_t, 5> strides;
400 if (
failed(memrefType.getStridesAndOffset(strides, offset)))
401 return gpuOp.emitOpError(
"Can't lower non-stride-offset memrefs");
403 MemRefDescriptor memrefDescriptor(memref);
405 Value ptr = memrefDescriptor.bufferPtr(
406 rewriter, loc, *this->getTypeConverter(), memrefType);
408 getNumRecords(rewriter, loc, memrefType, memrefDescriptor, strides,
409 elementByteWidth, chipset, adaptor.getBoundsCheck());
411 adaptor.getBoundsCheck(), chipset);
412 args.push_back(resource);
416 adaptor.getIndices(), strides);
417 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
418 indexOffset && *indexOffset > 0) {
420 voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
424 voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
425 args.push_back(voffset);
428 Value sgprOffset = adaptor.getSgprOffset();
431 sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
432 args.push_back(sgprOffset);
439 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
441 Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
442 ArrayRef<NamedAttribute>());
445 if (llvmBufferValType != llvmWantedDataType) {
446 replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
451 rewriter.eraseOp(gpuOp);
468static FailureOr<unsigned> encodeWaitcnt(
Chipset chipset,
unsigned vmcnt,
469 unsigned expcnt,
unsigned lgkmcnt) {
471 vmcnt = std::min(15u, vmcnt);
472 expcnt = std::min(7u, expcnt);
473 lgkmcnt = std::min(15u, lgkmcnt);
474 return vmcnt | (expcnt << 4) | (lgkmcnt << 8);
477 vmcnt = std::min(63u, vmcnt);
478 expcnt = std::min(7u, expcnt);
479 lgkmcnt = std::min(15u, lgkmcnt);
480 unsigned lowBits = vmcnt & 0xF;
481 unsigned highBits = (vmcnt >> 4) << 14;
482 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
483 return lowBits | highBits | otherCnts;
486 vmcnt = std::min(63u, vmcnt);
487 expcnt = std::min(7u, expcnt);
488 lgkmcnt = std::min(63u, lgkmcnt);
489 unsigned lowBits = vmcnt & 0xF;
490 unsigned highBits = (vmcnt >> 4) << 14;
491 unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8);
492 return lowBits | highBits | otherCnts;
495 vmcnt = std::min(63u, vmcnt);
496 expcnt = std::min(7u, expcnt);
497 lgkmcnt = std::min(63u, lgkmcnt);
498 return (vmcnt << 10) | expcnt | (lgkmcnt << 4);
503struct MemoryCounterWaitOpLowering
505 MemoryCounterWaitOpLowering(
const LLVMTypeConverter &converter,
507 : ConvertOpToLLVMPattern<MemoryCounterWaitOp>(converter),
513 matchAndRewrite(MemoryCounterWaitOp op, OpAdaptor adaptor,
514 ConversionPatternRewriter &rewriter)
const override {
515 if (chipset.majorVersion >= 12) {
516 Location loc = op.getLoc();
517 if (std::optional<int> ds = adaptor.getDs())
518 ROCDL::WaitDscntOp::create(rewriter, loc, *ds);
520 if (std::optional<int>
load = adaptor.getLoad())
521 ROCDL::WaitLoadcntOp::create(rewriter, loc, *
load);
523 if (std::optional<int> store = adaptor.getStore())
524 ROCDL::WaitStorecntOp::create(rewriter, loc, *store);
526 if (std::optional<int> exp = adaptor.getExp())
527 ROCDL::WaitExpcntOp::create(rewriter, loc, *exp);
529 if (std::optional<int> tensor = adaptor.getTensor())
530 ROCDL::WaitTensorcntOp::create(rewriter, loc, *tensor);
532 rewriter.eraseOp(op);
536 if (adaptor.getTensor())
537 return op.emitOpError(
"unsupported chipset");
539 auto getVal = [](Attribute attr) ->
unsigned {
541 return cast<IntegerAttr>(attr).getInt();
546 unsigned ds = getVal(adaptor.getDsAttr());
547 unsigned exp = getVal(adaptor.getExpAttr());
549 unsigned vmcnt = 1024;
550 Attribute
load = adaptor.getLoadAttr();
551 Attribute store = adaptor.getStoreAttr();
553 vmcnt = getVal(
load) + getVal(store);
555 vmcnt = getVal(
load);
557 vmcnt = getVal(store);
560 FailureOr<unsigned> waitcnt = encodeWaitcnt(chipset, vmcnt, exp, ds);
562 return op.emitOpError(
"unsupported chipset");
564 rewriter.replaceOpWithNewOp<ROCDL::SWaitcntOp>(op, *waitcnt);
570 LDSBarrierOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
571 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
576 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
577 ConversionPatternRewriter &rewriter)
const override {
578 Location loc = op.getLoc();
581 bool requiresInlineAsm = chipset <
kGfx90a;
584 rewriter.getAttr<LLVM::MMRATagAttr>(
"amdgpu-synchronize-as",
"local");
593 StringRef scope =
"workgroup";
595 auto relFence = LLVM::FenceOp::create(rewriter, loc,
596 LLVM::AtomicOrdering::release, scope);
597 relFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
598 if (requiresInlineAsm) {
599 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
600 LLVM::AsmDialect::AD_ATT);
601 const char *asmStr =
";;;WARNING: BREAKS DEBUG WATCHES\ns_barrier";
602 const char *constraints =
"";
603 LLVM::InlineAsmOp::create(
606 asmStr, constraints,
true,
607 false, LLVM::TailCallKind::None,
610 }
else if (chipset.majorVersion < 12) {
611 ROCDL::SBarrierOp::create(rewriter, loc);
613 ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
614 ROCDL::BarrierWaitOp::create(rewriter, loc, -1);
617 auto acqFence = LLVM::FenceOp::create(rewriter, loc,
618 LLVM::AtomicOrdering::acquire, scope);
619 acqFence->setDiscardableAttr(LLVM::LLVMDialect::getMmraAttrName(), mmra);
620 rewriter.replaceOp(op, acqFence);
626 SchedBarrierOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
627 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
632 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
633 ConversionPatternRewriter &rewriter)
const override {
634 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
635 (uint32_t)op.getOpts());
659 bool allowBf16 =
true) {
661 if (
auto vectorType = dyn_cast<VectorType>(inputType)) {
662 if (vectorType.getElementType().isBF16() && !allowBf16)
663 return LLVM::BitcastOp::create(
664 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
665 if (vectorType.getElementType().isInteger(8) &&
666 vectorType.getNumElements() <= 8)
667 return LLVM::BitcastOp::create(
669 rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
670 if (isa<IntegerType>(vectorType.getElementType()) &&
671 vectorType.getElementTypeBitWidth() <= 8) {
672 int64_t numWords = llvm::divideCeil(
673 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
675 return LLVM::BitcastOp::create(
676 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
686 bool allowBf16 =
true) {
688 auto vectorType = cast<VectorType>(inputType);
690 if (vectorType.getElementType().isBF16() && !allowBf16)
691 return LLVM::BitcastOp::create(
692 rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
694 if (isa<IntegerType>(vectorType.getElementType()) &&
695 vectorType.getElementTypeBitWidth() <= 8) {
696 int64_t numWords = llvm::divideCeil(
697 vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32);
698 return LLVM::BitcastOp::create(
699 rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()), input);
717 .Case([&](IntegerType) {
719 return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(),
722 .Case([&](VectorType vectorType) {
724 int64_t numElements = vectorType.getNumElements();
725 assert((numElements == 4 || numElements == 8) &&
726 "scale operand must be a vector of length 4 or 8");
727 IntegerType outputType =
728 (numElements == 4) ? rewriter.getI32Type() : rewriter.getI64Type();
729 return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
731 .DefaultUnreachable(
"unexpected input type for scale operand");
737 .Case([](Float8E8M0FNUType) {
return 0; })
738 .Case([](Float8E4M3FNType) {
return 2; })
739 .Default(std::nullopt);
744static std::optional<StringRef>
746 if (m == 16 && n == 16 && k == 128)
748 ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
749 : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
751 if (m == 32 && n == 16 && k == 128)
752 return isScale16 ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
753 : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
767 ConversionPatternRewriter &rewriter,
Location loc,
772 auto vectorType = dyn_cast<VectorType>(inputType);
774 operands.push_back(llvmInput);
777 Type elemType = vectorType.getElementType();
779 operands.push_back(llvmInput);
786 auto mlirInputType = cast<VectorType>(mlirInput.
getType());
787 bool isInputInteger = mlirInputType.getElementType().isInteger();
788 if (isInputInteger) {
790 bool localIsUnsigned = isUnsigned;
792 localIsUnsigned =
true;
794 localIsUnsigned =
false;
797 NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
802 Type i32 = rewriter.getI32Type();
803 Type intrinsicInType = numBits <= 32
804 ? (
Type)rewriter.getIntegerType(numBits)
805 : (
Type)VectorType::get(numBits / 32, i32);
806 auto llvmIntrinsicInType = typeConverter->convertType(intrinsicInType);
807 Value castInput = rewriter.createOrFold<LLVM::BitcastOp>(
808 loc, llvmIntrinsicInType, llvmInput);
813 castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
814 operands.push_back(castInput);
827 Value output, int32_t subwordOffset,
831 auto vectorType = dyn_cast<VectorType>(inputType);
832 Type elemType = vectorType.getElementType();
833 operands.push_back(output);
845 return (chipset ==
kGfx942 && isa<Float8E5M2FNUZType>(type)) ||
846 (
hasOcpFp8(chipset) && isa<Float8E5M2Type>(type));
852 return (chipset ==
kGfx942 && isa<Float8E4M3FNUZType>(type)) ||
853 (
hasOcpFp8(chipset) && isa<Float8E4M3FNType>(type));
861 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
862 b = mfma.getBlocks();
867 if (mfma.getReducePrecision() && chipset >=
kGfx942) {
868 if (m == 32 && n == 32 && k == 4 &&
b == 1)
869 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
870 if (m == 16 && n == 16 && k == 8 &&
b == 1)
871 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
873 if (m == 32 && n == 32 && k == 1 &&
b == 2)
874 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
875 if (m == 16 && n == 16 && k == 1 &&
b == 4)
876 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
877 if (m == 4 && n == 4 && k == 1 &&
b == 16)
878 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
879 if (m == 32 && n == 32 && k == 2 &&
b == 1)
880 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
881 if (m == 16 && n == 16 && k == 4 &&
b == 1)
882 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
887 if (m == 32 && n == 32 && k == 16 &&
b == 1)
888 return ROCDL::mfma_f32_32x32x16_f16::getOperationName();
889 if (m == 16 && n == 16 && k == 32 &&
b == 1)
890 return ROCDL::mfma_f32_16x16x32_f16::getOperationName();
892 if (m == 32 && n == 32 && k == 4 &&
b == 2)
893 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
894 if (m == 16 && n == 16 && k == 4 &&
b == 4)
895 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
896 if (m == 4 && n == 4 && k == 4 &&
b == 16)
897 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
898 if (m == 32 && n == 32 && k == 8 &&
b == 1)
899 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
900 if (m == 16 && n == 16 && k == 16 &&
b == 1)
901 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
906 if (m == 32 && n == 32 && k == 16 &&
b == 1)
907 return ROCDL::mfma_f32_32x32x16_bf16::getOperationName();
908 if (m == 16 && n == 16 && k == 32 &&
b == 1)
909 return ROCDL::mfma_f32_16x16x32_bf16::getOperationName();
912 if (m == 32 && n == 32 && k == 4 &&
b == 2)
913 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
914 if (m == 16 && n == 16 && k == 4 &&
b == 4)
915 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
916 if (m == 4 && n == 4 && k == 4 &&
b == 16)
917 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
918 if (m == 32 && n == 32 && k == 8 &&
b == 1)
919 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
920 if (m == 16 && n == 16 && k == 16 &&
b == 1)
921 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
923 if (m == 32 && n == 32 && k == 2 &&
b == 2)
924 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
925 if (m == 16 && n == 16 && k == 2 &&
b == 4)
926 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
927 if (m == 4 && n == 4 && k == 2 &&
b == 16)
928 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
929 if (m == 32 && n == 32 && k == 4 &&
b == 1)
930 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
931 if (m == 16 && n == 16 && k == 8 &&
b == 1)
932 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
937 if (m == 32 && n == 32 && k == 32 &&
b == 1)
938 return ROCDL::mfma_i32_32x32x32_i8::getOperationName();
939 if (m == 16 && n == 16 && k == 64 &&
b == 1)
940 return ROCDL::mfma_i32_16x16x64_i8::getOperationName();
942 if (m == 32 && n == 32 && k == 4 &&
b == 2)
943 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
944 if (m == 16 && n == 16 && k == 4 &&
b == 4)
945 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
946 if (m == 4 && n == 4 && k == 4 &&
b == 16)
947 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
948 if (m == 32 && n == 32 && k == 8 &&
b == 1)
949 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
950 if (m == 16 && n == 16 && k == 16 &&
b == 1)
951 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
952 if (m == 32 && n == 32 && k == 16 &&
b == 1 && chipset >=
kGfx942)
953 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
954 if (m == 16 && n == 16 && k == 32 &&
b == 1 && chipset >=
kGfx942)
955 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
959 if (m == 16 && n == 16 && k == 4 &&
b == 1)
960 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
961 if (m == 4 && n == 4 && k == 4 &&
b == 4)
962 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
969 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
970 if (m == 16 && n == 16 && k == 32 &&
b == 1) {
972 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
974 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
976 if (m == 32 && n == 32 && k == 16 &&
b == 1) {
978 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
980 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
986 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
987 if (m == 16 && n == 16 && k == 32 &&
b == 1) {
989 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
991 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
993 if (m == 32 && n == 32 && k == 16 &&
b == 1) {
995 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
997 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
1001 return std::nullopt;
1006 .Case([](Float8E4M3FNType) {
return 0u; })
1007 .Case([](Float8E5M2Type) {
return 1u; })
1008 .Case([](Float6E2M3FNType) {
return 2u; })
1009 .Case([](Float6E3M2FNType) {
return 3u; })
1010 .Case([](Float4E2M1FNType) {
return 4u; })
1011 .Default(std::nullopt);
1021static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1023 uint32_t n, uint32_t k, uint32_t
b,
Chipset chipset) {
1029 return std::nullopt;
1030 if (!isa<Float32Type>(destType))
1031 return std::nullopt;
1035 if (!aTypeCode || !bTypeCode)
1036 return std::nullopt;
1038 if (m == 32 && n == 32 && k == 64 &&
b == 1)
1039 return std::tuple{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
1040 *aTypeCode, *bTypeCode};
1041 if (m == 16 && n == 16 && k == 128 &&
b == 1)
1043 ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), *aTypeCode,
1046 return std::nullopt;
1049static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1052 mfma.getSourceA().getType(), mfma.getSourceB().getType(),
1053 mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
1054 mfma.getBlocks(), chipset);
1057static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1060 smfma.getSourceB().getType(),
1061 smfma.getDestC().getType(), smfma.getM(),
1062 smfma.getN(), smfma.getK(), 1u, chipset);
1067static std::optional<StringRef>
1069 Type elemDestType, uint32_t k,
bool isRDNA3) {
1070 using fp8 = Float8E4M3FNType;
1071 using bf8 = Float8E5M2Type;
1076 if (elemSourceType.
isF16() && elemDestType.
isF32())
1077 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
1078 if (elemSourceType.
isBF16() && elemDestType.
isF32())
1079 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
1080 if (elemSourceType.
isF16() && elemDestType.
isF16())
1081 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
1083 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
1085 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1090 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1091 return std::nullopt;
1095 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1096 elemDestType.
isF32())
1097 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
1098 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1099 elemDestType.
isF32())
1100 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
1101 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
1102 elemDestType.
isF32())
1103 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
1104 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1105 elemDestType.
isF32())
1106 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1108 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1110 return std::nullopt;
1114 if (k == 32 && !isRDNA3) {
1116 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1119 return std::nullopt;
1125 Type elemBSourceType,
1128 using fp8 = Float8E4M3FNType;
1129 using bf8 = Float8E5M2Type;
1132 if (elemSourceType.
isF32() && elemDestType.
isF32())
1133 return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1135 return std::nullopt;
1139 if (elemSourceType.
isF16() && elemDestType.
isF32())
1140 return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
1141 if (elemSourceType.
isBF16() && elemDestType.
isF32())
1142 return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
1143 if (elemSourceType.
isF16() && elemDestType.
isF16())
1144 return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
1146 return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
1148 return std::nullopt;
1152 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1153 if (elemDestType.
isF32())
1154 return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
1155 if (elemDestType.
isF16())
1156 return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
1158 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1159 if (elemDestType.
isF32())
1160 return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
1161 if (elemDestType.
isF16())
1162 return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
1164 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1165 if (elemDestType.
isF32())
1166 return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
1167 if (elemDestType.
isF16())
1168 return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
1170 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1171 if (elemDestType.
isF32())
1172 return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
1173 if (elemDestType.
isF16())
1174 return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
1177 return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
1179 return std::nullopt;
1183 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1184 if (elemDestType.
isF32())
1185 return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
1186 if (elemDestType.
isF16())
1187 return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
1189 if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1190 if (elemDestType.
isF32())
1191 return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
1192 if (elemDestType.
isF16())
1193 return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
1195 if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1196 if (elemDestType.
isF32())
1197 return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
1198 if (elemDestType.
isF16())
1199 return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
1201 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1202 if (elemDestType.
isF32())
1203 return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
1204 if (elemDestType.
isF16())
1205 return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
1208 return std::nullopt;
1211 return std::nullopt;
1219 bool isGfx950 = chipset >=
kGfx950;
1223 uint32_t m = op.getM(), n = op.getN(), k = op.getK();
1228 if (m == 16 && n == 16 && k == 32) {
1230 return ROCDL::smfmac_f32_16x16x32_f16::getOperationName();
1232 return ROCDL::smfmac_f32_16x16x32_bf16::getOperationName();
1235 if (m == 16 && n == 16 && k == 64) {
1238 return ROCDL::smfmac_f32_16x16x64_f16::getOperationName();
1240 return ROCDL::smfmac_f32_16x16x64_bf16::getOperationName();
1244 return ROCDL::smfmac_i32_16x16x64_i8::getOperationName();
1245 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1246 return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName();
1247 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1248 return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName();
1249 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1250 return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName();
1251 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1252 return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName();
1255 if (m == 16 && n == 16 && k == 128 && isGfx950) {
1258 return ROCDL::smfmac_i32_16x16x128_i8::getOperationName();
1259 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1260 return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName();
1261 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1262 return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName();
1263 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1264 return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName();
1265 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1266 return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName();
1269 if (m == 32 && n == 32 && k == 16) {
1271 return ROCDL::smfmac_f32_32x32x16_f16::getOperationName();
1273 return ROCDL::smfmac_f32_32x32x16_bf16::getOperationName();
1276 if (m == 32 && n == 32 && k == 32) {
1279 return ROCDL::smfmac_f32_32x32x32_f16::getOperationName();
1281 return ROCDL::smfmac_f32_32x32x32_bf16::getOperationName();
1285 return ROCDL::smfmac_i32_32x32x32_i8::getOperationName();
1286 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1287 return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName();
1288 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1289 return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName();
1290 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1291 return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName();
1292 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1293 return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName();
1296 if (m == 32 && n == 32 && k == 64 && isGfx950) {
1299 return ROCDL::smfmac_i32_32x32x64_i8::getOperationName();
1300 if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1301 return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName();
1302 if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1303 return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName();
1304 if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.
isF32())
1305 return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName();
1306 if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.
isF32())
1307 return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName();
1310 return std::nullopt;
1318 auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
1319 auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
1320 auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1321 Type elemSourceType = sourceVectorType.getElementType();
1322 Type elemBSourceType = sourceBVectorType.getElementType();
1323 Type elemDestType = destVectorType.getElementType();
1325 const uint32_t k = wmma.getK();
1330 if (isRDNA3 || isRDNA4)
1339 return std::nullopt;
1344 MFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1345 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
1350 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
1351 ConversionPatternRewriter &rewriter)
const override {
1352 Location loc = op.getLoc();
1353 Type outType = typeConverter->convertType(op.getDestD().getType());
1354 Type intrinsicOutType = outType;
1355 if (
auto outVecType = dyn_cast<VectorType>(outType))
1356 if (outVecType.getElementType().isBF16())
1357 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
1359 if (chipset.majorVersion != 9 || chipset <
kGfx908)
1360 return op->emitOpError(
"MFMA only supported on gfx908+");
1361 uint32_t getBlgpField =
static_cast<uint32_t
>(op.getBlgp());
1362 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
1364 return op.emitOpError(
"negation unsupported on older than gfx942");
1366 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
1369 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1371 if (!maybeIntrinsic.has_value() && !maybeScaledIntrinsic.has_value())
1372 return op.emitOpError(
"no intrinsic matching MFMA size on given chipset");
1375 !maybeIntrinsic.has_value() && maybeScaledIntrinsic.has_value();
1377 (adaptor.getAbid() > 0 || getBlgpField > 0 || op.getCbsz() > 0)) {
1378 return op.emitOpError(
1379 "non-default abid, blgp, and cbsz aren't supported on MFMAs that can "
1380 "be scaled as those fields are used for type information");
1383 StringRef intrinsicName =
1384 isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
1387 bool allowBf16 = [&]() {
1392 return intrinsicName.contains(
"16x16x32.bf16") ||
1393 intrinsicName.contains(
"32x32x16.bf16");
1395 OperationState loweredOp(loc, intrinsicName);
1396 loweredOp.addTypes(intrinsicOutType);
1398 rewriter, loc, adaptor.getSourceA(), allowBf16),
1400 rewriter, loc, adaptor.getSourceB(), allowBf16),
1401 adaptor.getDestC()});
1404 auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1405 loweredOp.addOperands({zero, zero});
1406 loweredOp.addAttributes({{
"cbsz", rewriter.getI32IntegerAttr(aTypeCode)},
1407 {
"blgp", rewriter.getI32IntegerAttr(bTypeCode)},
1408 {
"opselA", rewriter.getI32IntegerAttr(0)},
1409 {
"opselB", rewriter.getI32IntegerAttr(0)}});
1411 loweredOp.addAttributes(
1412 {{
"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1413 {
"abid", rewriter.getI32IntegerAttr(op.getAbid())},
1414 {
"blgp", rewriter.getI32IntegerAttr(getBlgpField)}});
1416 Value lowered = rewriter.create(loweredOp)->getResult(0);
1417 if (outType != intrinsicOutType)
1418 lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
1419 rewriter.replaceOp(op, lowered);
1425 ScaledMFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1426 : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
1431 matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
1432 ConversionPatternRewriter &rewriter)
const override {
1433 Location loc = op.getLoc();
1434 Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
1436 if (chipset.majorVersion != 9 || chipset <
kGfx950)
1437 return op->emitOpError(
"scaled MFMA only supported on gfx908+");
1438 std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1440 if (!maybeScaledIntrinsic.has_value())
1441 return op.emitOpError(
1442 "no intrinsic matching scaled MFMA size on given chipset");
1444 auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1445 OperationState loweredOp(loc, intrinsicName);
1446 loweredOp.addTypes(intrinsicOutType);
1447 loweredOp.addOperands(
1450 adaptor.getDestC()});
1451 loweredOp.addOperands(
1456 loweredOp.addAttributes(
1457 {{
"cbsz", rewriter.getI32IntegerAttr(aTypeCode)},
1458 {
"blgp", rewriter.getI32IntegerAttr(bTypeCode)},
1459 {
"opselA", rewriter.getI32IntegerAttr(adaptor.getScalesIdxA())},
1460 {
"opselB", rewriter.getI32IntegerAttr(adaptor.getScalesIdxB())}});
1462 Value lowered = rewriter.create(loweredOp)->getResult(0);
1463 rewriter.replaceOp(op, lowered);
1469 SparseMFMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1470 : ConvertOpToLLVMPattern<SparseMFMAOp>(converter), chipset(chipset) {}
1475 matchAndRewrite(SparseMFMAOp op, SparseMFMAOpAdaptor adaptor,
1476 ConversionPatternRewriter &rewriter)
const override {
1477 Location loc = op.getLoc();
1479 typeConverter->convertType<VectorType>(op.getDestC().
getType());
1481 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1484 if (chipset.majorVersion != 9 || chipset <
kGfx942)
1485 return op->emitOpError(
"sparse MFMA (smfmac) only supported on gfx942+");
1486 bool isGfx950 = chipset >=
kGfx950;
1489 adaptor.getSourceA(), isGfx950);
1491 adaptor.getSourceB(), isGfx950);
1492 Value c = adaptor.getDestC();
1495 if (!maybeIntrinsic.has_value())
1496 return op.emitOpError(
1497 "no intrinsic matching sparse MFMA on the given chipset");
1500 Value sparseIdx = LLVM::BitcastOp::create(
1501 rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx());
1503 OperationState loweredOp(loc, maybeIntrinsic.value());
1504 loweredOp.addTypes(outType);
1505 loweredOp.addOperands({a,
b, c, sparseIdx});
1506 loweredOp.addAttributes(
1507 {{
"cbsz", rewriter.getI32IntegerAttr(op.getCbsz())},
1508 {
"abid", rewriter.getI32IntegerAttr(op.getAbid())}});
1509 Value lowered = rewriter.create(loweredOp)->getResult(0);
1510 rewriter.replaceOp(op, lowered);
1516 WMMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1517 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
1522 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
1523 ConversionPatternRewriter &rewriter)
const override {
1524 Location loc = op.getLoc();
1526 typeConverter->convertType<VectorType>(op.getDestD().
getType());
1528 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1530 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
1531 return op->emitOpError(
"WMMA only supported on gfx11 and gfx12");
1533 bool isGFX1250 = chipset >=
kGfx1250;
1538 auto aType = cast<VectorType>(adaptor.getSourceA().getType());
1539 auto bType = cast<VectorType>(adaptor.getSourceB().getType());
1540 auto destCType = cast<VectorType>(adaptor.getDestC().getType());
1541 bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
1542 bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
1543 bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
1544 bool castOutToI16 = outType.getElementType().
isBF16() && !isGFX1250;
1545 VectorType rawOutType = outType;
1547 rawOutType = outType.clone(rewriter.getI16Type());
1548 Value a = adaptor.getSourceA();
1550 a = LLVM::BitcastOp::create(rewriter, loc,
1551 aType.clone(rewriter.getI16Type()), a);
1552 Value
b = adaptor.getSourceB();
1554 b = LLVM::BitcastOp::create(rewriter, loc,
1555 bType.clone(rewriter.getI16Type()),
b);
1556 Value destC = adaptor.getDestC();
1558 destC = LLVM::BitcastOp::create(
1559 rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
1563 if (!maybeIntrinsic.has_value())
1564 return op.emitOpError(
"no intrinsic matching WMMA on the given chipset");
1566 if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
1567 return op.emitOpError(
"subwordOffset not supported on gfx12+");
1569 SmallVector<Value, 4> operands;
1570 SmallVector<NamedAttribute, 4> attrs;
1572 op.getSourceA(), operands, attrs,
"signA");
1574 op.getSourceB(), operands, attrs,
"signB");
1576 op.getSubwordOffset(), op.getClamp(), operands,
1579 OperationState loweredOp(loc, *maybeIntrinsic);
1580 loweredOp.addTypes(rawOutType);
1581 loweredOp.addOperands(operands);
1582 loweredOp.addAttributes(attrs);
1583 Operation *lowered = rewriter.create(loweredOp);
1585 Operation *maybeCastBack = lowered;
1586 if (rawOutType != outType)
1587 maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
1589 rewriter.replaceOp(op, maybeCastBack->
getResults());
1596 ScaledWMMAOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1597 : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
1602 matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
1603 ConversionPatternRewriter &rewriter)
const override {
1604 Location loc = op.getLoc();
1606 typeConverter->convertType<VectorType>(op.getDestD().
getType());
1608 return rewriter.notifyMatchFailure(op,
"type conversion failed");
1611 return op->emitOpError(
"WMMA scale only supported on gfx1250+");
1613 int64_t m = op.getM();
1614 int64_t n = op.getN();
1615 int64_t k = op.getK();
1623 if (!aFmtCode || !bFmtCode)
1624 return op.emitOpError(
"unsupported element types for scaled_wmma");
1627 auto scaleAVecType = cast<VectorType>(op.getScaleA().getType());
1628 auto scaleBVecType = cast<VectorType>(op.getScaleB().getType());
1630 if (scaleAVecType.getNumElements() != scaleBVecType.getNumElements())
1631 return op.emitOpError(
"scaleA and scaleB must have equal vector length");
1634 Type scaleAElemType = scaleAVecType.getElementType();
1635 Type scaleBElemType = scaleBVecType.getElementType();
1640 if (!scaleAFmt || !scaleBFmt)
1641 return op.emitOpError(
"unsupported scale element types");
1644 bool isScale16 = (scaleAVecType.getNumElements() == 8);
1645 std::optional<StringRef> intrinsicName =
1648 return op.emitOpError(
"unsupported scaled_wmma dimensions: ")
1649 << m <<
"x" << n <<
"x" << k;
1651 SmallVector<NamedAttribute, 8> attrs;
1654 bool is32x16 = (m == 32 && n == 16 && k == 128);
1656 attrs.emplace_back(
"fmtA", rewriter.getI32IntegerAttr(*aFmtCode));
1657 attrs.emplace_back(
"fmtB", rewriter.getI32IntegerAttr(*bFmtCode));
1661 attrs.emplace_back(
"modC", rewriter.getI16IntegerAttr(0));
1666 "scaleAType", rewriter.getI32IntegerAttr(op.getAFirstScaleLane() / 16));
1667 attrs.emplace_back(
"fmtScaleA", rewriter.getI32IntegerAttr(*scaleAFmt));
1669 "scaleBType", rewriter.getI32IntegerAttr(op.getBFirstScaleLane() / 16));
1670 attrs.emplace_back(
"fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt));
1673 attrs.emplace_back(
"reuseA", rewriter.getBoolAttr(
false));
1674 attrs.emplace_back(
"reuseB", rewriter.getBoolAttr(
false));
1687 OperationState loweredOp(loc, *intrinsicName);
1688 loweredOp.addTypes(outType);
1689 loweredOp.addOperands(
1690 {sourceA, sourceB, adaptor.getDestC(), packedScaleA, packedScaleB});
1691 loweredOp.addAttributes(attrs);
1693 Operation *lowered = rewriter.create(loweredOp);
1694 rewriter.replaceOp(op, lowered->
getResults());
1700struct TransposeLoadOpLowering
1702 TransposeLoadOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1703 : ConvertOpToLLVMPattern<TransposeLoadOp>(converter), chipset(chipset) {}
1708 matchAndRewrite(TransposeLoadOp op, TransposeLoadOpAdaptor adaptor,
1709 ConversionPatternRewriter &rewriter)
const override {
1711 return op.emitOpError(
"Non-gfx950 chipset not supported");
1713 Location loc = op.getLoc();
1714 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1718 size_t srcElementSize =
1719 srcMemRefType.getElementType().getIntOrFloatBitWidth();
1720 if (srcElementSize < 8)
1721 return op.emitOpError(
"Expect source memref to have at least 8 bits "
1722 "element size, got ")
1725 auto resultType = cast<VectorType>(op.getResult().getType());
1728 (adaptor.getSrcIndices()));
1730 size_t numElements = resultType.getNumElements();
1731 size_t elementTypeSize =
1732 resultType.getElementType().getIntOrFloatBitWidth();
1736 Type rocdlResultType = VectorType::get((numElements * elementTypeSize) / 32,
1737 rewriter.getIntegerType(32));
1738 Type llvmResultType = typeConverter->convertType(resultType);
1740 switch (elementTypeSize) {
1742 assert(numElements == 16);
1743 auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
1744 rocdlResultType, srcPtr);
1745 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1749 assert(numElements == 16);
1750 auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
1751 rocdlResultType, srcPtr);
1752 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1756 assert(numElements == 8);
1757 auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
1758 rocdlResultType, srcPtr);
1759 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
1763 assert(numElements == 4);
1764 rewriter.replaceOpWithNewOp<ROCDL::ds_read_tr16_b64>(op, llvmResultType,
1769 return op.emitOpError(
"Unsupported element size for transpose load");
1776 GatherToLDSOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1777 : ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
1782 matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1783 ConversionPatternRewriter &rewriter)
const override {
1784 if (chipset.majorVersion < 9 || chipset.majorVersion > 10)
1785 return op.emitOpError(
"pre-gfx9 and post-gfx10 not supported");
1787 Location loc = op.getLoc();
1789 auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1790 auto dstMemRefType = cast<MemRefType>(op.getDst().getType());
1795 Type transferType = op.getTransferType();
1796 int loadWidth = [&]() ->
int {
1797 if (
auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1798 return (transferVectorType.getNumElements() *
1799 transferVectorType.getElementTypeBitWidth()) /
1806 if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
1807 return op.emitOpError(
"chipset unsupported element size");
1809 if (chipset !=
kGfx950 && llvm::is_contained({12, 16}, loadWidth))
1810 return op.emitOpError(
"Gather to LDS instructions with 12-byte and "
1811 "16-byte load widths are only supported on gfx950");
1815 (adaptor.getSrcIndices()));
1818 (adaptor.getDstIndices()));
1820 if (op.getAsync()) {
1821 rewriter.replaceOpWithNewOp<ROCDL::LoadAsyncToLDSOp>(
1822 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
1823 rewriter.getI32IntegerAttr(0),
1827 rewriter.replaceOpWithNewOp<ROCDL::LoadToLDSOp>(
1828 op, srcPtr, dstPtr, rewriter.getI32IntegerAttr(loadWidth),
1829 rewriter.getI32IntegerAttr(0),
1839struct ExtPackedFp8OpLowering final
1841 ExtPackedFp8OpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1842 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
1847 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1848 ConversionPatternRewriter &rewriter)
const override;
1851struct ScaledExtPackedMatrixOpLowering final
1853 ScaledExtPackedMatrixOpLowering(
const LLVMTypeConverter &converter,
1855 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
1860 matchAndRewrite(ScaledExtPackedMatrixOp op,
1861 ScaledExtPackedMatrixOpAdaptor adaptor,
1862 ConversionPatternRewriter &rewriter)
const override;
1865struct PackedTrunc2xFp8OpLowering final
1867 PackedTrunc2xFp8OpLowering(
const LLVMTypeConverter &converter,
1869 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
1874 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
1875 ConversionPatternRewriter &rewriter)
const override;
1878struct PackedStochRoundFp8OpLowering final
1880 PackedStochRoundFp8OpLowering(
const LLVMTypeConverter &converter,
1882 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
1887 matchAndRewrite(PackedStochRoundFp8Op op,
1888 PackedStochRoundFp8OpAdaptor adaptor,
1889 ConversionPatternRewriter &rewriter)
const override;
1892struct ScaledExtPackedOpLowering final
1894 ScaledExtPackedOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
1895 : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedOp>(converter),
1900 matchAndRewrite(ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
1901 ConversionPatternRewriter &rewriter)
const override;
1904struct PackedScaledTruncOpLowering final
1906 PackedScaledTruncOpLowering(
const LLVMTypeConverter &converter,
1908 : ConvertOpToLLVMPattern<amdgpu::PackedScaledTruncOp>(converter),
1913 matchAndRewrite(PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
1914 ConversionPatternRewriter &rewriter)
const override;
1919LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1920 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
1921 ConversionPatternRewriter &rewriter)
const {
1922 Location loc = op.getLoc();
1924 return rewriter.notifyMatchFailure(
1925 loc,
"Fp8 conversion instructions are not available on target "
1926 "architecture and their emulation is not implemented");
1928 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
1929 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1930 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1932 Value source = adaptor.getSource();
1933 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1934 auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1937 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1938 Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
1939 if (!sourceVecType) {
1940 longVec = LLVM::InsertElementOp::create(
1943 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1945 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
1947 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
1952 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
1953 if (resultVecType) {
1955 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
1958 rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Fp8Op>(op, f32, i32Source,
1963 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
1966 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
1973int32_t getScaleSel(int32_t blockSize,
unsigned bitWidth, int32_t scaleWaveHalf,
1974 int32_t firstScaleByte) {
1980 assert(llvm::is_contained({16, 32}, blockSize));
1981 assert(llvm::is_contained({4u, 6u, 8u}, bitWidth));
1983 const bool isFp8 = bitWidth == 8;
1984 const bool isBlock16 = blockSize == 16;
1987 int32_t bit0 = isBlock16;
1988 assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
1989 int32_t bit1 = (firstScaleByte == 2) << 1;
1990 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
1991 int32_t bit2 = scaleWaveHalf << 2;
1992 return bit2 | bit1 | bit0;
1995 int32_t bit0 = isBlock16;
1997 assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
1998 int32_t bits2and1 = firstScaleByte << 1;
1999 assert(llvm::is_contained({0, 1}, scaleWaveHalf));
2000 int32_t bit3 = scaleWaveHalf << 3;
2001 int32_t bits = bit3 | bits2and1 | bit0;
2003 assert(!llvm::is_contained(
2004 {0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
2008static std::optional<StringRef>
2009scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
2010 using fp4 = Float4E2M1FNType;
2011 using fp8 = Float8E4M3FNType;
2012 using bf8 = Float8E5M2Type;
2013 using fp6 = Float6E2M3FNType;
2014 using bf6 = Float6E3M2FNType;
2015 if (isa<fp4>(srcElemType)) {
2016 if (destElemType.
isF16())
2017 return ROCDL::CvtPkScalePk8F16Fp4Op::getOperationName();
2018 if (destElemType.
isBF16())
2019 return ROCDL::CvtPkScalePk8Bf16Fp4Op::getOperationName();
2020 if (destElemType.
isF32())
2021 return ROCDL::CvtPkScalePk8F32Fp4Op::getOperationName();
2022 return std::nullopt;
2024 if (isa<fp8>(srcElemType)) {
2025 if (destElemType.
isF16())
2026 return ROCDL::CvtPkScalePk8F16Fp8Op::getOperationName();
2027 if (destElemType.
isBF16())
2028 return ROCDL::CvtPkScalePk8Bf16Fp8Op::getOperationName();
2029 if (destElemType.
isF32())
2030 return ROCDL::CvtPkScalePk8F32Fp8Op::getOperationName();
2031 return std::nullopt;
2033 if (isa<bf8>(srcElemType)) {
2034 if (destElemType.
isF16())
2035 return ROCDL::CvtPkScalePk8F16Bf8Op::getOperationName();
2036 if (destElemType.
isBF16())
2037 return ROCDL::CvtPkScalePk8Bf16Bf8Op::getOperationName();
2038 if (destElemType.
isF32())
2039 return ROCDL::CvtPkScalePk8F32Bf8Op::getOperationName();
2040 return std::nullopt;
2042 if (isa<fp6>(srcElemType)) {
2043 if (destElemType.
isF16())
2044 return ROCDL::CvtPkScalePk16F16Fp6Op::getOperationName();
2045 if (destElemType.
isBF16())
2046 return ROCDL::CvtPkScalePk16Bf16Fp6Op::getOperationName();
2047 if (destElemType.
isF32())
2048 return ROCDL::CvtPkScalePk16F32Fp6Op::getOperationName();
2049 return std::nullopt;
2051 if (isa<bf6>(srcElemType)) {
2052 if (destElemType.
isF16())
2053 return ROCDL::CvtPkScalePk16F16Bf6Op::getOperationName();
2054 if (destElemType.
isBF16())
2055 return ROCDL::CvtPkScalePk16Bf16Bf6Op::getOperationName();
2056 if (destElemType.
isF32())
2057 return ROCDL::CvtPkScalePk16F32Bf6Op::getOperationName();
2058 return std::nullopt;
2060 llvm_unreachable(
"invalid combination of element types for packed conversion "
2064LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
2065 ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
2066 ConversionPatternRewriter &rewriter)
const {
2067 using fp4 = Float4E2M1FNType;
2068 using fp8 = Float8E4M3FNType;
2069 using bf8 = Float8E5M2Type;
2070 using fp6 = Float6E2M3FNType;
2071 using bf6 = Float6E3M2FNType;
2072 Location loc = op.getLoc();
2074 return rewriter.notifyMatchFailure(
2076 "Scaled fp packed conversion instructions are not available on target "
2077 "architecture and their emulation is not implemented");
2081 int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
2082 int32_t firstScaleByte = op.getFirstScaleByte();
2083 int32_t blockSize = op.getBlockSize();
2084 auto sourceType = cast<VectorType>(op.getSource().getType());
2085 auto srcElemType = cast<FloatType>(sourceType.getElementType());
2086 unsigned bitWidth = srcElemType.getWidth();
2088 auto targetType = cast<VectorType>(op.getResult().getType());
2089 auto destElemType = cast<FloatType>(targetType.getElementType());
2091 IntegerType i32 = rewriter.getI32Type();
2092 Value source = adaptor.getSource();
2093 Type llvmResultType = typeConverter->convertType(op.getResult().getType());
2094 Type packedType =
nullptr;
2095 if (isa<fp4>(srcElemType)) {
2097 packedType = getTypeConverter()->convertType(packedType);
2098 }
else if (isa<fp8, bf8>(srcElemType)) {
2099 packedType = VectorType::get(2, i32);
2100 packedType = getTypeConverter()->convertType(packedType);
2101 }
else if (isa<fp6, bf6>(srcElemType)) {
2102 packedType = VectorType::get(3, i32);
2103 packedType = getTypeConverter()->convertType(packedType);
2105 llvm_unreachable(
"invalid element type for packed scaled ext");
2108 if (!packedType || !llvmResultType) {
2109 return rewriter.notifyMatchFailure(op,
"type conversion failed");
2112 std::optional<StringRef> maybeIntrinsic =
2113 scaledExtPacked816ToIntrinsic(srcElemType, destElemType);
2114 if (!maybeIntrinsic.has_value())
2115 return op.emitOpError(
2116 "no intrinsic matching packed scaled conversion on the given chipset");
2119 getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
2121 LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
2122 Value castedSource =
2123 LLVM::BitcastOp::create(rewriter, loc, packedType, source);
2125 OperationState loweredOp(loc, *maybeIntrinsic);
2126 loweredOp.addTypes({llvmResultType});
2127 loweredOp.addOperands({castedSource, castedScale});
2129 SmallVector<NamedAttribute, 1> attrs;
2131 NamedAttribute(
"scaleSel", rewriter.getI32IntegerAttr(scaleSel)));
2133 loweredOp.addAttributes(attrs);
2134 Operation *lowered = rewriter.create(loweredOp);
2135 rewriter.replaceOp(op, lowered);
2140LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
2141 ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
2142 ConversionPatternRewriter &rewriter)
const {
2143 Location loc = op.getLoc();
2145 return rewriter.notifyMatchFailure(
2146 loc,
"Scaled fp conversion instructions are not available on target "
2147 "architecture and their emulation is not implemented");
2148 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2150 Value source = adaptor.getSource();
2151 Value scale = adaptor.getScale();
2153 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2154 Type sourceElemType = sourceVecType.getElementType();
2155 VectorType destVecType = cast<VectorType>(op.getResult().getType());
2156 Type destElemType = destVecType.getElementType();
2158 VectorType packedVecType;
2159 if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
2160 VectorType v4i8 = VectorType::get(4, rewriter.getI8Type());
2161 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v4i8));
2162 }
else if (isa<Float4E2M1FNType>(sourceElemType)) {
2163 VectorType v8i4 = VectorType::get(8, rewriter.getI4Type());
2164 packedVecType = cast<VectorType>(getTypeConverter()->convertType(v8i4));
2166 llvm_unreachable(
"invalid element type for scaled ext");
2170 if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
2171 Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
2172 if (!sourceVecType) {
2173 longVec = LLVM::InsertElementOp::create(
2176 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
2178 Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
2180 LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
2185 Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
2187 if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF32())
2188 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
2189 op, destVecType, i32Source, scale, op.getIndex());
2190 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isF16())
2191 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Bf8Op>(
2192 op, destVecType, i32Source, scale, op.getIndex());
2193 else if (isa<Float8E5M2Type>(sourceElemType) && destElemType.
isBF16())
2194 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Bf8Op>(
2195 op, destVecType, i32Source, scale, op.getIndex());
2196 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF32())
2197 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
2198 op, destVecType, i32Source, scale, op.getIndex());
2199 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isF16())
2200 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp8Op>(
2201 op, destVecType, i32Source, scale, op.getIndex());
2202 else if (isa<Float8E4M3FNType>(sourceElemType) && destElemType.
isBF16())
2203 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp8Op>(
2204 op, destVecType, i32Source, scale, op.getIndex());
2205 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF32())
2206 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp4Op>(
2207 op, destVecType, i32Source, scale, op.getIndex());
2208 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isF16())
2209 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF16Fp4Op>(
2210 op, destVecType, i32Source, scale, op.getIndex());
2211 else if (isa<Float4E2M1FNType>(sourceElemType) && destElemType.
isBF16())
2212 rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkBf16Fp4Op>(
2213 op, destVecType, i32Source, scale, op.getIndex());
2220LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
2221 PackedScaledTruncOp op, PackedScaledTruncOpAdaptor adaptor,
2222 ConversionPatternRewriter &rewriter)
const {
2223 Location loc = op.getLoc();
2225 return rewriter.notifyMatchFailure(
2226 loc,
"Scaled fp conversion instructions are not available on target "
2227 "architecture and their emulation is not implemented");
2228 Type v2i16 = getTypeConverter()->convertType(
2229 VectorType::get(2, rewriter.getI16Type()));
2230 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2232 Type resultType = op.getResult().getType();
2234 VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
2235 Type sourceElemType = sourceVecType.getElementType();
2237 Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
2239 Value source = adaptor.getSource();
2240 Value scale = adaptor.getScale();
2241 Value existing = adaptor.getExisting();
2243 existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing);
2245 existing = LLVM::ZeroOp::create(rewriter, loc, intResultType);
2247 if (sourceVecType.getNumElements() < 2) {
2249 Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2250 VectorType v2 = VectorType::get(2, sourceElemType);
2251 source = LLVM::ZeroOp::create(rewriter, loc, v2);
2252 source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0);
2255 Value sourceA, sourceB;
2256 if (sourceElemType.
isF32()) {
2259 sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0);
2260 sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1);
2264 if (sourceElemType.
isF32() && isa<Float8E5M2Type>(resultElemType))
2265 result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType,
2266 existing, sourceA, sourceB,
2267 scale, op.getIndex());
2268 else if (sourceElemType.
isF16() && isa<Float8E5M2Type>(resultElemType))
2269 result = ROCDL::CvtScaleF32PkBf8F16Op::create(
2270 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2271 else if (sourceElemType.
isBF16() && isa<Float8E5M2Type>(resultElemType))
2272 result = ROCDL::CvtScaleF32PkBf8Bf16Op::create(
2273 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2274 else if (sourceElemType.
isF32() && isa<Float8E4M3FNType>(resultElemType))
2275 result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType,
2276 existing, sourceA, sourceB,
2277 scale, op.getIndex());
2278 else if (sourceElemType.
isF16() && isa<Float8E4M3FNType>(resultElemType))
2279 result = ROCDL::CvtScaleF32PkFp8F16Op::create(
2280 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2281 else if (sourceElemType.
isBF16() && isa<Float8E4M3FNType>(resultElemType))
2282 result = ROCDL::CvtScaleF32PkFp8Bf16Op::create(
2283 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2284 else if (sourceElemType.
isF32() && isa<Float4E2M1FNType>(resultElemType))
2285 result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType,
2286 existing, sourceA, sourceB,
2287 scale, op.getIndex());
2288 else if (sourceElemType.
isF16() && isa<Float4E2M1FNType>(resultElemType))
2289 result = ROCDL::CvtScaleF32PkFp4F16Op::create(
2290 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2291 else if (sourceElemType.
isBF16() && isa<Float4E2M1FNType>(resultElemType))
2292 result = ROCDL::CvtScaleF32PkFp4Bf16Op::create(
2293 rewriter, loc, intResultType, existing, source, scale, op.getIndex());
2297 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2298 op, getTypeConverter()->convertType(resultType),
result);
2302LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
2303 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
2304 ConversionPatternRewriter &rewriter)
const {
2305 Location loc = op.getLoc();
2307 return rewriter.notifyMatchFailure(
2308 loc,
"Fp8 conversion instructions are not available on target "
2309 "architecture and their emulation is not implemented");
2310 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2312 Type resultType = op.getResult().getType();
2315 Value sourceA = adaptor.getSourceA();
2316 Value sourceB = adaptor.getSourceB();
2318 sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.
getType());
2319 Value existing = adaptor.getExisting();
2321 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2323 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2327 result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2328 existing, op.getWordIndex());
2330 result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB,
2331 existing, op.getWordIndex());
2333 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2334 op, getTypeConverter()->convertType(resultType),
result);
2338LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
2339 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
2340 ConversionPatternRewriter &rewriter)
const {
2341 Location loc = op.getLoc();
2343 return rewriter.notifyMatchFailure(
2344 loc,
"Fp8 conversion instructions are not available on target "
2345 "architecture and their emulation is not implemented");
2346 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
2348 Type resultType = op.getResult().getType();
2351 Value source = adaptor.getSource();
2352 Value stoch = adaptor.getStochiasticParam();
2353 Value existing = adaptor.getExisting();
2355 existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing);
2357 existing = LLVM::UndefOp::create(rewriter, loc, i32);
2361 result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch,
2362 existing, op.getStoreIndex());
2364 result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch,
2365 existing, op.getStoreIndex());
2367 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
2368 op, getTypeConverter()->convertType(resultType),
result);
2374struct AMDGPUDPPLowering :
public ConvertOpToLLVMPattern<DPPOp> {
2375 AMDGPUDPPLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2376 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
2380 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
2381 ConversionPatternRewriter &rewriter)
const override {
2384 Location loc = DppOp.getLoc();
2385 Value src = adaptor.getSrc();
2386 Value old = adaptor.getOld();
2389 Type llvmType =
nullptr;
2391 llvmType = rewriter.getI32Type();
2392 }
else if (isa<FloatType>(srcType)) {
2394 ? rewriter.getF32Type()
2395 : rewriter.getF64Type();
2396 }
else if (isa<IntegerType>(srcType)) {
2398 ? rewriter.getI32Type()
2399 : rewriter.getI64Type();
2401 auto llvmSrcIntType = typeConverter->convertType(
2405 auto convertOperand = [&](Value operand, Type operandType) {
2406 if (operandType.getIntOrFloatBitWidth() <= 16) {
2407 if (llvm::isa<FloatType>(operandType)) {
2409 LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand);
2411 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
2412 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
2413 Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType);
2415 LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand,
2417 operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand);
2422 src = convertOperand(src, srcType);
2423 old = convertOperand(old, oldType);
2426 enum DppCtrl :
unsigned {
2435 ROW_HALF_MIRROR = 0x141,
2440 auto kind = DppOp.getKind();
2441 auto permArgument = DppOp.getPermArgument();
2442 uint32_t DppCtrl = 0;
2446 case DPPPerm::quad_perm: {
2447 auto quadPermAttr = cast<ArrayAttr>(*permArgument);
2449 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
2450 uint32_t num = elem.getInt();
2451 DppCtrl |= num << (i * 2);
2456 case DPPPerm::row_shl: {
2457 auto intAttr = cast<IntegerAttr>(*permArgument);
2458 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
2461 case DPPPerm::row_shr: {
2462 auto intAttr = cast<IntegerAttr>(*permArgument);
2463 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
2466 case DPPPerm::row_ror: {
2467 auto intAttr = cast<IntegerAttr>(*permArgument);
2468 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
2471 case DPPPerm::wave_shl:
2472 DppCtrl = DppCtrl::WAVE_SHL1;
2474 case DPPPerm::wave_shr:
2475 DppCtrl = DppCtrl::WAVE_SHR1;
2477 case DPPPerm::wave_rol:
2478 DppCtrl = DppCtrl::WAVE_ROL1;
2480 case DPPPerm::wave_ror:
2481 DppCtrl = DppCtrl::WAVE_ROR1;
2483 case DPPPerm::row_mirror:
2484 DppCtrl = DppCtrl::ROW_MIRROR;
2486 case DPPPerm::row_half_mirror:
2487 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
2489 case DPPPerm::row_bcast_15:
2490 DppCtrl = DppCtrl::BCAST15;
2492 case DPPPerm::row_bcast_31:
2493 DppCtrl = DppCtrl::BCAST31;
2499 auto rowMask = DppOp->getAttrOfType<IntegerAttr>(
"row_mask").getInt();
2500 auto bankMask = DppOp->getAttrOfType<IntegerAttr>(
"bank_mask").getInt();
2501 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>(
"bound_ctrl").getValue();
2505 ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl,
2506 rowMask, bankMask, boundCtrl);
2508 Value
result = dppMovOp.getRes();
2510 result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType,
result);
2511 if (!llvm::isa<IntegerType>(srcType)) {
2512 result = LLVM::BitcastOp::create(rewriter, loc, srcType,
result);
2523struct AMDGPUSwizzleBitModeLowering
2524 :
public ConvertOpToLLVMPattern<SwizzleBitModeOp> {
2528 matchAndRewrite(SwizzleBitModeOp op, OpAdaptor adaptor,
2529 ConversionPatternRewriter &rewriter)
const override {
2530 Location loc = op.getLoc();
2531 Type i32 = rewriter.getI32Type();
2532 Value src = adaptor.getSrc();
2533 SmallVector<Value> decomposed =
2534 LLVM::decomposeValue(rewriter, loc, src, i32);
2535 unsigned andMask = op.getAndMask();
2536 unsigned orMask = op.getOrMask();
2537 unsigned xorMask = op.getXorMask();
2541 unsigned mask = andMask | (orMask << 5) | (xorMask << 10);
2543 SmallVector<Value> swizzled;
2544 for (Value v : decomposed) {
2546 ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue);
2547 swizzled.emplace_back(res);
2550 Value
result = LLVM::composeValue(rewriter, loc, swizzled, src.
getType());
2551 rewriter.replaceOp(op,
result);
2556struct AMDGPUPermlaneLowering :
public ConvertOpToLLVMPattern<PermlaneSwapOp> {
2559 AMDGPUPermlaneLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2560 : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
2564 matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
2565 ConversionPatternRewriter &rewriter)
const override {
2567 return op->emitOpError(
"permlane_swap is only supported on gfx950+");
2569 Location loc = op.getLoc();
2570 Type i32 = rewriter.getI32Type();
2571 Value src = adaptor.getSrc();
2572 unsigned rowLength = op.getRowLength();
2573 bool fi = op.getFetchInactive();
2574 bool boundctrl = op.getBoundCtrl();
2576 SmallVector<Value> decomposed =
2577 LLVM::decomposeValue(rewriter, loc, src, i32);
2579 SmallVector<Value> permuted;
2580 for (Value v : decomposed) {
2582 Type i32pair = LLVM::LLVMStructType::getLiteral(
2583 rewriter.getContext(), {v.getType(), v.getType()});
2585 if (rowLength == 16)
2586 res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2588 else if (rowLength == 32)
2589 res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
2592 llvm_unreachable(
"unsupported row length");
2594 Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
2595 Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1});
2597 Value isEqual = LLVM::ICmpOp::create(rewriter, loc,
2598 LLVM::ICmpPredicate::eq, vdst0, v);
2603 LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0);
2604 permuted.emplace_back(vdstNew);
2607 Value
result = LLVM::composeValue(rewriter, loc, permuted, src.
getType());
2608 rewriter.replaceOp(op,
result);
2621constexpr int32_t kDsBarrierPendingCountBitWidth = 29;
2622constexpr int32_t kDsBarrierPhasePos = kDsBarrierPendingCountBitWidth;
2623constexpr int32_t kDsBarrierInitCountPos = 32;
2624constexpr int32_t kDsBarrierPendingCountMask =
2625 (1 << kDsBarrierPendingCountBitWidth) - 1;
2627struct DsBarrierInitOpLowering
2628 :
public ConvertOpToLLVMPattern<DsBarrierInitOp> {
2631 DsBarrierInitOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2632 : ConvertOpToLLVMPattern<DsBarrierInitOp>(converter), chipset(chipset) {}
2635 matchAndRewrite(DsBarrierInitOp op, OpAdaptor adaptor,
2636 ConversionPatternRewriter &rewriter)
const override {
2638 return op->emitOpError(
"only supported on gfx1250+");
2640 Location loc = op.getLoc();
2641 Type i64 = rewriter.getI64Type();
2643 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
2645 adaptor.getBase(), adaptor.getIndices());
2652 LLVM::SubOp::create(rewriter, loc, adaptor.getParticipants(),
2659 Value maskedCount32 =
2660 LLVM::AndOp::create(rewriter, loc, initCount, countMask);
2661 Value maskedCount = LLVM::ZExtOp::create(rewriter, loc, i64, maskedCount32);
2663 Value initCountShifted = LLVM::ShlOp::create(
2664 rewriter, loc, maskedCount,
2666 Value barrierState =
2667 LLVM::OrOp::create(rewriter, loc, initCountShifted, maskedCount);
2669 LLVM::StoreOp::create(
2670 rewriter, loc, barrierState, ptr, 8,
false,
2672 false, LLVM::AtomicOrdering::release,
2675 rewriter.eraseOp(op);
2680struct DsBarrierPollStateOpLowering
2681 :
public ConvertOpToLLVMPattern<DsBarrierPollStateOp> {
2684 DsBarrierPollStateOpLowering(
const LLVMTypeConverter &converter,
2686 : ConvertOpToLLVMPattern<DsBarrierPollStateOp>(converter),
2690 matchAndRewrite(DsBarrierPollStateOp op, OpAdaptor adaptor,
2691 ConversionPatternRewriter &rewriter)
const override {
2693 return op->emitOpError(
"only supported on gfx1250+");
2695 Location loc = op.getLoc();
2696 Type i64 = rewriter.getI64Type();
2698 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
2700 adaptor.getBase(), adaptor.getIndices());
2704 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
2705 op, i64, ptr, 8,
false,
2707 false, LLVM::AtomicOrdering::acquire,
2713struct DsAsyncBarrierArriveOpLowering
2714 :
public ConvertOpToLLVMPattern<DsAsyncBarrierArriveOp> {
2717 DsAsyncBarrierArriveOpLowering(
const LLVMTypeConverter &converter,
2719 : ConvertOpToLLVMPattern<DsAsyncBarrierArriveOp>(converter),
2723 matchAndRewrite(DsAsyncBarrierArriveOp op, OpAdaptor adaptor,
2724 ConversionPatternRewriter &rewriter)
const override {
2726 return op->emitOpError(
"only supported on gfx1250+");
2728 Location loc = op.getLoc();
2730 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
2732 adaptor.getBase(), adaptor.getIndices());
2734 rewriter.replaceOpWithNewOp<ROCDL::DsAtomicAsyncBarrierArriveOp>(
2735 op, ptr,
nullptr,
nullptr,
2741struct DsBarrierArriveOpLowering
2742 :
public ConvertOpToLLVMPattern<DsBarrierArriveOp> {
2745 DsBarrierArriveOpLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2746 : ConvertOpToLLVMPattern<DsBarrierArriveOp>(converter), chipset(chipset) {
2750 matchAndRewrite(DsBarrierArriveOp op, OpAdaptor adaptor,
2751 ConversionPatternRewriter &rewriter)
const override {
2753 return op->emitOpError(
"only supported on gfx1250+");
2755 Location loc = op.getLoc();
2756 Type i64 = rewriter.getI64Type();
2758 MemRefType memrefType = cast<MemRefType>(op.getBase().getType());
2760 adaptor.getBase(), adaptor.getIndices());
2762 rewriter.replaceOpWithNewOp<ROCDL::DsAtomicBarrierArriveRtnOp>(
2763 op, i64, ptr, adaptor.getCount(),
nullptr,
2769struct DsBarrierStatePhaseOpLowering
2770 :
public ConvertOpToLLVMPattern<DsBarrierStatePhaseOp> {
2774 matchAndRewrite(DsBarrierStatePhaseOp op, OpAdaptor adaptor,
2775 ConversionPatternRewriter &rewriter)
const override {
2776 Location loc = op.getLoc();
2777 Type i32 = rewriter.getI32Type();
2779 Value state = adaptor.getState();
2781 Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state);
2782 Value phase = LLVM::LShrOp::create(
2783 rewriter, loc, noInitCount,
2786 rewriter.replaceOp(op, phase);
2791struct DsBarrierStatePendingCountOpLowering
2792 :
public ConvertOpToLLVMPattern<DsBarrierStatePendingCountOp> {
2796 matchAndRewrite(DsBarrierStatePendingCountOp op, OpAdaptor adaptor,
2797 ConversionPatternRewriter &rewriter)
const override {
2798 Location loc = op.getLoc();
2799 Type i32 = rewriter.getI32Type();
2801 Value state = adaptor.getState();
2803 Value noInitCount = LLVM::TruncOp::create(rewriter, loc, i32, state);
2804 Value pendingCount = LLVM::AndOp::create(
2805 rewriter, loc, noInitCount,
2807 static_cast<uint32_t
>(kDsBarrierPendingCountMask)));
2809 rewriter.replaceOp(op, pendingCount);
2814struct DsBarrierStateInitCountOpLowering
2815 :
public ConvertOpToLLVMPattern<DsBarrierStateInitCountOp> {
2819 matchAndRewrite(DsBarrierStateInitCountOp op, OpAdaptor adaptor,
2820 ConversionPatternRewriter &rewriter)
const override {
2821 Location loc = op.getLoc();
2822 Type i32 = rewriter.getI32Type();
2824 Value state = adaptor.getState();
2826 Value initCountI64 = LLVM::LShrOp::create(
2827 rewriter, loc, state,
2829 Value initCount = LLVM::TruncOp::create(rewriter, loc, i32, initCountI64);
2831 rewriter.replaceOp(op, initCount);
2836struct DsBarrierStatePhaseParityLowering
2837 :
public ConvertOpToLLVMPattern<DsBarrierStatePhaseParity> {
2841 matchAndRewrite(DsBarrierStatePhaseParity op, OpAdaptor adaptor,
2842 ConversionPatternRewriter &rewriter)
const override {
2843 Location loc = op.getLoc();
2844 Type i1 = rewriter.getI1Type();
2846 Value state = adaptor.getState();
2849 LLVM::TruncOp::create(rewriter, loc, rewriter.getI32Type(), state);
2850 Value phase = LLVM::LShrOp::create(
2851 rewriter, loc, noInitCount,
2853 Value parity = LLVM::TruncOp::create(rewriter, loc, i1, phase);
2855 rewriter.replaceOp(op, parity);
2864static Value setValueAtOffset(ConversionPatternRewriter &rewriter, Location loc,
2865 Value accumulator, Value value, int64_t shift) {
2870 value = LLVM::ShlOp::create(rewriter, loc, value, shiftAmount);
2876 constexpr bool isDisjoint =
true;
2877 return LLVM::OrOp::create(rewriter, loc, accumulator, value, isDisjoint);
2880template <
typename BaseOp>
2881struct AMDGPUMakeDmaBaseLowering :
public ConvertOpToLLVMPattern<BaseOp> {
2882 using ConvertOpToLLVMPattern<BaseOp>::ConvertOpToLLVMPattern;
2885 AMDGPUMakeDmaBaseLowering(
const LLVMTypeConverter &converter, Chipset chipset)
2886 : ConvertOpToLLVMPattern<BaseOp>(converter), chipset(chipset) {}
2890 matchAndRewrite(BaseOp op, Adaptor adaptor,
2891 ConversionPatternRewriter &rewriter)
const override {
2893 return op->emitOpError(
"make_dma_base is only supported on gfx1250");
2895 Location loc = op.getLoc();
2897 constexpr int32_t constlen = 4;
2898 Value consts[constlen];
2899 for (int64_t i = 0; i < constlen; ++i)
2902 constexpr int32_t sgprslen = constlen;
2903 Value sgprs[sgprslen];
2904 for (int64_t i = 0; i < sgprslen; ++i) {
2905 sgprs[i] = consts[0];
2908 sgprs[0] = consts[1];
2910 if constexpr (BaseOp::isGather()) {
2911 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 30);
2913 auto type = cast<TDMGatherBaseType>(op.getResult().getType());
2914 Type indexType = type.getIndexType();
2916 assert(llvm::is_contained({16u, 32u}, indexSize) &&
2917 "expected index_size to be 16 or 32");
2918 unsigned idx = (indexSize / 16) - 1;
2921 sgprs[0] = setValueAtOffset(rewriter, loc, sgprs[0], consts[1], 31);
2924 ValueRange ldsIndices = adaptor.getLdsIndices();
2925 Value lds = adaptor.getLds();
2926 auto ldsMemRefType = cast<MemRefType>(op.getLds().getType());
2929 rewriter, loc, ldsMemRefType, lds, ldsIndices);
2931 ValueRange globalIndices = adaptor.getGlobalIndices();
2932 Value global = adaptor.getGlobal();
2933 auto globalMemRefType = cast<MemRefType>(op.getGlobal().getType());
2936 rewriter, loc, globalMemRefType, global, globalIndices);
2938 Type i32 = rewriter.getI32Type();
2939 Type i64 = rewriter.getI64Type();
2941 sgprs[1] = LLVM::PtrToIntOp::create(rewriter, loc, i32, ldsPtr);
2942 Value castForGlobalAddr =
2943 LLVM::PtrToIntOp::create(rewriter, loc, i64, globalPtr);
2945 sgprs[2] = LLVM::TruncOp::create(rewriter, loc, i32, castForGlobalAddr);
2947 Value shift = LLVM::LShrOp::create(rewriter, loc, castForGlobalAddr,
2950 Value highHalf = LLVM::TruncOp::create(rewriter, loc, i32, shift);
2953 highHalf = LLVM::AndOp::create(rewriter, loc, highHalf, mask);
2955 sgprs[3] = setValueAtOffset(rewriter, loc, highHalf, consts[2], 30);
2957 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
2958 assert(v4i32 &&
"expected type conversion to succeed");
2959 Value
result = LLVM::PoisonOp::create(rewriter, loc, v4i32);
2961 for (
auto [sgpr, constant] : llvm::zip_equal(sgprs, consts))
2963 LLVM::InsertElementOp::create(rewriter, loc,
result, sgpr, constant);
2965 rewriter.replaceOp(op,
result);
2970template <
typename DescriptorOp>
2971struct AMDGPULowerDescriptor :
public ConvertOpToLLVMPattern<DescriptorOp> {
2972 using ConvertOpToLLVMPattern<DescriptorOp>::ConvertOpToLLVMPattern;
2975 AMDGPULowerDescriptor(
const LLVMTypeConverter &converter, Chipset chipset)
2976 : ConvertOpToLLVMPattern<DescriptorOp>(converter), chipset(chipset) {}
2979 Value getDGroup0(OpAdaptor adaptor)
const {
return adaptor.getBase(); }
2981 Value setWorkgroupMask(DescriptorOp op, OpAdaptor adaptor,
2982 ConversionPatternRewriter &rewriter, Location loc,
2983 Value sgpr0)
const {
2984 Value mask = op.getWorkgroupMask();
2988 Type i16 = rewriter.getI16Type();
2989 mask = LLVM::BitcastOp::create(rewriter, loc, i16, mask);
2990 Type i32 = rewriter.getI32Type();
2991 Value extendedMask = LLVM::ZExtOp::create(rewriter, loc, i32, mask);
2992 return setValueAtOffset(rewriter, loc, sgpr0, extendedMask, 0);
2995 Value setDataSize(DescriptorOp op, OpAdaptor adaptor,
2996 ConversionPatternRewriter &rewriter, Location loc,
2997 Value sgpr0, ArrayRef<Value> consts)
const {
2998 unsigned elementTypeWidthInBits = op.getElementTypeWidth();
2999 assert(llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidthInBits) &&
3000 "expected type width to be 8, 16, 32, or 64.");
3001 int64_t idx = llvm::Log2_32(elementTypeWidthInBits / 8);
3002 Value size = consts[idx];
3003 return setValueAtOffset(rewriter, loc, sgpr0, size, 16);
3006 Value setAtomicBarrier(DescriptorOp op, OpAdaptor adaptor,
3007 ConversionPatternRewriter &rewriter, Location loc,
3008 Value sgpr0, ArrayRef<Value> consts)
const {
3009 if (!adaptor.getAtomicBarrierAddress())
3012 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 18);
3015 Value setIterateEnable(DescriptorOp op, OpAdaptor adaptor,
3016 ConversionPatternRewriter &rewriter, Location loc,
3017 Value sgpr0, ArrayRef<Value> consts)
const {
3018 if (!adaptor.getGlobalIncrement())
3023 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 19);
3026 Value setPadEnable(DescriptorOp op, OpAdaptor adaptor,
3027 ConversionPatternRewriter &rewriter, Location loc,
3028 Value sgpr0, ArrayRef<Value> consts)
const {
3029 if (!op.getPadAmount())
3032 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 20);
3035 Value setEarlyTimeout(DescriptorOp op, OpAdaptor adaptor,
3036 ConversionPatternRewriter &rewriter, Location loc,
3037 Value sgpr0, ArrayRef<Value> consts)
const {
3038 if (!op.getWorkgroupMask())
3041 return setValueAtOffset(rewriter, loc, sgpr0, consts[1], 21);
3044 Value setPadInterval(DescriptorOp op, OpAdaptor adaptor,
3045 ConversionPatternRewriter &rewriter, Location loc,
3046 Value sgpr0, ArrayRef<Value> consts)
const {
3047 if (!op.getPadAmount())
3056 IntegerType i32 = rewriter.getI32Type();
3057 Value padInterval = adaptor.getPadInterval();
3058 padInterval = LLVM::CountTrailingZerosOp::create(rewriter, loc, i32,
3059 padInterval,
false);
3060 padInterval = LLVM::SubOp::create(rewriter, loc, padInterval, consts[1]);
3062 return setValueAtOffset(rewriter, loc, sgpr0, padInterval, 22);
3065 Value setPadAmount(DescriptorOp op, OpAdaptor adaptor,
3066 ConversionPatternRewriter &rewriter, Location loc,
3067 Value sgpr0, ArrayRef<Value> consts)
const {
3068 if (!op.getPadAmount())
3077 Value padAmount = adaptor.getPadAmount();
3078 padAmount = LLVM::SubOp::create(rewriter, loc, padAmount, consts[1]);
3080 return setValueAtOffset(rewriter, loc, sgpr0, padAmount, 25);
3083 Value setAtomicBarrierAddress(DescriptorOp op, OpAdaptor adaptor,
3084 ConversionPatternRewriter &rewriter,
3085 Location loc, Value sgpr1,
3086 ArrayRef<Value> consts)
const {
3087 if (!adaptor.getAtomicBarrierAddress())
3090 Value atomicBarrierAddress = adaptor.getAtomicBarrierAddress();
3091 auto barrierAddressTy =
3092 cast<MemRefType>(op.getAtomicBarrierAddress().getType());
3093 ValueRange atomicBarrierIndices = adaptor.getAtomicBarrierIndices();
3095 rewriter, loc, barrierAddressTy, atomicBarrierAddress,
3096 atomicBarrierIndices);
3097 IntegerType i32 = rewriter.getI32Type();
3103 atomicBarrierAddress =
3104 LLVM::PtrToIntOp::create(rewriter, loc, i32, atomicBarrierAddress);
3105 atomicBarrierAddress =
3106 LLVM::LShrOp::create(rewriter, loc, atomicBarrierAddress, consts[3]);
3108 atomicBarrierAddress =
3109 LLVM::AndOp::create(rewriter, loc, atomicBarrierAddress, mask);
3110 return setValueAtOffset(rewriter, loc, sgpr1, atomicBarrierAddress, 32);
3113 std::pair<Value, Value> setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3114 ConversionPatternRewriter &rewriter,
3115 Location loc, Value sgpr1, Value sgpr2,
3116 ArrayRef<Value> consts, uint64_t dimX,
3117 uint32_t offset)
const {
3118 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3119 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3120 SmallVector<OpFoldResult> mixedGlobalSizes =
3122 if (mixedGlobalSizes.size() <= dimX)
3123 return {sgpr1, sgpr2};
3125 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3132 if (
auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3136 IntegerType i32 = rewriter.getI32Type();
3137 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3138 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3141 sgpr1 = setValueAtOffset(rewriter, loc, sgpr1, tensorDimX, offset);
3144 Value tensorDimXHigh = LLVM::LShrOp::create(rewriter, loc, tensorDimX, c16);
3145 sgpr2 = setValueAtOffset(rewriter, loc, sgpr2, tensorDimXHigh, offset + 16);
3146 return {sgpr1, sgpr2};
3149 std::pair<Value, Value> setTensorDim0(DescriptorOp op, OpAdaptor adaptor,
3150 ConversionPatternRewriter &rewriter,
3151 Location loc, Value sgpr1, Value sgpr2,
3152 ArrayRef<Value> consts)
const {
3153 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, 0,
3157 std::pair<Value, Value> setTensorDim1(DescriptorOp op, OpAdaptor adaptor,
3158 ConversionPatternRewriter &rewriter,
3159 Location loc, Value sgpr2, Value sgpr3,
3160 ArrayRef<Value> consts)
const {
3161 return setTensorDimX(op, adaptor, rewriter, loc, sgpr2, sgpr3, consts, 1,
3165 Value setTileDimX(DescriptorOp op, OpAdaptor adaptor,
3166 ConversionPatternRewriter &rewriter, Location loc,
3167 Value sgpr, ArrayRef<Value> consts,
size_t dimX,
3168 int64_t offset)
const {
3169 ArrayRef<int64_t> sharedStaticSizes = adaptor.getSharedStaticSizes();
3170 ValueRange sharedDynamicSizes = adaptor.getSharedDynamicSizes();
3171 SmallVector<OpFoldResult> mixedSharedSizes =
3173 if (mixedSharedSizes.size() <= dimX)
3176 OpFoldResult tileDimXOpFoldResult = *(mixedSharedSizes.rbegin() + dimX);
3185 if (
auto attr = dyn_cast<Attribute>(tileDimXOpFoldResult)) {
3189 IntegerType i32 = rewriter.getI32Type();
3190 tileDimX = cast<Value>(tileDimXOpFoldResult);
3191 tileDimX = LLVM::TruncOp::create(rewriter, loc, i32, tileDimX);
3194 return setValueAtOffset(rewriter, loc, sgpr, tileDimX, offset);
3197 Value setTileDim0(DescriptorOp op, OpAdaptor adaptor,
3198 ConversionPatternRewriter &rewriter, Location loc,
3199 Value sgpr3, ArrayRef<Value> consts)
const {
3200 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, 0, 112);
3203 Value setTileDim1(DescriptorOp op, OpAdaptor adaptor,
3204 ConversionPatternRewriter &rewriter, Location loc,
3205 Value sgpr4, ArrayRef<Value> consts)
const {
3206 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 1, 128);
3209 Value setValidIndices(DescriptorOp op, OpAdaptor adaptor,
3210 ConversionPatternRewriter &rewriter, Location loc,
3211 Value sgpr4, ArrayRef<Value> consts)
const {
3212 auto type = cast<VectorType>(op.getIndices().getType());
3213 ArrayRef<int64_t> shape = type.getShape();
3214 assert(shape.size() == 1 &&
"expected shape to be of rank 1.");
3215 unsigned length = shape.back();
3216 assert(0 < length && length <= 16 &&
"expected length to be at most 16.");
3218 return setValueAtOffset(rewriter, loc, sgpr4, value, 128);
3221 Value setTileDim1OrValidIndices(DescriptorOp op, OpAdaptor adaptor,
3222 ConversionPatternRewriter &rewriter,
3223 Location loc, Value sgpr4,
3224 ArrayRef<Value> consts)
const {
3225 if constexpr (DescriptorOp::isGather())
3226 return setValidIndices(op, adaptor, rewriter, loc, sgpr4, consts);
3227 return setTileDim1(op, adaptor, rewriter, loc, sgpr4, consts);
3230 Value setTileDim2(DescriptorOp op, OpAdaptor adaptor,
3231 ConversionPatternRewriter &rewriter, Location loc,
3232 Value sgpr4, ArrayRef<Value> consts)
const {
3234 if constexpr (DescriptorOp::isGather())
3236 return setTileDimX(op, adaptor, rewriter, loc, sgpr4, consts, 2, 144);
3239 std::pair<Value, Value>
3240 setTensorDimXStride(DescriptorOp op, OpAdaptor adaptor,
3241 ConversionPatternRewriter &rewriter, Location loc,
3242 Value sgprY, Value sgprZ, ArrayRef<Value> consts,
3243 size_t dimX, int64_t offset)
const {
3244 ArrayRef<int64_t> globalStaticStrides = adaptor.getGlobalStaticStrides();
3245 ValueRange globalDynamicStrides = adaptor.getGlobalDynamicStrides();
3246 SmallVector<OpFoldResult> mixedGlobalStrides =
3247 getMixedValues(globalStaticStrides, globalDynamicStrides, rewriter);
3249 if (mixedGlobalStrides.size() <= (dimX + 1))
3250 return {sgprY, sgprZ};
3252 OpFoldResult tensorDimXStrideOpFoldResult =
3253 *(mixedGlobalStrides.rbegin() + dimX + 1);
3258 Value tensorDimXStride;
3259 if (
auto attr = dyn_cast<Attribute>(tensorDimXStrideOpFoldResult))
3263 tensorDimXStride = cast<Value>(tensorDimXStrideOpFoldResult);
3265 constexpr int64_t first48bits = (1ll << 48) - 1;
3268 LLVM::AndOp::create(rewriter, loc, mask, tensorDimXStride);
3269 IntegerType i32 = rewriter.getI32Type();
3270 Value tensorDimXStrideLow =
3271 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStride);
3272 sgprY = setValueAtOffset(rewriter, loc, sgprY, tensorDimXStrideLow, offset);
3274 int64_t shift = (offset % 32) == 0 ? 32 : offset % 32;
3276 Value tensorDimXStrideHigh =
3277 LLVM::LShrOp::create(rewriter, loc, tensorDimXStride, shiftVal);
3278 tensorDimXStrideHigh =
3279 LLVM::TruncOp::create(rewriter, loc, i32, tensorDimXStrideHigh);
3280 sgprZ = setValueAtOffset(rewriter, loc, sgprZ, tensorDimXStrideHigh,
3282 return {sgprY, sgprZ};
3285 std::pair<Value, Value>
3286 setTensorDim0Stride(DescriptorOp op, OpAdaptor adaptor,
3287 ConversionPatternRewriter &rewriter, Location loc,
3288 Value sgpr5, Value sgpr6, ArrayRef<Value> consts)
const {
3289 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3293 std::pair<Value, Value>
3294 setTensorDim1Stride(DescriptorOp op, OpAdaptor adaptor,
3295 ConversionPatternRewriter &rewriter, Location loc,
3296 Value sgpr5, Value sgpr6, ArrayRef<Value> consts)
const {
3298 if constexpr (DescriptorOp::isGather())
3299 return {sgpr5, sgpr6};
3300 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr5, sgpr6, consts,
3304 Value getDGroup1(DescriptorOp op, OpAdaptor adaptor,
3305 ConversionPatternRewriter &rewriter, Location loc,
3306 ArrayRef<Value> consts)
const {
3308 for (int64_t i = 0; i < 8; ++i) {
3309 sgprs[i] = consts[0];
3312 sgprs[0] = setWorkgroupMask(op, adaptor, rewriter, loc, sgprs[0]);
3313 sgprs[0] = setDataSize(op, adaptor, rewriter, loc, sgprs[0], consts);
3314 sgprs[0] = setAtomicBarrier(op, adaptor, rewriter, loc, sgprs[0], consts);
3315 sgprs[0] = setIterateEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3316 sgprs[0] = setPadEnable(op, adaptor, rewriter, loc, sgprs[0], consts);
3317 sgprs[0] = setEarlyTimeout(op, adaptor, rewriter, loc, sgprs[0], consts);
3318 sgprs[0] = setPadInterval(op, adaptor, rewriter, loc, sgprs[0], consts);
3319 sgprs[0] = setPadAmount(op, adaptor, rewriter, loc, sgprs[0], consts);
3322 setAtomicBarrierAddress(op, adaptor, rewriter, loc, sgprs[1], consts);
3323 std::tie(sgprs[1], sgprs[2]) =
3324 setTensorDim0(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3325 std::tie(sgprs[2], sgprs[3]) =
3326 setTensorDim1(op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3328 sgprs[3] = setTileDim0(op, adaptor, rewriter, loc, sgprs[3], consts);
3330 setTileDim1OrValidIndices(op, adaptor, rewriter, loc, sgprs[4], consts);
3331 sgprs[4] = setTileDim2(op, adaptor, rewriter, loc, sgprs[4], consts);
3332 std::tie(sgprs[5], sgprs[6]) = setTensorDim0Stride(
3333 op, adaptor, rewriter, loc, sgprs[5], sgprs[6], consts);
3334 std::tie(sgprs[6], sgprs[7]) = setTensorDim1Stride(
3335 op, adaptor, rewriter, loc, sgprs[6], sgprs[7], consts);
3337 IntegerType i32 = rewriter.getI32Type();
3338 Type v8i32 = this->typeConverter->convertType(VectorType::get(8, i32));
3339 assert(v8i32 &&
"expected type conversion to succeed");
3340 Value dgroup1 = LLVM::PoisonOp::create(rewriter, loc, v8i32);
3342 for (
auto [sgpr, constant] : llvm::zip_equal(sgprs, consts)) {
3344 LLVM::InsertElementOp::create(rewriter, loc, dgroup1, sgpr, constant);
3350 Value setTensorDimX(DescriptorOp op, OpAdaptor adaptor,
3351 ConversionPatternRewriter &rewriter, Location loc,
3352 Value sgpr0, ArrayRef<Value> consts, int64_t dimX,
3353 int64_t offset)
const {
3354 ArrayRef<int64_t> globalStaticSizes = adaptor.getGlobalStaticSizes();
3355 ValueRange globalDynamicSizes = adaptor.getGlobalDynamicSizes();
3356 SmallVector<OpFoldResult> mixedGlobalSizes =
3358 if (mixedGlobalSizes.size() <=
static_cast<unsigned long>(dimX))
3361 OpFoldResult tensorDimXOpFoldResult = *(mixedGlobalSizes.rbegin() + dimX);
3363 if (
auto attr = dyn_cast<Attribute>(tensorDimXOpFoldResult)) {
3367 IntegerType i32 = rewriter.getI32Type();
3368 tensorDimX = cast<Value>(tensorDimXOpFoldResult);
3369 tensorDimX = LLVM::TruncOp::create(rewriter, loc, i32, tensorDimX);
3372 return setValueAtOffset(rewriter, loc, sgpr0, tensorDimX, offset);
3375 Value setTensorDim2(DescriptorOp op, OpAdaptor adaptor,
3376 ConversionPatternRewriter &rewriter, Location loc,
3377 Value sgpr0, ArrayRef<Value> consts)
const {
3378 return setTensorDimX(op, adaptor, rewriter, loc, sgpr0, consts, 2, 0);
3381 Value truncateAndSetValueAtOffset(ConversionPatternRewriter &rewriter,
3382 Location loc, Value accumulator,
3383 Value value, int64_t shift)
const {
3385 IntegerType i32 = rewriter.getI32Type();
3386 value = LLVM::TruncOp::create(rewriter, loc, i32, value);
3387 return setValueAtOffset(rewriter, loc, accumulator, value, shift);
3390 Value setLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3391 ConversionPatternRewriter &rewriter, Location loc,
3392 Value sgpr1, ArrayRef<Value> consts,
3393 int64_t offset)
const {
3394 Value ldsAddrIncrement = adaptor.getLdsIncrement();
3395 return setValueAtOffset(rewriter, loc, sgpr1, ldsAddrIncrement, offset);
3398 std::pair<Value, Value>
3399 setGlobalAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3400 ConversionPatternRewriter &rewriter, Location loc,
3401 Value sgpr2, Value sgpr3, ArrayRef<Value> consts,
3402 int64_t offset)
const {
3403 Value globalAddrIncrement = adaptor.getGlobalIncrement();
3404 sgpr2 = truncateAndSetValueAtOffset(rewriter, loc, sgpr2,
3405 globalAddrIncrement, offset);
3407 globalAddrIncrement =
3408 LLVM::LShrOp::create(rewriter, loc, globalAddrIncrement, shift);
3409 constexpr int64_t first16BitsHigh = (1ll << 16) - 1;
3410 sgpr3 = truncateAndSetValueAtOffset(rewriter, loc, sgpr3,
3411 globalAddrIncrement, offset + 32);
3413 sgpr3 = LLVM::AndOp::create(rewriter, loc, sgpr3, mask);
3414 return {sgpr2, sgpr3};
3417 Value setTensorDim3OrLDSAddrIncrement(DescriptorOp op, OpAdaptor adaptor,
3418 ConversionPatternRewriter &rewriter,
3419 Location loc, Value sgpr1,
3420 ArrayRef<Value> consts)
const {
3421 Value ldsIncrement = op.getLdsIncrement();
3422 constexpr int64_t dim = 3;
3423 constexpr int64_t offset = 32;
3425 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, consts, dim,
3427 return setLDSAddrIncrement(op, adaptor, rewriter, loc, sgpr1, consts,
3431 std::pair<Value, Value> setTensorDim2StrideOrGlobalAddrIncrement(
3432 DescriptorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter,
3433 Location loc, Value sgpr2, Value sgpr3, ArrayRef<Value> consts)
const {
3434 Value globalIncrement = op.getGlobalIncrement();
3435 constexpr int32_t dim = 2;
3436 constexpr int32_t offset = 64;
3437 if (!globalIncrement)
3438 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3439 consts, dim, offset);
3440 return setGlobalAddrIncrement(op, adaptor, rewriter, loc, sgpr2, sgpr3,
3444 Value setIterateCount(DescriptorOp op, OpAdaptor adaptor,
3445 ConversionPatternRewriter &rewriter, Location loc,
3446 Value sgpr3, ArrayRef<Value> consts,
3447 int32_t offset)
const {
3448 Value iterationCount = adaptor.getIterationCount();
3449 IntegerType i32 = rewriter.getI32Type();
3456 iterationCount = LLVM::TruncOp::create(rewriter, loc, i32, iterationCount);
3458 LLVM::SubOp::create(rewriter, loc, iterationCount, consts[1]);
3459 return setValueAtOffset(rewriter, loc, sgpr3, iterationCount, offset);
3462 Value setTileDim3OrIterateCount(DescriptorOp op, OpAdaptor adaptor,
3463 ConversionPatternRewriter &rewriter,
3464 Location loc, Value sgpr3,
3465 ArrayRef<Value> consts)
const {
3466 Value iterateCount = op.getIterationCount();
3467 constexpr int32_t dim = 2;
3468 constexpr int32_t offset = 112;
3470 return setTileDimX(op, adaptor, rewriter, loc, sgpr3, consts, dim,
3473 return setIterateCount(op, adaptor, rewriter, loc, sgpr3, consts, offset);
3476 Value getDGroup2(DescriptorOp op, OpAdaptor adaptor,
3477 ConversionPatternRewriter &rewriter, Location loc,
3478 ArrayRef<Value> consts)
const {
3479 if constexpr (DescriptorOp::isGather())
3480 return getDGroup2Gather(op, adaptor, rewriter, loc, consts);
3481 return getDGroup2NonGather(op, adaptor, rewriter, loc, consts);
3484 Value getDGroup2NonGather(DescriptorOp op, OpAdaptor adaptor,
3485 ConversionPatternRewriter &rewriter, Location loc,
3486 ArrayRef<Value> consts)
const {
3487 IntegerType i32 = rewriter.getI32Type();
3488 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3489 assert(v4i32 &&
"expected type conversion to succeed.");
3491 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
3492 if (onlyNeedsTwoDescriptors)
3493 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
3495 constexpr int64_t sgprlen = 4;
3496 Value sgprs[sgprlen];
3497 for (
int i = 0; i < sgprlen; ++i)
3498 sgprs[i] = consts[0];
3500 sgprs[0] = setTensorDim2(op, adaptor, rewriter, loc, sgprs[0], consts);
3501 sgprs[1] = setTensorDim3OrLDSAddrIncrement(op, adaptor, rewriter, loc,
3503 std::tie(sgprs[2], sgprs[3]) = setTensorDim2StrideOrGlobalAddrIncrement(
3504 op, adaptor, rewriter, loc, sgprs[2], sgprs[3], consts);
3506 setTileDim3OrIterateCount(op, adaptor, rewriter, loc, sgprs[3], consts);
3508 Value dgroup2 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3509 for (
auto [sgpr, constant] : llvm::zip(sgprs, consts))
3511 LLVM::InsertElementOp::create(rewriter, loc, dgroup2, sgpr, constant);
3516 Value getGatherIndices(DescriptorOp op, OpAdaptor adaptor,
3517 ConversionPatternRewriter &rewriter, Location loc,
3518 ArrayRef<Value> consts,
bool firstHalf)
const {
3519 IntegerType i32 = rewriter.getI32Type();
3520 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3521 assert(v4i32 &&
"expected type conversion to succeed.");
3523 Value
indices = adaptor.getIndices();
3524 auto vectorType = cast<VectorType>(
indices.getType());
3525 unsigned length = vectorType.getShape().back();
3526 Type elementType = vectorType.getElementType();
3527 unsigned maxLength = elementType == i32 ? 4 : 8;
3528 int32_t offset = firstHalf ? 0 : maxLength;
3529 unsigned discountedLength =
3530 std::max(
static_cast<int32_t
>(length - offset), 0);
3532 unsigned targetSize = std::min(maxLength, discountedLength);
3534 SmallVector<Value> indicesVector;
3535 for (
unsigned i = offset; i < targetSize + offset; ++i) {
3537 if (i < consts.size())
3541 Value elem = LLVM::ExtractElementOp::create(rewriter, loc,
indices, idx);
3542 indicesVector.push_back(elem);
3545 SmallVector<Value> indicesI32Vector;
3546 if (elementType == i32) {
3547 indicesI32Vector = indicesVector;
3549 for (
unsigned i = 0; i < targetSize; ++i) {
3550 Value index = indicesVector[i];
3551 indicesI32Vector.push_back(
3552 LLVM::ZExtOp::create(rewriter, loc, i32, index));
3554 if ((targetSize % 2) != 0)
3556 indicesI32Vector.push_back(consts[0]);
3559 SmallVector<Value> indicesToInsert;
3560 if (elementType == i32) {
3561 indicesToInsert = indicesI32Vector;
3563 unsigned size = indicesI32Vector.size() / 2;
3564 for (
unsigned i = 0; i < size; ++i) {
3565 Value first = indicesI32Vector[2 * i];
3566 Value second = indicesI32Vector[2 * i + 1];
3567 Value joined = setValueAtOffset(rewriter, loc, first, second, 16);
3568 indicesToInsert.push_back(joined);
3572 Value dgroup = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3573 for (
auto [sgpr, constant] : llvm::zip_first(indicesToInsert, consts))
3575 LLVM::InsertElementOp::create(rewriter, loc, dgroup, sgpr, constant);
3580 Value getDGroup2Gather(DescriptorOp op, OpAdaptor adaptor,
3581 ConversionPatternRewriter &rewriter, Location loc,
3582 ArrayRef<Value> consts)
const {
3583 return getGatherIndices(op, adaptor, rewriter, loc, consts,
true);
3586 std::pair<Value, Value>
3587 setTensorDim3Stride(DescriptorOp op, OpAdaptor adaptor,
3588 ConversionPatternRewriter &rewriter, Location loc,
3589 Value sgpr0, Value sgpr1, ArrayRef<Value> consts)
const {
3590 constexpr int32_t dim = 3;
3591 constexpr int32_t offset = 0;
3592 return setTensorDimXStride(op, adaptor, rewriter, loc, sgpr0, sgpr1, consts,
3596 std::pair<Value, Value> setTensorDim4(DescriptorOp op, OpAdaptor adaptor,
3597 ConversionPatternRewriter &rewriter,
3598 Location loc, Value sgpr1, Value sgpr2,
3599 ArrayRef<Value> consts)
const {
3600 constexpr int32_t dim = 4;
3601 constexpr int32_t offset = 48;
3602 return setTensorDimX(op, adaptor, rewriter, loc, sgpr1, sgpr2, consts, dim,
3606 Value setTileDim4(DescriptorOp op, OpAdaptor adaptor,
3607 ConversionPatternRewriter &rewriter, Location loc,
3608 Value sgpr2, ArrayRef<Value> consts)
const {
3609 constexpr int32_t dim = 4;
3610 constexpr int32_t offset = 80;
3611 return setTileDimX(op, adaptor, rewriter, loc, sgpr2, consts, dim, offset);
3614 Value getDGroup3(DescriptorOp op, OpAdaptor adaptor,
3615 ConversionPatternRewriter &rewriter, Location loc,
3616 ArrayRef<Value> consts)
const {
3617 if constexpr (DescriptorOp::isGather())
3618 return getDGroup3Gather(op, adaptor, rewriter, loc, consts);
3619 return getDGroup3NonGather(op, adaptor, rewriter, loc, consts);
3622 Value getDGroup3NonGather(DescriptorOp op, OpAdaptor adaptor,
3623 ConversionPatternRewriter &rewriter, Location loc,
3624 ArrayRef<Value> consts)
const {
3625 IntegerType i32 = rewriter.getI32Type();
3626 Type v4i32 = this->typeConverter->convertType(VectorType::get(4, i32));
3627 assert(v4i32 &&
"expected type conversion to succeed.");
3628 bool onlyNeedsTwoDescriptors = !op.getLdsIncrement() && op.getRank() <= 2;
3629 if (onlyNeedsTwoDescriptors)
3630 return LLVM::ZeroOp::create(rewriter, loc, v4i32);
3632 constexpr int32_t sgprlen = 4;
3633 Value sgprs[sgprlen];
3634 for (
int i = 0; i < sgprlen; ++i)
3635 sgprs[i] = consts[0];
3637 std::tie(sgprs[0], sgprs[1]) = setTensorDim3Stride(
3638 op, adaptor, rewriter, loc, sgprs[0], sgprs[1], consts);
3639 std::tie(sgprs[1], sgprs[2]) =
3640 setTensorDim4(op, adaptor, rewriter, loc, sgprs[1], sgprs[2], consts);
3641 sgprs[2] = setTileDim4(op, adaptor, rewriter, loc, sgprs[2], consts);
3643 Value dgroup3 = LLVM::PoisonOp::create(rewriter, loc, v4i32);
3644 for (
auto [sgpr, constant] : llvm::zip(sgprs, consts))
3646 LLVM::InsertElementOp::create(rewriter, loc, dgroup3, sgpr, constant);
3651 Value getDGroup3Gather(DescriptorOp op, OpAdaptor adaptor,
3652 ConversionPatternRewriter &rewriter, Location loc,
3653 ArrayRef<Value> consts)
const {
3654 return getGatherIndices(op, adaptor, rewriter, loc, consts,
false);
3658 matchAndRewrite(DescriptorOp op, OpAdaptor adaptor,
3659 ConversionPatternRewriter &rewriter)
const override {
3661 return op->emitOpError(
3662 "make_dma_descriptor is only supported on gfx1250");
3664 Location loc = op.getLoc();
3666 SmallVector<Value> consts;
3667 for (int64_t i = 0; i < 8; ++i)
3670 Value dgroup0 = this->getDGroup0(adaptor);
3671 Value dgroup1 = this->getDGroup1(op, adaptor, rewriter, loc, consts);
3672 Value dgroup2 = this->getDGroup2(op, adaptor, rewriter, loc, consts);
3673 Value dgroup3 = this->getDGroup3(op, adaptor, rewriter, loc, consts);
3674 SmallVector<Value> results = {dgroup0, dgroup1, dgroup2, dgroup3};
3675 rewriter.replaceOpWithMultiple(op, {results});
3680template <
typename SourceOp,
typename TargetOp>
3681struct AMDGPUTensorLoadStoreOpLowering
3682 :
public ConvertOpToLLVMPattern<SourceOp> {
3683 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
3685 AMDGPUTensorLoadStoreOpLowering(
const LLVMTypeConverter &converter,
3687 : ConvertOpToLLVMPattern<SourceOp>(converter), chipset(chipset) {}
3691 matchAndRewrite(SourceOp op, Adaptor adaptor,
3692 ConversionPatternRewriter &rewriter)
const override {
3694 return op->emitOpError(
"is only supported on gfx1250");
3697 rewriter.replaceOpWithNewOp<TargetOp>(op, desc[0], desc[1], desc[2],
3706struct ConvertAMDGPUToROCDLPass
3707 :
public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
3710 void runOnOperation()
override {
3713 if (
failed(maybeChipset)) {
3714 emitError(UnknownLoc::get(ctx),
"Invalid chipset name: " + chipset);
3715 return signalPassFailure();
3719 LLVMTypeConverter converter(ctx);
3722 amdgpu::populateCommonGPUTypeAndAttributeConversions(converter);
3724 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
3725 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
3726 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
3727 if (
failed(applyPartialConversion(getOperation(),
target,
3729 signalPassFailure();
3737 typeConverter, [](gpu::AddressSpace space) {
3739 case gpu::AddressSpace::Global:
3740 return ROCDL::ROCDLDialect::kGlobalMemoryAddressSpace;
3741 case gpu::AddressSpace::Workgroup:
3742 return ROCDL::ROCDLDialect::kSharedMemoryAddressSpace;
3743 case gpu::AddressSpace::Private:
3744 return ROCDL::ROCDLDialect::kPrivateMemoryAddressSpace;
3746 llvm_unreachable(
"unknown address space enum value");
3752 typeConverter.addTypeAttributeConversion(
3754 -> TypeConverter::AttributeConversionResult {
3756 Type i64 = IntegerType::get(ctx, 64);
3757 switch (as.getValue()) {
3758 case amdgpu::AddressSpace::FatRawBuffer:
3759 return IntegerAttr::get(i64, 7);
3760 case amdgpu::AddressSpace::BufferRsrc:
3761 return IntegerAttr::get(i64, 8);
3762 case amdgpu::AddressSpace::FatStructuredBuffer:
3763 return IntegerAttr::get(i64, 9);
3765 return TypeConverter::AttributeConversionResult::abort();
3767 typeConverter.addConversion([&](DsBarrierStateType type) ->
Type {
3768 return IntegerType::get(type.
getContext(), 64);
3770 typeConverter.addConversion([&](TDMBaseType type) ->
Type {
3772 return typeConverter.convertType(VectorType::get(4, i32));
3774 typeConverter.addConversion([&](TDMGatherBaseType type) ->
Type {
3776 return typeConverter.convertType(VectorType::get(4, i32));
3778 typeConverter.addConversion(
3779 [&](TDMDescriptorType type,
3782 Type v4i32 = typeConverter.convertType(VectorType::get(4, i32));
3783 Type v8i32 = typeConverter.convertType(VectorType::get(8, i32));
3784 llvm::append_values(
result, v4i32, v8i32, v4i32, v4i32);
3794 if (inputs.size() != 1)
3797 if (!isa<TDMDescriptorType>(inputs[0].
getType()))
3800 auto cast = UnrealizedConversionCastOp::create(builder, loc, types, inputs);
3801 return cast.getResults();
3804 typeConverter.addTargetMaterialization(addUnrealizedCast);
3812 .add<FatRawBufferCastLowering,
3813 RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
3814 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
3815 RawBufferOpLowering<RawBufferAtomicFaddOp,
3816 ROCDL::RawPtrBufferAtomicFaddOp>,
3817 RawBufferOpLowering<RawBufferAtomicFmaxOp,
3818 ROCDL::RawPtrBufferAtomicFmaxOp>,
3819 RawBufferOpLowering<RawBufferAtomicSmaxOp,
3820 ROCDL::RawPtrBufferAtomicSmaxOp>,
3821 RawBufferOpLowering<RawBufferAtomicUminOp,
3822 ROCDL::RawPtrBufferAtomicUminOp>,
3823 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
3824 ROCDL::RawPtrBufferAtomicCmpSwap>,
3825 AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
3826 SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
3827 SparseMFMAOpLowering, WMMAOpLowering, ScaledWMMAOpLowering,
3828 ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
3829 ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
3830 PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
3831 GatherToLDSOpLowering, TransposeLoadOpLowering,
3832 AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering<MakeDmaBaseOp>,
3833 AMDGPUMakeDmaBaseLowering<MakeGatherDmaBaseOp>,
3834 AMDGPULowerDescriptor<MakeDmaDescriptorOp>,
3835 AMDGPULowerDescriptor<MakeGatherDmaDescriptorOp>,
3836 AMDGPUTensorLoadStoreOpLowering<TensorLoadToLDSOp,
3837 ROCDL::TensorLoadToLDSOp>,
3838 AMDGPUTensorLoadStoreOpLowering<TensorStoreFromLDSOp,
3839 ROCDL::TensorStoreFromLDSOp>,
3840 DsBarrierInitOpLowering, DsBarrierPollStateOpLowering,
3841 DsAsyncBarrierArriveOpLowering, DsBarrierArriveOpLowering>(converter,
3843 patterns.add<AMDGPUSwizzleBitModeLowering, DsBarrierStatePhaseOpLowering,
3844 DsBarrierStatePendingCountOpLowering,
3845 DsBarrierStateInitCountOpLowering,
3846 DsBarrierStatePhaseParityLowering>(converter);
static bool typeIsExpectedFp8ForChipset(Chipset chipset, Type type)
Return true if type is the E4M3FN variant of an 8-bit float that is supported by the _fp8 instruction...
constexpr Chipset kGfx942
static std::optional< StringRef > wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k, bool isRDNA3)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma for RDNA3/4 architectures.
static std::optional< std::tuple< StringRef, uint32_t, uint32_t > > mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m, uint32_t n, uint32_t k, uint32_t b, Chipset chipset)
If there is a scaled MFMA instruction for the input element types aType and bType,...
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
constexpr Chipset kGfx908
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs, StringRef attrName)
Push an input operand.
constexpr Chipset kGfx1250
static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts the scaled MFMA/WMMA operands, scalesA and scalesB, from MLIR AMDGPU dialect convention to R...
constexpr Chipset kGfx90a
static std::optional< StringRef > getScaledWmmaIntrinsicName(int64_t m, int64_t n, int64_t k, bool isScale16)
Determines the ROCDL intrinsic name for scaled WMMA based on dimensions and scale block size (16 or 3...
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVectorImpl< Value > &operands, SmallVectorImpl< NamedAttribute > &attrs)
Push the output operand.
static bool typeIsExpectedBf8ForChipset(Chipset chipset, Type type)
Return true if type is the E5M2 variant of an 8-bit float that is supported by the _bf8 instructions ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Returns the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static std::optional< StringRef > smfmacOpToIntrinsic(SparseMFMAOp op, Chipset chipset)
Returns the rocdl intrinsic corresponding to a SparseMFMA (smfmac) operation if one exists.
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value basePointer, Value numRecords, bool boundsCheck, amdgpu::Chipset chipset, Value cacheSwizzleStride=nullptr, unsigned addressSpace=8)
static Value createI64Constant(ConversionPatternRewriter &rewriter, Location loc, int64_t value)
static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Converts sparse MFMA (smfmac) operands to the expected ROCDL types.
static std::optional< StringRef > wmmaOpToIntrinsicGfx1250(Type elemSourceType, Type elemBSourceType, Type elemDestType, uint32_t k)
Return the rocdl intrinsic corresponding to a WMMA operation wmma for the gfx1250 architecture.
static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, MemRefType memrefType, MemRefDescriptor &memrefDescriptor, ArrayRef< int64_t > strides, int64_t elementByteWidth, amdgpu::Chipset chipset, bool boundsCheck)
Compute the contents of the num_records field for a given memref descriptor - that is,...
static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input, bool allowBf16=true)
Pack small float vector operands (fp4/fp6/fp8/bf16) into the format expected by scaled matrix multipl...
static std::optional< uint32_t > getWmmaScaleFormat(Type elemType)
Maps f8 scale element types to WMMA scale format codes.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static std::optional< uint32_t > smallFloatTypeToFormatCode(Type mlirElemType)
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i64.
constexpr Chipset kGfx950
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static constexpr unsigned kSizePosInMemRefDescriptor
static constexpr unsigned kStridePosInMemRefDescriptor
static constexpr unsigned kOffsetPosInMemRefDescriptor
static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor
static constexpr unsigned kAlignedPtrPosInMemRefDescriptor
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
This class provides a shared interface for ranked and unranked memref types.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
typename SourceOp::template GenericAdaptor< ArrayRef< ValueRange > > OneToNOpAdaptor
typename SourceOp::Adaptor OpAdaptor
Value getStridedElementPtr(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none) const
Convenience wrapper for the corresponding helper utility.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
Value size(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
NamedAttribute represents a combination of a name and an Attribute value.
This class helps build Operations.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Value getStridedElementPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type, Value memRefDesc, ValueRange indices, LLVM::GEPNoWrapFlags noWrapFlags=LLVM::GEPNoWrapFlags::none)
Performs the index computation to get to the element at indices of the memory pointed to by memRefDes...
bool hasOcpFp8(const Chipset &chipset)
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: This function will also add conversions for the AMDGPU-specific address spaces and types,...
llvm::TypeSwitch< T, ResultT > TypeSwitch
void populateAMDGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap AMDGPU memory spaces to LLVM address spaces by mapping amdgpu::AddressSpace::fat_raw_buffer to ...
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.