22 #include "llvm/ADT/STLExtras.h"
26 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDL
27 #include "mlir/Conversion/Passes.h.inc"
36 return rewriter.
create<LLVM::ConstantOp>(loc, llvmI32, value);
42 return rewriter.
create<LLVM::ConstantOp>(loc, llvmI1, value);
52 template <
typename GpuOp,
typename Intrinsic>
58 static constexpr uint32_t maxVectorOpWidth = 128;
61 matchAndRewrite(GpuOp gpuOp,
typename GpuOp::Adaptor adaptor,
64 Value memref = adaptor.getMemref();
65 Value unconvertedMemref = gpuOp.getMemref();
66 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
69 return gpuOp.emitOpError(
"raw buffer ops require GCN or higher");
71 Value storeData = adaptor.getODSOperands(0)[0];
72 if (storeData == memref)
76 wantedDataType = storeData.
getType();
78 wantedDataType = gpuOp.getODSResults(0)[0].getType();
83 Value maybeCmpData = adaptor.getODSOperands(1)[0];
84 if (maybeCmpData != memref)
85 atomicCmpData = maybeCmpData;
88 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
91 Type llvmI32 = this->typeConverter->convertType(i32);
92 Type llvmI16 = this->typeConverter->convertType(rewriter.
getI16Type());
95 if (val.getType() == llvmI32)
98 return rewriter.
create<LLVM::TruncOp>(loc, llvmI32, val);
101 int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8;
109 Type llvmBufferValType = llvmWantedDataType;
111 if (
auto floatType = dyn_cast<FloatType>(wantedDataType))
112 llvmBufferValType = this->getTypeConverter()->convertType(
115 if (
auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
116 uint32_t vecLen = dataVector.getNumElements();
117 uint32_t elemBits = dataVector.getElementTypeBitWidth();
118 uint32_t totalBits = elemBits * vecLen;
120 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
121 if (totalBits > maxVectorOpWidth)
122 return gpuOp.emitOpError(
123 "Total width of loads or stores must be no more than " +
124 Twine(maxVectorOpWidth) +
" bits, but we call for " +
126 " bits. This should've been caught in validation");
127 if (!usePackedFp16 && elemBits < 32) {
128 if (totalBits > 32) {
129 if (totalBits % 32 != 0)
130 return gpuOp.emitOpError(
"Load or store of more than 32-bits that "
131 "doesn't fit into words. Can't happen\n");
132 llvmBufferValType = this->typeConverter->convertType(
135 llvmBufferValType = this->typeConverter->convertType(
143 if (llvmBufferValType != llvmWantedDataType) {
145 rewriter.
create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
146 args.push_back(castForStore);
148 args.push_back(storeData);
153 if (llvmBufferValType != llvmWantedDataType) {
154 Value castForCmp = rewriter.
create<LLVM::BitcastOp>(
155 loc, llvmBufferValType, atomicCmpData);
156 args.push_back(castForCmp);
158 args.push_back(atomicCmpData);
166 return gpuOp.emitOpError(
"Can't lower non-stride-offset memrefs");
170 Value ptr = memrefDescriptor.alignedPtr(rewriter, loc);
176 if (memrefType.hasStaticShape() && memrefType.getLayout().isIdentity()) {
179 static_cast<int32_t
>(memrefType.getNumElements() * elementByteWidth));
182 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
183 Value size = toI32(memrefDescriptor.size(rewriter, loc, i));
184 Value stride = toI32(memrefDescriptor.stride(rewriter, loc, i));
185 stride = rewriter.
create<LLVM::MulOp>(loc, stride, byteWidthConst);
186 Value maxThisDim = rewriter.
create<LLVM::MulOp>(loc, size, stride);
187 maxIndex = maxIndex ? rewriter.
create<LLVM::MaximumOp>(loc, maxIndex,
191 numRecords = maxIndex;
208 uint32_t flags = (7 << 12) | (4 << 15);
211 uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2;
212 flags |= (oob << 28);
217 loc, rsrcType, ptr, stride, numRecords, flagsConst);
218 args.push_back(resource);
223 size_t i = pair.index();
224 Value index = pair.value();
226 if (ShapedType::isDynamic(strides[i])) {
227 strideOp = rewriter.
create<LLVM::MulOp>(
228 loc, toI32(memrefDescriptor.stride(rewriter, loc, i)),
234 index = rewriter.
create<LLVM::MulOp>(loc, index, strideOp);
235 voffset = rewriter.
create<LLVM::AddOp>(loc, voffset, index);
237 if (adaptor.getIndexOffset()) {
238 int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth;
241 voffset ? rewriter.
create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
244 args.push_back(voffset);
246 Value sgprOffset = adaptor.getSgprOffset();
249 if (ShapedType::isDynamic(offset))
250 sgprOffset = rewriter.
create<LLVM::AddOp>(
251 loc, toI32(memrefDescriptor.offset(rewriter, loc)), sgprOffset);
253 sgprOffset = rewriter.
create<LLVM::AddOp>(
255 args.push_back(sgprOffset);
264 Operation *lowered = rewriter.
create<Intrinsic>(loc, resultTypes, args,
268 if (llvmBufferValType != llvmWantedDataType) {
269 replacement = rewriter.
create<LLVM::BitcastOp>(loc, llvmWantedDataType,
287 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
289 bool requiresInlineAsm = chipset < kGfx90a || chipset.
majorVersion == 11;
291 if (requiresInlineAsm) {
293 LLVM::AsmDialect::AD_ATT);
295 ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
296 const char *constraints =
"";
300 asmStr, constraints,
true,
301 false, asmDialectAttr,
306 constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
307 constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8);
310 constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4);
314 ldsOnlyBits = ldsOnlyBitsGfx11;
316 ldsOnlyBits = ldsOnlyBitsGfx10;
318 ldsOnlyBits = ldsOnlyBitsGfx6789;
320 return op.emitOpError(
321 "don't know how to lower this for chipset major version")
325 rewriter.
create<ROCDL::WaitcntOp>(loc, ldsOnlyBits);
329 rewriter.
create<ROCDL::WaitDscntOp>(loc, 0);
330 rewriter.
create<ROCDL::BarrierSignalOp>(loc, -1);
345 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
348 (uint32_t)op.getOpts());
364 if (
auto vectorType = dyn_cast<VectorType>(inputType)) {
365 if (vectorType.getElementType().isBF16())
366 return rewriter.
create<LLVM::BitcastOp>(
367 loc, vectorType.clone(rewriter.
getI16Type()), input);
368 if (vectorType.getElementType().isInteger(8)) {
369 return rewriter.
create<LLVM::BitcastOp>(
370 loc, rewriter.
getIntegerType(vectorType.getNumElements() * 8), input);
384 bool isUnsigned,
Value llvmInput,
388 auto vectorType = dyn_cast<VectorType>(inputType);
389 Type elemType = vectorType.getElementType();
392 llvmInput = rewriter.
create<LLVM::BitcastOp>(
393 loc, vectorType.clone(rewriter.
getI16Type()), llvmInput);
395 operands.push_back(llvmInput);
402 auto mlirInputType = cast<VectorType>(mlirInput.
getType());
403 bool isInputInt8 = mlirInputType.getElementType().isInteger(8);
406 bool localIsUnsigned = isUnsigned;
408 localIsUnsigned =
true;
410 localIsUnsigned =
false;
413 operands.push_back(sign);
416 int64_t numBytes = vectorType.getNumElements();
419 auto llvmVectorType32bits = typeConverter->
convertType(vectorType32bits);
421 loc, llvmVectorType32bits, llvmInput);
422 operands.push_back(result);
434 Value output, int32_t subwordOffset,
437 auto vectorType = dyn_cast<VectorType>(inputType);
438 Type elemType = vectorType.getElementType();
440 output = rewriter.
create<LLVM::BitcastOp>(
441 loc, vectorType.clone(rewriter.
getI16Type()), output);
442 operands.push_back(output);
455 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
456 b = mfma.getBlocks();
457 Type sourceElem = mfma.getSourceA().getType();
458 if (
auto sourceType = dyn_cast<VectorType>(sourceElem))
459 sourceElem = sourceType.getElementType();
460 Type destElem = mfma.getDestC().getType();
461 if (
auto destType = dyn_cast<VectorType>(destElem))
462 destElem = destType.getElementType();
465 if (mfma.getReducePrecision() && chipset >= kGfx940) {
466 if (m == 32 && n == 32 && k == 4 && b == 1)
467 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
468 if (m == 16 && n == 16 && k == 8 && b == 1)
469 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
471 if (m == 32 && n == 32 && k == 1 && b == 2)
472 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
473 if (m == 16 && n == 16 && k == 1 && b == 4)
474 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
475 if (m == 4 && n == 4 && k == 1 && b == 16)
476 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
477 if (m == 32 && n == 32 && k == 2 && b == 1)
478 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
479 if (m == 16 && n == 16 && k == 4 && b == 1)
480 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
484 if (m == 32 && n == 32 && k == 4 && b == 2)
485 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
486 if (m == 16 && n == 16 && k == 4 && b == 4)
487 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
488 if (m == 4 && n == 4 && k == 4 && b == 16)
489 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
490 if (m == 32 && n == 32 && k == 8 && b == 1)
491 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
492 if (m == 16 && n == 16 && k == 16 && b == 1)
493 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
496 if (sourceElem.
isBF16() && destElem.
isF32() && chipset >= kGfx90a) {
497 if (m == 32 && n == 32 && k == 4 && b == 2)
498 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
499 if (m == 16 && n == 16 && k == 4 && b == 4)
500 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
501 if (m == 4 && n == 4 && k == 4 && b == 16)
502 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
503 if (m == 32 && n == 32 && k == 8 && b == 1)
504 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
505 if (m == 16 && n == 16 && k == 16 && b == 1)
506 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
510 if (m == 32 && n == 32 && k == 2 && b == 2)
511 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
512 if (m == 16 && n == 16 && k == 2 && b == 4)
513 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
514 if (m == 4 && n == 4 && k == 2 && b == 16)
515 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
516 if (m == 32 && n == 32 && k == 4 && b == 1)
517 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
518 if (m == 16 && n == 16 && k == 8 && b == 1)
519 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
522 if (isa<IntegerType>(sourceElem) && destElem.
isInteger(32)) {
523 if (m == 32 && n == 32 && k == 4 && b == 2)
524 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
525 if (m == 16 && n == 16 && k == 4 && b == 4)
526 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
527 if (m == 4 && n == 4 && k == 4 && b == 16)
528 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
529 if (m == 32 && n == 32 && k == 8 && b == 1)
530 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
531 if (m == 16 && n == 16 && k == 16 && b == 1)
532 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
533 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx940)
534 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
535 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx940)
536 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
539 if (sourceElem.
isF64() && destElem.
isF64() && chipset >= kGfx90a) {
540 if (m == 16 && n == 16 && k == 4 && b == 1)
541 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
542 if (m == 4 && n == 4 && k == 4 && b == 4)
543 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
550 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
551 if (m == 16 && n == 16 && k == 32 && b == 1) {
553 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
555 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
557 if (m == 32 && n == 32 && k == 16 && b == 1) {
559 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
561 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
567 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
568 if (m == 16 && n == 16 && k == 32 && b == 1) {
570 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
572 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
574 if (m == 32 && n == 32 && k == 16 && b == 1) {
576 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
578 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
590 auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
591 auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
592 auto elemSourceType = sourceVectorType.getElementType();
593 auto elemDestType = destVectorType.getElementType();
595 if (elemSourceType.isF16() && elemDestType.isF32())
596 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
597 if (elemSourceType.isBF16() && elemDestType.isF32())
598 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
599 if (elemSourceType.isF16() && elemDestType.isF16())
600 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
601 if (elemSourceType.isBF16() && elemDestType.isBF16())
602 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
603 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
604 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
605 if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32())
606 return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
607 if (elemSourceType.isFloat8E5M2() && elemDestType.isF32())
608 return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
620 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
623 Type outType = typeConverter->convertType(op.getDestD().getType());
624 Type intrinsicOutType = outType;
625 if (
auto outVecType = dyn_cast<VectorType>(outType))
626 if (outVecType.getElementType().isBF16())
627 intrinsicOutType = outVecType.clone(rewriter.
getI16Type());
630 return op->emitOpError(
"MFMA only supported on gfx908+");
631 uint32_t getBlgpField =
static_cast<uint32_t
>(op.getBlgp());
632 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
633 if (chipset < kGfx940)
634 return op.emitOpError(
"negation unsupported on older than gfx940");
636 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
639 if (!maybeIntrinsic.has_value())
640 return op.emitOpError(
"no intrinsic matching MFMA size on given chipset");
642 loweredOp.addTypes(intrinsicOutType);
643 loweredOp.addOperands(
650 if (outType != intrinsicOutType)
651 lowered = rewriter.
create<LLVM::BitcastOp>(loc, outType, lowered);
664 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
668 typeConverter->convertType<VectorType>(op.getDestD().getType());
673 return op->emitOpError(
"WMMA only supported on gfx11 and gfx12");
677 VectorType rawOutType = outType;
678 if (outType.getElementType().
isBF16())
679 rawOutType = outType.clone(rewriter.
getI16Type());
683 if (!maybeIntrinsic.has_value())
684 return op.emitOpError(
"no intrinsic matching WMMA on the given chipset");
687 loweredOp.addTypes(rawOutType);
691 adaptor.getSourceA(), op.getSourceA(), operands);
693 adaptor.getSourceB(), op.getSourceB(), operands);
695 op.getSubwordOffset(), op.getClamp(), operands);
697 loweredOp.addOperands(operands);
701 if (rawOutType != outType)
711 struct ExtPackedFp8OpLowering final
719 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
723 struct PackedTrunc2xFp8OpLowering final
732 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
736 struct PackedStochRoundFp8OpLowering final
745 matchAndRewrite(PackedStochRoundFp8Op op,
746 PackedStochRoundFp8OpAdaptor adaptor,
751 LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
752 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
757 loc,
"Fp8 conversion instructions are not available on target "
758 "architecture and their emulation is not implemented");
761 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
762 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
764 Value source = adaptor.getSource();
765 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
768 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
769 Value longVec = rewriter.
create<LLVM::UndefOp>(loc, v4i8);
770 if (!sourceVecType) {
771 longVec = rewriter.
create<LLVM::InsertElementOp>(
774 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
776 Value elem = rewriter.
create<LLVM::ExtractElementOp>(loc, source, idx);
778 rewriter.
create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
783 Value i32Source = rewriter.
create<LLVM::BitcastOp>(loc, i32, source);
795 LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
796 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
801 loc,
"Fp8 conversion instructions are not available on target "
802 "architecture and their emulation is not implemented");
803 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
805 Type resultType = op.getResult().getType();
808 Value sourceA = adaptor.getSourceA();
809 Value sourceB = adaptor.getSourceB();
811 sourceB = rewriter.
create<LLVM::UndefOp>(loc, sourceA.
getType());
812 Value existing = adaptor.getExisting();
814 existing = rewriter.
create<LLVM::BitcastOp>(loc, i32, existing);
816 existing = rewriter.
create<LLVM::UndefOp>(loc, i32);
821 result = rewriter.
create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
824 result = rewriter.
create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
828 op, getTypeConverter()->convertType(resultType), result);
832 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
833 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
838 loc,
"Fp8 conversion instructions are not available on target "
839 "architecture and their emulation is not implemented");
840 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
842 Type resultType = op.getResult().getType();
845 Value source = adaptor.getSource();
846 Value stoch = adaptor.getStochiasticParam();
847 Value existing = adaptor.getExisting();
849 existing = rewriter.
create<LLVM::BitcastOp>(loc, i32, existing);
851 existing = rewriter.
create<LLVM::UndefOp>(loc, i32);
856 result = rewriter.
create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
859 result = rewriter.
create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
863 op, getTypeConverter()->convertType(resultType), result);
875 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
880 Value src = adaptor.getSrc();
881 Value old = adaptor.getOld();
884 Type llvmType =
nullptr;
887 }
else if (isa<FloatType>(srcType)) {
891 }
else if (isa<IntegerType>(srcType)) {
896 auto llvmSrcIntType = typeConverter->convertType(
900 auto convertOperand = [&](
Value operand,
Type operandType) {
901 if (operandType.getIntOrFloatBitWidth() <= 16) {
902 if (llvm::isa<FloatType>(operandType)) {
904 rewriter.
create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
907 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
908 Value undefVec = rewriter.
create<LLVM::UndefOp>(loc, llvmVecType);
909 operand = rewriter.
create<LLVM::InsertElementOp>(
911 operand = rewriter.
create<LLVM::BitcastOp>(loc, llvmType, operand);
916 src = convertOperand(src, srcType);
917 old = convertOperand(old, oldType);
920 enum DppCtrl :
unsigned {
929 ROW_HALF_MIRROR = 0x141,
934 auto kind = DppOp.getKind();
935 auto permArgument = DppOp.getPermArgument();
936 uint32_t DppCtrl = 0;
940 case DPPPerm::quad_perm:
941 if (
auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
943 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
944 uint32_t num = elem.getInt();
945 DppCtrl |= num << (i * 2);
950 case DPPPerm::row_shl:
951 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
952 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
955 case DPPPerm::row_shr:
956 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
957 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
960 case DPPPerm::row_ror:
961 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
962 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
965 case DPPPerm::wave_shl:
966 DppCtrl = DppCtrl::WAVE_SHL1;
968 case DPPPerm::wave_shr:
969 DppCtrl = DppCtrl::WAVE_SHR1;
971 case DPPPerm::wave_rol:
972 DppCtrl = DppCtrl::WAVE_ROL1;
974 case DPPPerm::wave_ror:
975 DppCtrl = DppCtrl::WAVE_ROR1;
977 case DPPPerm::row_mirror:
978 DppCtrl = DppCtrl::ROW_MIRROR;
980 case DPPPerm::row_half_mirror:
981 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
983 case DPPPerm::row_bcast_15:
984 DppCtrl = DppCtrl::BCAST15;
986 case DPPPerm::row_bcast_31:
987 DppCtrl = DppCtrl::BCAST31;
993 auto rowMask = DppOp->getAttrOfType<IntegerAttr>(
"row_mask").getInt();
994 auto bankMask = DppOp->getAttrOfType<IntegerAttr>(
"bank_mask").getInt();
995 bool boundCtrl = DppOp->getAttrOfType<
BoolAttr>(
"bound_ctrl").getValue();
998 auto dppMovOp = rewriter.
create<ROCDL::DPPUpdateOp>(
999 loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
1001 Value result = dppMovOp.getRes();
1003 result = rewriter.
create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
1004 if (!llvm::isa<IntegerType>(srcType)) {
1005 result = rewriter.
create<LLVM::BitcastOp>(loc, srcType, result);
1016 struct ConvertAMDGPUToROCDLPass
1017 :
public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
1018 ConvertAMDGPUToROCDLPass() =
default;
1020 void runOnOperation()
override {
1023 if (failed(maybeChipset)) {
1025 return signalPassFailure();
1032 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
1033 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
1034 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
1037 signalPassFailure();
1046 .add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
1047 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
1048 RawBufferOpLowering<RawBufferAtomicFaddOp,
1049 ROCDL::RawPtrBufferAtomicFaddOp>,
1050 RawBufferOpLowering<RawBufferAtomicFmaxOp,
1051 ROCDL::RawPtrBufferAtomicFmaxOp>,
1052 RawBufferOpLowering<RawBufferAtomicSmaxOp,
1053 ROCDL::RawPtrBufferAtomicSmaxOp>,
1054 RawBufferOpLowering<RawBufferAtomicUminOp,
1055 ROCDL::RawPtrBufferAtomicUminOp>,
1056 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
1057 ROCDL::RawPtrBufferAtomicCmpSwap>,
1058 AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1059 MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
1060 PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
1065 return std::make_unique<ConvertAMDGPUToROCDLPass>();
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Return the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value)
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector< Value, 4 > &operands)
Push the output operand.
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVector< Value, 4 > &operands)
Push an input operand.
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsic...
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static MLIRContext * getContext(OpFoldResult val)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getI16IntegerAttr(int16_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isFloat8E4M3FNUZ() const
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isFloat8E5M2FNUZ() const
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
std::unique_ptr< Pass > createConvertAMDGPUToROCDLPass()
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void populateAMDGPUToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: The ROCDL target does not support the LLVM bfloat type at this time and so this function will a...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.