22 #include "llvm/ADT/STLExtras.h"
26 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDL
27 #include "mlir/Conversion/Passes.h.inc"
36 return rewriter.
create<LLVM::ConstantOp>(loc, llvmI32, value);
42 return rewriter.
create<LLVM::ConstantOp>(loc, llvmI1, value);
52 template <
typename GpuOp,
typename Intrinsic>
58 static constexpr uint32_t maxVectorOpWidth = 128;
61 matchAndRewrite(GpuOp gpuOp,
typename GpuOp::Adaptor adaptor,
64 Value memref = adaptor.getMemref();
65 Value unconvertedMemref = gpuOp.getMemref();
66 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
69 return gpuOp.emitOpError(
"raw buffer ops require GCN or higher");
71 Value storeData = adaptor.getODSOperands(0)[0];
72 if (storeData == memref)
76 wantedDataType = storeData.
getType();
78 wantedDataType = gpuOp.getODSResults(0)[0].getType();
83 Value maybeCmpData = adaptor.getODSOperands(1)[0];
84 if (maybeCmpData != memref)
85 atomicCmpData = maybeCmpData;
88 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
91 Type llvmI32 = this->typeConverter->convertType(i32);
92 Type llvmI16 = this->typeConverter->convertType(rewriter.
getI16Type());
94 int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8;
102 Type llvmBufferValType = llvmWantedDataType;
104 if (
auto floatType = dyn_cast<FloatType>(wantedDataType))
105 llvmBufferValType = this->getTypeConverter()->convertType(
108 if (
auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
109 uint32_t vecLen = dataVector.getNumElements();
110 uint32_t elemBits = dataVector.getElementTypeBitWidth();
111 uint32_t totalBits = elemBits * vecLen;
113 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
114 if (totalBits > maxVectorOpWidth)
115 return gpuOp.emitOpError(
116 "Total width of loads or stores must be no more than " +
117 Twine(maxVectorOpWidth) +
" bits, but we call for " +
119 " bits. This should've been caught in validation");
120 if (!usePackedFp16 && elemBits < 32) {
121 if (totalBits > 32) {
122 if (totalBits % 32 != 0)
123 return gpuOp.emitOpError(
"Load or store of more than 32-bits that "
124 "doesn't fit into words. Can't happen\n");
125 llvmBufferValType = this->typeConverter->convertType(
128 llvmBufferValType = this->typeConverter->convertType(
136 if (llvmBufferValType != llvmWantedDataType) {
138 rewriter.
create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
139 args.push_back(castForStore);
141 args.push_back(storeData);
146 if (llvmBufferValType != llvmWantedDataType) {
147 Value castForCmp = rewriter.
create<LLVM::BitcastOp>(
148 loc, llvmBufferValType, atomicCmpData);
149 args.push_back(castForCmp);
151 args.push_back(atomicCmpData);
159 return gpuOp.emitOpError(
"Can't lower non-stride-offset memrefs");
163 Value ptr = memrefDescriptor.alignedPtr(rewriter, loc);
169 if (memrefType.hasStaticShape()) {
172 static_cast<int32_t
>(memrefType.getNumElements() * elementByteWidth));
175 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
176 Value size = memrefDescriptor.size(rewriter, loc, i);
177 Value stride = memrefDescriptor.stride(rewriter, loc, i);
178 stride = rewriter.
create<LLVM::MulOp>(loc, stride, byteWidthConst);
179 Value maxThisDim = rewriter.
create<LLVM::MulOp>(loc, size, stride);
180 maxIndex = maxIndex ? rewriter.
create<LLVM::MaximumOp>(loc, maxIndex,
184 numRecords = rewriter.
create<LLVM::TruncOp>(loc, llvmI32, maxIndex);
201 uint32_t flags = (7 << 12) | (4 << 15);
204 uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2;
205 flags |= (oob << 28);
210 loc, rsrcType, ptr, stride, numRecords, flagsConst);
211 args.push_back(resource);
216 size_t i = pair.index();
217 Value index = pair.value();
219 if (ShapedType::isDynamic(strides[i])) {
220 strideOp = rewriter.
create<LLVM::MulOp>(
221 loc, memrefDescriptor.stride(rewriter, loc, i), byteWidthConst);
226 index = rewriter.
create<LLVM::MulOp>(loc, index, strideOp);
227 voffset = rewriter.
create<LLVM::AddOp>(loc, voffset, index);
229 if (adaptor.getIndexOffset()) {
230 int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth;
233 voffset ? rewriter.
create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
236 args.push_back(voffset);
238 Value sgprOffset = adaptor.getSgprOffset();
241 if (ShapedType::isDynamic(offset))
242 sgprOffset = rewriter.
create<LLVM::AddOp>(
243 loc, memrefDescriptor.offset(rewriter, loc), sgprOffset);
245 sgprOffset = rewriter.
create<LLVM::AddOp>(
247 args.push_back(sgprOffset);
256 Operation *lowered = rewriter.
create<Intrinsic>(loc, resultTypes, args,
260 if (llvmBufferValType != llvmWantedDataType) {
261 replacement = rewriter.
create<LLVM::BitcastOp>(loc, llvmWantedDataType,
279 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
281 bool requiresInlineAsm = chipset < kGfx90a || chipset.
majorVersion == 11;
283 if (requiresInlineAsm) {
285 LLVM::AsmDialect::AD_ATT);
287 ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
288 const char *constraints =
"";
292 asmStr, constraints,
true,
293 false, asmDialectAttr,
298 constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
299 constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8);
302 constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4);
306 ldsOnlyBits = ldsOnlyBitsGfx11;
308 ldsOnlyBits = ldsOnlyBitsGfx10;
310 ldsOnlyBits = ldsOnlyBitsGfx6789;
312 return op.emitOpError(
313 "don't know how to lower this for chipset major version")
317 rewriter.
create<ROCDL::WaitcntOp>(loc, ldsOnlyBits);
321 rewriter.
create<ROCDL::WaitDscntOp>(loc, 0);
322 rewriter.
create<ROCDL::BarrierSignalOp>(loc, -1);
337 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
340 (uint32_t)op.getOpts());
356 if (
auto vectorType = dyn_cast<VectorType>(inputType)) {
357 if (vectorType.getElementType().isBF16())
358 return rewriter.
create<LLVM::BitcastOp>(
359 loc, vectorType.clone(rewriter.
getI16Type()), input);
360 if (vectorType.getElementType().isInteger(8)) {
361 return rewriter.
create<LLVM::BitcastOp>(
362 loc, rewriter.
getIntegerType(vectorType.getNumElements() * 8), input);
376 bool isUnsigned,
Value llvmInput,
380 auto vectorType = dyn_cast<VectorType>(inputType);
381 Type elemType = vectorType.getElementType();
384 llvmInput = rewriter.
create<LLVM::BitcastOp>(
385 loc, vectorType.clone(rewriter.
getI16Type()), llvmInput);
387 operands.push_back(llvmInput);
394 auto mlirInputType = cast<VectorType>(mlirInput.
getType());
395 bool isInputInt8 = mlirInputType.getElementType().isInteger(8);
398 bool localIsUnsigned = isUnsigned;
400 localIsUnsigned =
true;
402 localIsUnsigned =
false;
405 operands.push_back(sign);
408 int64_t numBytes = vectorType.getNumElements();
411 auto llvmVectorType32bits = typeConverter->
convertType(vectorType32bits);
413 loc, llvmVectorType32bits, llvmInput);
414 operands.push_back(result);
426 Value output, int32_t subwordOffset,
429 auto vectorType = dyn_cast<VectorType>(inputType);
430 Type elemType = vectorType.getElementType();
432 output = rewriter.
create<LLVM::BitcastOp>(
433 loc, vectorType.clone(rewriter.
getI16Type()), output);
434 operands.push_back(output);
447 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
448 b = mfma.getBlocks();
449 Type sourceElem = mfma.getSourceA().getType();
450 if (
auto sourceType = dyn_cast<VectorType>(sourceElem))
451 sourceElem = sourceType.getElementType();
452 Type destElem = mfma.getDestC().getType();
453 if (
auto destType = dyn_cast<VectorType>(destElem))
454 destElem = destType.getElementType();
457 if (mfma.getReducePrecision() && chipset >= kGfx940) {
458 if (m == 32 && n == 32 && k == 4 && b == 1)
459 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
460 if (m == 16 && n == 16 && k == 8 && b == 1)
461 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
463 if (m == 32 && n == 32 && k == 1 && b == 2)
464 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
465 if (m == 16 && n == 16 && k == 1 && b == 4)
466 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
467 if (m == 4 && n == 4 && k == 1 && b == 16)
468 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
469 if (m == 32 && n == 32 && k == 2 && b == 1)
470 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
471 if (m == 16 && n == 16 && k == 4 && b == 1)
472 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
476 if (m == 32 && n == 32 && k == 4 && b == 2)
477 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
478 if (m == 16 && n == 16 && k == 4 && b == 4)
479 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
480 if (m == 4 && n == 4 && k == 4 && b == 16)
481 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
482 if (m == 32 && n == 32 && k == 8 && b == 1)
483 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
484 if (m == 16 && n == 16 && k == 16 && b == 1)
485 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
488 if (sourceElem.
isBF16() && destElem.
isF32() && chipset >= kGfx90a) {
489 if (m == 32 && n == 32 && k == 4 && b == 2)
490 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
491 if (m == 16 && n == 16 && k == 4 && b == 4)
492 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
493 if (m == 4 && n == 4 && k == 4 && b == 16)
494 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
495 if (m == 32 && n == 32 && k == 8 && b == 1)
496 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
497 if (m == 16 && n == 16 && k == 16 && b == 1)
498 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
502 if (m == 32 && n == 32 && k == 2 && b == 2)
503 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
504 if (m == 16 && n == 16 && k == 2 && b == 4)
505 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
506 if (m == 4 && n == 4 && k == 2 && b == 16)
507 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
508 if (m == 32 && n == 32 && k == 4 && b == 1)
509 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
510 if (m == 16 && n == 16 && k == 8 && b == 1)
511 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
514 if (isa<IntegerType>(sourceElem) && destElem.
isInteger(32)) {
515 if (m == 32 && n == 32 && k == 4 && b == 2)
516 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
517 if (m == 16 && n == 16 && k == 4 && b == 4)
518 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
519 if (m == 4 && n == 4 && k == 4 && b == 16)
520 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
521 if (m == 32 && n == 32 && k == 8 && b == 1)
522 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
523 if (m == 16 && n == 16 && k == 16 && b == 1)
524 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
525 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx940)
526 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
527 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx940)
528 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
531 if (sourceElem.
isF64() && destElem.
isF64() && chipset >= kGfx90a) {
532 if (m == 16 && n == 16 && k == 4 && b == 1)
533 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
534 if (m == 4 && n == 4 && k == 4 && b == 4)
535 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
542 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
543 if (m == 16 && n == 16 && k == 32 && b == 1) {
545 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
547 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
549 if (m == 32 && n == 32 && k == 16 && b == 1) {
551 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
553 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
559 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
560 if (m == 16 && n == 16 && k == 32 && b == 1) {
562 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
564 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
566 if (m == 32 && n == 32 && k == 16 && b == 1) {
568 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
570 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
582 auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
583 auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
584 auto elemSourceType = sourceVectorType.getElementType();
585 auto elemDestType = destVectorType.getElementType();
587 if (elemSourceType.isF16() && elemDestType.isF32())
588 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
589 if (elemSourceType.isBF16() && elemDestType.isF32())
590 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
591 if (elemSourceType.isF16() && elemDestType.isF16())
592 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
593 if (elemSourceType.isBF16() && elemDestType.isBF16())
594 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
595 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
596 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
597 if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32())
598 return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
599 if (elemSourceType.isFloat8E5M2() && elemDestType.isF32())
600 return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
612 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
615 Type outType = typeConverter->convertType(op.getDestD().getType());
616 Type intrinsicOutType = outType;
617 if (
auto outVecType = dyn_cast<VectorType>(outType))
618 if (outVecType.getElementType().isBF16())
619 intrinsicOutType = outVecType.clone(rewriter.
getI16Type());
622 return op->emitOpError(
"MFMA only supported on gfx908+");
623 uint32_t getBlgpField =
static_cast<uint32_t
>(op.getBlgp());
624 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
625 if (chipset < kGfx940)
626 return op.emitOpError(
"negation unsupported on older than gfx940");
628 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
631 if (!maybeIntrinsic.has_value())
632 return op.emitOpError(
"no intrinsic matching MFMA size on given chipset");
634 loweredOp.addTypes(intrinsicOutType);
635 loweredOp.addOperands(
642 if (outType != intrinsicOutType)
643 lowered = rewriter.
create<LLVM::BitcastOp>(loc, outType, lowered);
656 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
660 typeConverter->convertType<VectorType>(op.getDestD().getType());
665 return op->emitOpError(
"WMMA only supported on gfx11 and gfx12");
669 VectorType rawOutType = outType;
670 if (outType.getElementType().
isBF16())
671 rawOutType = outType.clone(rewriter.
getI16Type());
675 if (!maybeIntrinsic.has_value())
676 return op.emitOpError(
"no intrinsic matching WMMA on the given chipset");
679 loweredOp.addTypes(rawOutType);
683 adaptor.getSourceA(), op.getSourceA(), operands);
685 adaptor.getSourceB(), op.getSourceB(), operands);
687 op.getSubwordOffset(), op.getClamp(), operands);
689 loweredOp.addOperands(operands);
693 if (rawOutType != outType)
703 struct ExtPackedFp8OpLowering final
711 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
715 struct PackedTrunc2xFp8OpLowering final
724 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
728 struct PackedStochRoundFp8OpLowering final
737 matchAndRewrite(PackedStochRoundFp8Op op,
738 PackedStochRoundFp8OpAdaptor adaptor,
743 LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
744 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
749 loc,
"Fp8 conversion instructions are not available on target "
750 "architecture and their emulation is not implemented");
753 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
754 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
756 Value source = adaptor.getSource();
757 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
760 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
761 Value longVec = rewriter.
create<LLVM::UndefOp>(loc, v4i8);
762 if (!sourceVecType) {
763 longVec = rewriter.
create<LLVM::InsertElementOp>(
766 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
768 Value elem = rewriter.
create<LLVM::ExtractElementOp>(loc, source, idx);
770 rewriter.
create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
775 Value i32Source = rewriter.
create<LLVM::BitcastOp>(loc, i32, source);
787 LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
788 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
793 loc,
"Fp8 conversion instructions are not available on target "
794 "architecture and their emulation is not implemented");
795 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
797 Type resultType = op.getResult().getType();
800 Value sourceA = adaptor.getSourceA();
801 Value sourceB = adaptor.getSourceB();
803 sourceB = rewriter.
create<LLVM::UndefOp>(loc, sourceA.
getType());
804 Value existing = adaptor.getExisting();
806 existing = rewriter.
create<LLVM::BitcastOp>(loc, i32, existing);
808 existing = rewriter.
create<LLVM::UndefOp>(loc, i32);
813 result = rewriter.
create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
816 result = rewriter.
create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
820 op, getTypeConverter()->convertType(resultType), result);
824 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
825 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
830 loc,
"Fp8 conversion instructions are not available on target "
831 "architecture and their emulation is not implemented");
832 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
834 Type resultType = op.getResult().getType();
837 Value source = adaptor.getSource();
838 Value stoch = adaptor.getStochiasticParam();
839 Value existing = adaptor.getExisting();
841 existing = rewriter.
create<LLVM::BitcastOp>(loc, i32, existing);
843 existing = rewriter.
create<LLVM::UndefOp>(loc, i32);
848 result = rewriter.
create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
851 result = rewriter.
create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
855 op, getTypeConverter()->convertType(resultType), result);
867 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
872 Value src = adaptor.getSrc();
873 Value old = adaptor.getOld();
876 Type llvmType =
nullptr;
879 }
else if (isa<FloatType>(srcType)) {
883 }
else if (isa<IntegerType>(srcType)) {
888 auto llvmSrcIntType = typeConverter->convertType(
892 auto convertOperand = [&](
Value operand,
Type operandType) {
893 if (operandType.getIntOrFloatBitWidth() <= 16) {
894 if (llvm::isa<FloatType>(operandType)) {
896 rewriter.
create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
899 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
900 Value undefVec = rewriter.
create<LLVM::UndefOp>(loc, llvmVecType);
901 operand = rewriter.
create<LLVM::InsertElementOp>(
903 operand = rewriter.
create<LLVM::BitcastOp>(loc, llvmType, operand);
908 src = convertOperand(src, srcType);
909 old = convertOperand(old, oldType);
912 enum DppCtrl :
unsigned {
921 ROW_HALF_MIRROR = 0x141,
926 auto kind = DppOp.getKind();
927 auto permArgument = DppOp.getPermArgument();
928 uint32_t DppCtrl = 0;
932 case DPPPerm::quad_perm:
933 if (
auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
935 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
936 uint32_t num = elem.getInt();
937 DppCtrl |= num << (i * 2);
942 case DPPPerm::row_shl:
943 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
944 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
947 case DPPPerm::row_shr:
948 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
949 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
952 case DPPPerm::row_ror:
953 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
954 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
957 case DPPPerm::wave_shl:
958 DppCtrl = DppCtrl::WAVE_SHL1;
960 case DPPPerm::wave_shr:
961 DppCtrl = DppCtrl::WAVE_SHR1;
963 case DPPPerm::wave_rol:
964 DppCtrl = DppCtrl::WAVE_ROL1;
966 case DPPPerm::wave_ror:
967 DppCtrl = DppCtrl::WAVE_ROR1;
969 case DPPPerm::row_mirror:
970 DppCtrl = DppCtrl::ROW_MIRROR;
972 case DPPPerm::row_half_mirror:
973 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
975 case DPPPerm::row_bcast_15:
976 DppCtrl = DppCtrl::BCAST15;
978 case DPPPerm::row_bcast_31:
979 DppCtrl = DppCtrl::BCAST31;
985 auto rowMask = DppOp->getAttrOfType<IntegerAttr>(
"row_mask").getInt();
986 auto bankMask = DppOp->getAttrOfType<IntegerAttr>(
"bank_mask").getInt();
987 bool boundCtrl = DppOp->getAttrOfType<
BoolAttr>(
"bound_ctrl").getValue();
990 auto dppMovOp = rewriter.
create<ROCDL::DPPUpdateOp>(
991 loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
993 Value result = dppMovOp.getRes();
995 result = rewriter.
create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
996 if (!llvm::isa<IntegerType>(srcType)) {
997 result = rewriter.
create<LLVM::BitcastOp>(loc, srcType, result);
1008 struct ConvertAMDGPUToROCDLPass
1009 :
public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
1010 ConvertAMDGPUToROCDLPass() =
default;
1012 void runOnOperation()
override {
1015 if (failed(maybeChipset)) {
1017 return signalPassFailure();
1024 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
1025 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
1026 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
1028 std::move(patterns))))
1029 signalPassFailure();
1038 .
add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
1039 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
1040 RawBufferOpLowering<RawBufferAtomicFaddOp,
1041 ROCDL::RawPtrBufferAtomicFaddOp>,
1042 RawBufferOpLowering<RawBufferAtomicFmaxOp,
1043 ROCDL::RawPtrBufferAtomicFmaxOp>,
1044 RawBufferOpLowering<RawBufferAtomicSmaxOp,
1045 ROCDL::RawPtrBufferAtomicSmaxOp>,
1046 RawBufferOpLowering<RawBufferAtomicUminOp,
1047 ROCDL::RawPtrBufferAtomicUminOp>,
1048 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
1049 ROCDL::RawPtrBufferAtomicCmpSwap>,
1050 AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1051 MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
1052 PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
1057 return std::make_unique<ConvertAMDGPUToROCDLPass>();
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Return the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value)
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector< Value, 4 > &operands)
Push the output operand.
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVector< Value, 4 > &operands)
Push an input operand.
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsic...
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static MLIRContext * getContext(OpFoldResult val)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getI16IntegerAttr(int16_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isFloat8E4M3FNUZ() const
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isFloat8E5M2FNUZ() const
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
std::unique_ptr< Pass > createConvertAMDGPUToROCDLPass()
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void populateAMDGPUToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: The ROCDL target does not support the LLVM bfloat type at this time and so this function will a...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.