22 #include "llvm/ADT/STLExtras.h"
26 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDL
27 #include "mlir/Conversion/Passes.h.inc"
38 auto valTy = cast<IntegerType>(val.
getType());
41 return valTy.getWidth() > 32
42 ?
Value(rewriter.
create<LLVM::TruncOp>(loc, i32, val))
43 :
Value(rewriter.
create<LLVM::ZExtOp>(loc, i32, val));
49 return rewriter.
create<LLVM::ConstantOp>(loc, i32, value);
55 return rewriter.
create<LLVM::ConstantOp>(loc, llvmI1, value);
67 ShapedType::isDynamic(stride)
69 memRefDescriptor.
stride(rewriter, loc, i))
70 : rewriter.
create<LLVM::ConstantOp>(loc, i32, stride);
71 increment = rewriter.
create<LLVM::MulOp>(loc, increment, strideValue);
74 index ? rewriter.
create<LLVM::AddOp>(loc, index, increment) : increment;
86 template <
typename GpuOp,
typename Intrinsic>
92 static constexpr uint32_t maxVectorOpWidth = 128;
95 matchAndRewrite(GpuOp gpuOp,
typename GpuOp::Adaptor adaptor,
98 Value memref = adaptor.getMemref();
99 Value unconvertedMemref = gpuOp.getMemref();
100 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
103 return gpuOp.emitOpError(
"raw buffer ops require GCN or higher");
105 Value storeData = adaptor.getODSOperands(0)[0];
106 if (storeData == memref)
110 wantedDataType = storeData.
getType();
112 wantedDataType = gpuOp.getODSResults(0)[0].getType();
117 Value maybeCmpData = adaptor.getODSOperands(1)[0];
118 if (maybeCmpData != memref)
119 atomicCmpData = maybeCmpData;
122 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
129 int64_t elementByteWidth =
138 Type llvmBufferValType = llvmWantedDataType;
140 if (
auto floatType = dyn_cast<FloatType>(wantedDataType))
141 llvmBufferValType = this->getTypeConverter()->convertType(
144 if (
auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
145 uint32_t vecLen = dataVector.getNumElements();
148 uint32_t totalBits = elemBits * vecLen;
150 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
151 if (totalBits > maxVectorOpWidth)
152 return gpuOp.emitOpError(
153 "Total width of loads or stores must be no more than " +
154 Twine(maxVectorOpWidth) +
" bits, but we call for " +
156 " bits. This should've been caught in validation");
157 if (!usePackedFp16 && elemBits < 32) {
158 if (totalBits > 32) {
159 if (totalBits % 32 != 0)
160 return gpuOp.emitOpError(
"Load or store of more than 32-bits that "
161 "doesn't fit into words. Can't happen\n");
162 llvmBufferValType = this->typeConverter->convertType(
165 llvmBufferValType = this->typeConverter->convertType(
173 if (llvmBufferValType != llvmWantedDataType) {
175 rewriter.
create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
176 args.push_back(castForStore);
178 args.push_back(storeData);
183 if (llvmBufferValType != llvmWantedDataType) {
184 Value castForCmp = rewriter.
create<LLVM::BitcastOp>(
185 loc, llvmBufferValType, atomicCmpData);
186 args.push_back(castForCmp);
188 args.push_back(atomicCmpData);
195 if (failed(memrefType.getStridesAndOffset(strides, offset)))
196 return gpuOp.emitOpError(
"Can't lower non-stride-offset memrefs");
200 Value ptr = memrefDescriptor.bufferPtr(
201 rewriter, loc, *this->getTypeConverter(), memrefType);
208 if (memrefType.hasStaticShape() &&
209 !llvm::any_of(strides, ShapedType::isDynamic)) {
210 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
212 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
213 size =
std::max(shape[i] * strides[i], size);
214 size = size * elementByteWidth;
216 "the memref buffer is too large");
220 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
221 Value size = memrefDescriptor.size(rewriter, loc, i);
222 Value stride = memrefDescriptor.stride(rewriter, loc, i);
223 Value maxThisDim = rewriter.
create<LLVM::MulOp>(loc, size, stride);
225 maxIndex ? rewriter.
create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
228 numRecords = rewriter.
create<LLVM::MulOp>(
246 uint32_t flags = (7 << 12) | (4 << 15);
249 uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2;
250 flags |= (oob << 28);
255 loc, rsrcType, ptr, stride, numRecords, flagsConst);
256 args.push_back(resource);
260 adaptor.getIndices(), strides);
261 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
262 indexOffset && *indexOffset > 0) {
265 voffset ? rewriter.
create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
268 voffset = rewriter.
create<LLVM::MulOp>(loc, voffset, byteWidthConst);
269 args.push_back(voffset);
272 Value sgprOffset = adaptor.getSgprOffset();
275 sgprOffset = rewriter.
create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst);
276 args.push_back(sgprOffset);
285 Operation *lowered = rewriter.
create<Intrinsic>(loc, resultTypes, args,
289 if (llvmBufferValType != llvmWantedDataType) {
290 replacement = rewriter.
create<LLVM::BitcastOp>(loc, llvmWantedDataType,
308 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
310 bool requiresInlineAsm = chipset < kGfx90a || chipset.
majorVersion == 11;
312 if (requiresInlineAsm) {
314 LLVM::AsmDialect::AD_ATT);
316 ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
317 const char *constraints =
"";
321 asmStr, constraints,
true,
322 false, asmDialectAttr,
327 constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
328 constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8);
331 constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4);
335 ldsOnlyBits = ldsOnlyBitsGfx11;
337 ldsOnlyBits = ldsOnlyBitsGfx10;
339 ldsOnlyBits = ldsOnlyBitsGfx6789;
341 return op.emitOpError(
342 "don't know how to lower this for chipset major version")
346 rewriter.
create<ROCDL::SWaitcntOp>(loc, ldsOnlyBits);
350 rewriter.
create<ROCDL::WaitDscntOp>(loc, 0);
351 rewriter.
create<ROCDL::BarrierSignalOp>(loc, -1);
366 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
369 (uint32_t)op.getOpts());
385 if (
auto vectorType = dyn_cast<VectorType>(inputType)) {
386 if (vectorType.getElementType().isBF16())
387 return rewriter.
create<LLVM::BitcastOp>(
388 loc, vectorType.clone(rewriter.
getI16Type()), input);
389 if (vectorType.getElementType().isInteger(8)) {
390 return rewriter.
create<LLVM::BitcastOp>(
391 loc, rewriter.
getIntegerType(vectorType.getNumElements() * 8), input);
405 bool isUnsigned,
Value llvmInput,
409 auto vectorType = dyn_cast<VectorType>(inputType);
410 Type elemType = vectorType.getElementType();
413 llvmInput = rewriter.
create<LLVM::BitcastOp>(
414 loc, vectorType.clone(rewriter.
getI16Type()), llvmInput);
416 operands.push_back(llvmInput);
423 auto mlirInputType = cast<VectorType>(mlirInput.
getType());
424 bool isInputInt8 = mlirInputType.getElementType().isInteger(8);
427 bool localIsUnsigned = isUnsigned;
429 localIsUnsigned =
true;
431 localIsUnsigned =
false;
434 operands.push_back(sign);
437 int64_t numBytes = vectorType.getNumElements();
440 auto llvmVectorType32bits = typeConverter->
convertType(vectorType32bits);
442 loc, llvmVectorType32bits, llvmInput);
443 operands.push_back(result);
455 Value output, int32_t subwordOffset,
458 auto vectorType = dyn_cast<VectorType>(inputType);
459 Type elemType = vectorType.getElementType();
461 output = rewriter.
create<LLVM::BitcastOp>(
462 loc, vectorType.clone(rewriter.
getI16Type()), output);
463 operands.push_back(output);
476 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
477 b = mfma.getBlocks();
478 Type sourceElem = mfma.getSourceA().getType();
479 if (
auto sourceType = dyn_cast<VectorType>(sourceElem))
480 sourceElem = sourceType.getElementType();
481 Type destElem = mfma.getDestC().getType();
482 if (
auto destType = dyn_cast<VectorType>(destElem))
483 destElem = destType.getElementType();
486 if (mfma.getReducePrecision() && chipset >= kGfx940) {
487 if (m == 32 && n == 32 && k == 4 && b == 1)
488 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
489 if (m == 16 && n == 16 && k == 8 && b == 1)
490 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
492 if (m == 32 && n == 32 && k == 1 && b == 2)
493 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
494 if (m == 16 && n == 16 && k == 1 && b == 4)
495 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
496 if (m == 4 && n == 4 && k == 1 && b == 16)
497 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
498 if (m == 32 && n == 32 && k == 2 && b == 1)
499 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
500 if (m == 16 && n == 16 && k == 4 && b == 1)
501 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
505 if (m == 32 && n == 32 && k == 4 && b == 2)
506 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
507 if (m == 16 && n == 16 && k == 4 && b == 4)
508 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
509 if (m == 4 && n == 4 && k == 4 && b == 16)
510 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
511 if (m == 32 && n == 32 && k == 8 && b == 1)
512 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
513 if (m == 16 && n == 16 && k == 16 && b == 1)
514 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
517 if (sourceElem.
isBF16() && destElem.
isF32() && chipset >= kGfx90a) {
518 if (m == 32 && n == 32 && k == 4 && b == 2)
519 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
520 if (m == 16 && n == 16 && k == 4 && b == 4)
521 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
522 if (m == 4 && n == 4 && k == 4 && b == 16)
523 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
524 if (m == 32 && n == 32 && k == 8 && b == 1)
525 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
526 if (m == 16 && n == 16 && k == 16 && b == 1)
527 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
531 if (m == 32 && n == 32 && k == 2 && b == 2)
532 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
533 if (m == 16 && n == 16 && k == 2 && b == 4)
534 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
535 if (m == 4 && n == 4 && k == 2 && b == 16)
536 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
537 if (m == 32 && n == 32 && k == 4 && b == 1)
538 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
539 if (m == 16 && n == 16 && k == 8 && b == 1)
540 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
543 if (isa<IntegerType>(sourceElem) && destElem.
isInteger(32)) {
544 if (m == 32 && n == 32 && k == 4 && b == 2)
545 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
546 if (m == 16 && n == 16 && k == 4 && b == 4)
547 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
548 if (m == 4 && n == 4 && k == 4 && b == 16)
549 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
550 if (m == 32 && n == 32 && k == 8 && b == 1)
551 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
552 if (m == 16 && n == 16 && k == 16 && b == 1)
553 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
554 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx940)
555 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
556 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx940)
557 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
560 if (sourceElem.
isF64() && destElem.
isF64() && chipset >= kGfx90a) {
561 if (m == 16 && n == 16 && k == 4 && b == 1)
562 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
563 if (m == 4 && n == 4 && k == 4 && b == 4)
564 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
567 if (isa<Float8E5M2FNUZType>(sourceElem) && destElem.
isF32() &&
568 chipset >= kGfx940) {
572 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
573 if (m == 16 && n == 16 && k == 32 && b == 1) {
574 if (isa<Float8E5M2FNUZType>(sourceBElem))
575 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
576 if (isa<Float8E4M3FNUZType>(sourceBElem))
577 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
579 if (m == 32 && n == 32 && k == 16 && b == 1) {
580 if (isa<Float8E5M2FNUZType>(sourceBElem))
581 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
582 if (isa<Float8E4M3FNUZType>(sourceBElem))
583 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
587 if (isa<Float8E4M3FNUZType>(sourceElem) && destElem.
isF32() &&
588 chipset >= kGfx940) {
590 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
591 if (m == 16 && n == 16 && k == 32 && b == 1) {
592 if (isa<Float8E5M2FNUZType>(sourceBElem))
593 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
594 if (isa<Float8E4M3FNUZType>(sourceBElem))
595 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
597 if (m == 32 && n == 32 && k == 16 && b == 1) {
598 if (isa<Float8E5M2FNUZType>(sourceBElem))
599 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
600 if (isa<Float8E4M3FNUZType>(sourceBElem))
601 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
613 auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
614 auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
615 auto elemSourceType = sourceVectorType.getElementType();
616 auto elemDestType = destVectorType.getElementType();
618 if (elemSourceType.isF16() && elemDestType.isF32())
619 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
620 if (elemSourceType.isBF16() && elemDestType.isF32())
621 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
622 if (elemSourceType.isF16() && elemDestType.isF16())
623 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
624 if (elemSourceType.isBF16() && elemDestType.isBF16())
625 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
626 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
627 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
628 if (isa<Float8E4M3FNType>(elemSourceType) && elemDestType.isF32())
629 return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
630 if (isa<Float8E5M2Type>(elemSourceType) && elemDestType.isF32())
631 return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
643 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
646 Type outType = typeConverter->convertType(op.getDestD().getType());
647 Type intrinsicOutType = outType;
648 if (
auto outVecType = dyn_cast<VectorType>(outType))
649 if (outVecType.getElementType().isBF16())
650 intrinsicOutType = outVecType.clone(rewriter.
getI16Type());
653 return op->emitOpError(
"MFMA only supported on gfx908+");
654 uint32_t getBlgpField =
static_cast<uint32_t
>(op.getBlgp());
655 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
656 if (chipset < kGfx940)
657 return op.emitOpError(
"negation unsupported on older than gfx940");
659 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
662 if (!maybeIntrinsic.has_value())
663 return op.emitOpError(
"no intrinsic matching MFMA size on given chipset");
665 loweredOp.addTypes(intrinsicOutType);
666 loweredOp.addOperands(
673 if (outType != intrinsicOutType)
674 lowered = rewriter.
create<LLVM::BitcastOp>(loc, outType, lowered);
687 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
691 typeConverter->convertType<VectorType>(op.getDestD().getType());
696 return op->emitOpError(
"WMMA only supported on gfx11 and gfx12");
700 VectorType rawOutType = outType;
701 if (outType.getElementType().
isBF16())
702 rawOutType = outType.clone(rewriter.
getI16Type());
706 if (!maybeIntrinsic.has_value())
707 return op.emitOpError(
"no intrinsic matching WMMA on the given chipset");
710 loweredOp.addTypes(rawOutType);
714 adaptor.getSourceA(), op.getSourceA(), operands);
716 adaptor.getSourceB(), op.getSourceB(), operands);
718 op.getSubwordOffset(), op.getClamp(), operands);
720 loweredOp.addOperands(operands);
724 if (rawOutType != outType)
734 struct ExtPackedFp8OpLowering final
742 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
746 struct PackedTrunc2xFp8OpLowering final
755 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
759 struct PackedStochRoundFp8OpLowering final
768 matchAndRewrite(PackedStochRoundFp8Op op,
769 PackedStochRoundFp8OpAdaptor adaptor,
774 LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
775 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
780 loc,
"Fp8 conversion instructions are not available on target "
781 "architecture and their emulation is not implemented");
784 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
785 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
787 Value source = adaptor.getSource();
788 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
791 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
792 Value longVec = rewriter.
create<LLVM::UndefOp>(loc, v4i8);
793 if (!sourceVecType) {
794 longVec = rewriter.
create<LLVM::InsertElementOp>(
797 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
799 Value elem = rewriter.
create<LLVM::ExtractElementOp>(loc, source, idx);
801 rewriter.
create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
806 Value i32Source = rewriter.
create<LLVM::BitcastOp>(loc, i32, source);
808 if (isa<Float8E5M2FNUZType>(sourceElemType)) {
811 }
else if (isa<Float8E4M3FNUZType>(sourceElemType)) {
818 LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
819 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
824 loc,
"Fp8 conversion instructions are not available on target "
825 "architecture and their emulation is not implemented");
826 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
828 Type resultType = op.getResult().getType();
831 Value sourceA = adaptor.getSourceA();
832 Value sourceB = adaptor.getSourceB();
834 sourceB = rewriter.
create<LLVM::UndefOp>(loc, sourceA.
getType());
835 Value existing = adaptor.getExisting();
837 existing = rewriter.
create<LLVM::BitcastOp>(loc, i32, existing);
839 existing = rewriter.
create<LLVM::UndefOp>(loc, i32);
843 if (isa<Float8E5M2FNUZType>(resultElemType))
844 result = rewriter.
create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
846 else if (isa<Float8E4M3FNUZType>(resultElemType))
847 result = rewriter.
create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
851 op, getTypeConverter()->convertType(resultType), result);
855 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
856 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
861 loc,
"Fp8 conversion instructions are not available on target "
862 "architecture and their emulation is not implemented");
863 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
865 Type resultType = op.getResult().getType();
868 Value source = adaptor.getSource();
869 Value stoch = adaptor.getStochiasticParam();
870 Value existing = adaptor.getExisting();
872 existing = rewriter.
create<LLVM::BitcastOp>(loc, i32, existing);
874 existing = rewriter.
create<LLVM::UndefOp>(loc, i32);
878 if (isa<Float8E5M2FNUZType>(resultElemType))
879 result = rewriter.
create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
881 else if (isa<Float8E4M3FNUZType>(resultElemType))
882 result = rewriter.
create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
886 op, getTypeConverter()->convertType(resultType), result);
898 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
903 Value src = adaptor.getSrc();
904 Value old = adaptor.getOld();
907 Type llvmType =
nullptr;
910 }
else if (isa<FloatType>(srcType)) {
914 }
else if (isa<IntegerType>(srcType)) {
919 auto llvmSrcIntType = typeConverter->convertType(
923 auto convertOperand = [&](
Value operand,
Type operandType) {
924 if (operandType.getIntOrFloatBitWidth() <= 16) {
925 if (llvm::isa<FloatType>(operandType)) {
927 rewriter.
create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
930 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
931 Value undefVec = rewriter.
create<LLVM::UndefOp>(loc, llvmVecType);
932 operand = rewriter.
create<LLVM::InsertElementOp>(
934 operand = rewriter.
create<LLVM::BitcastOp>(loc, llvmType, operand);
939 src = convertOperand(src, srcType);
940 old = convertOperand(old, oldType);
943 enum DppCtrl :
unsigned {
952 ROW_HALF_MIRROR = 0x141,
957 auto kind = DppOp.getKind();
958 auto permArgument = DppOp.getPermArgument();
959 uint32_t DppCtrl = 0;
963 case DPPPerm::quad_perm:
964 if (
auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
966 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
967 uint32_t num = elem.getInt();
968 DppCtrl |= num << (i * 2);
973 case DPPPerm::row_shl:
974 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
975 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
978 case DPPPerm::row_shr:
979 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
980 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
983 case DPPPerm::row_ror:
984 if (
auto intAttr = cast<IntegerAttr>(*permArgument)) {
985 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
988 case DPPPerm::wave_shl:
989 DppCtrl = DppCtrl::WAVE_SHL1;
991 case DPPPerm::wave_shr:
992 DppCtrl = DppCtrl::WAVE_SHR1;
994 case DPPPerm::wave_rol:
995 DppCtrl = DppCtrl::WAVE_ROL1;
997 case DPPPerm::wave_ror:
998 DppCtrl = DppCtrl::WAVE_ROR1;
1000 case DPPPerm::row_mirror:
1001 DppCtrl = DppCtrl::ROW_MIRROR;
1003 case DPPPerm::row_half_mirror:
1004 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
1006 case DPPPerm::row_bcast_15:
1007 DppCtrl = DppCtrl::BCAST15;
1009 case DPPPerm::row_bcast_31:
1010 DppCtrl = DppCtrl::BCAST31;
1016 auto rowMask = DppOp->getAttrOfType<IntegerAttr>(
"row_mask").getInt();
1017 auto bankMask = DppOp->getAttrOfType<IntegerAttr>(
"bank_mask").getInt();
1018 bool boundCtrl = DppOp->getAttrOfType<
BoolAttr>(
"bound_ctrl").getValue();
1021 auto dppMovOp = rewriter.
create<ROCDL::DPPUpdateOp>(
1022 loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
1024 Value result = dppMovOp.getRes();
1026 result = rewriter.
create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
1027 if (!llvm::isa<IntegerType>(srcType)) {
1028 result = rewriter.
create<LLVM::BitcastOp>(loc, srcType, result);
1039 struct ConvertAMDGPUToROCDLPass
1040 :
public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
1041 ConvertAMDGPUToROCDLPass() =
default;
1043 void runOnOperation()
override {
1046 if (failed(maybeChipset)) {
1048 return signalPassFailure();
1055 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
1056 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
1057 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
1060 signalPassFailure();
1069 .add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
1070 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
1071 RawBufferOpLowering<RawBufferAtomicFaddOp,
1072 ROCDL::RawPtrBufferAtomicFaddOp>,
1073 RawBufferOpLowering<RawBufferAtomicFmaxOp,
1074 ROCDL::RawPtrBufferAtomicFmaxOp>,
1075 RawBufferOpLowering<RawBufferAtomicSmaxOp,
1076 ROCDL::RawPtrBufferAtomicSmaxOp>,
1077 RawBufferOpLowering<RawBufferAtomicUminOp,
1078 ROCDL::RawPtrBufferAtomicUminOp>,
1079 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
1080 ROCDL::RawPtrBufferAtomicCmpSwap>,
1081 AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1082 MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
1083 PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
1088 return std::make_unique<ConvertAMDGPUToROCDLPass>();
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Return the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value)
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector< Value, 4 > &operands)
Push the output operand.
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, Value mlirInput, SmallVector< Value, 4 > &operands)
Push an input operand.
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Location loc, Value input)
Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsic...
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, Location loc, MemRefDescriptor &memRefDescriptor, ValueRange indices, ArrayRef< int64_t > strides)
Returns the linear index used to access an element in the memref.
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, Location loc, Value val)
Convert an unsigned number val to i32.
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static MLIRContext * getContext(OpFoldResult val)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getI16IntegerAttr(int16_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Value stride(OpBuilder &builder, Location loc, unsigned pos)
Builds IR extracting the pos-th size from the descriptor.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
std::unique_ptr< Pass > createConvertAMDGPUToROCDLPass()
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void populateAMDGPUToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: The ROCDL target does not support the LLVM bfloat type at this time and so this function will a...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.