21 #include "llvm/ADT/STLExtras.h"
25 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDL
26 #include "mlir/Conversion/Passes.h.inc"
35 return rewriter.
create<LLVM::ConstantOp>(loc, llvmI32, value);
41 return rewriter.
create<LLVM::ConstantOp>(loc, llvmI1, value);
46 template <
typename GpuOp,
typename Intrinsic>
52 static constexpr uint32_t maxVectorOpWidth = 128;
55 matchAndRewrite(GpuOp gpuOp,
typename GpuOp::Adaptor adaptor,
58 Value memref = adaptor.getMemref();
59 Value unconvertedMemref = gpuOp.getMemref();
60 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
63 return gpuOp.emitOpError(
"raw buffer ops require GCN or higher");
65 Value storeData = adaptor.getODSOperands(0)[0];
66 if (storeData == memref)
70 wantedDataType = storeData.
getType();
72 wantedDataType = gpuOp.getODSResults(0)[0].getType();
77 Value maybeCmpData = adaptor.getODSOperands(1)[0];
78 if (maybeCmpData != memref)
79 atomicCmpData = maybeCmpData;
82 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
85 Type llvmI32 = this->typeConverter->convertType(i32);
86 Type llvmI16 = this->typeConverter->convertType(rewriter.
getI16Type());
88 int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8;
98 Type llvmBufferValType = llvmWantedDataType;
99 if (wantedDataType.
isBF16())
101 if (
auto wantedVecType = dyn_cast<VectorType>(wantedDataType))
102 if (wantedVecType.getElementType().isBF16())
103 llvmBufferValType = wantedVecType.clone(rewriter.
getI16Type());
105 if (isa<VectorType>(wantedDataType))
106 return gpuOp.emitOpError(
"vector compare-and-swap does not exist");
107 if (
auto floatType = dyn_cast<FloatType>(wantedDataType))
108 llvmBufferValType = this->getTypeConverter()->convertType(
111 if (
auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
112 uint32_t elemBits = dataVector.getElementTypeBitWidth();
113 uint32_t totalBits = elemBits * dataVector.getNumElements();
114 if (totalBits > maxVectorOpWidth)
115 return gpuOp.emitOpError(
116 "Total width of loads or stores must be no more than " +
117 Twine(maxVectorOpWidth) +
" bits, but we call for " +
119 " bits. This should've been caught in validation");
121 if (totalBits > 32) {
122 if (totalBits % 32 != 0)
123 return gpuOp.emitOpError(
"Load or store of more than 32-bits that "
124 "doesn't fit into words. Can't happen\n");
125 llvmBufferValType = this->typeConverter->convertType(
128 llvmBufferValType = this->typeConverter->convertType(
136 if (llvmBufferValType != llvmWantedDataType) {
138 rewriter.
create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
139 args.push_back(castForStore);
141 args.push_back(storeData);
146 if (llvmBufferValType != llvmWantedDataType) {
147 Value castForCmp = rewriter.
create<LLVM::BitcastOp>(
148 loc, llvmBufferValType, atomicCmpData);
149 args.push_back(castForCmp);
151 args.push_back(atomicCmpData);
159 return gpuOp.emitOpError(
"Can't lower non-stride-offset memrefs");
163 Value ptr = memrefDescriptor.alignedPtr(rewriter, loc);
169 if (memrefType.hasStaticShape()) {
172 static_cast<int32_t
>(memrefType.getNumElements() * elementByteWidth));
175 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
176 Value size = memrefDescriptor.size(rewriter, loc, i);
177 Value stride = memrefDescriptor.stride(rewriter, loc, i);
178 stride = rewriter.
create<LLVM::MulOp>(loc, stride, byteWidthConst);
179 Value maxThisDim = rewriter.
create<LLVM::MulOp>(loc, size, stride);
180 maxIndex = maxIndex ? rewriter.
create<LLVM::MaximumOp>(loc, maxIndex,
184 numRecords = rewriter.
create<LLVM::TruncOp>(loc, llvmI32, maxIndex);
201 uint32_t flags = (7 << 12) | (4 << 15);
204 uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2;
205 flags |= (oob << 28);
210 loc, rsrcType, ptr, stride, numRecords, flagsConst);
211 args.push_back(resource);
216 size_t i = pair.index();
217 Value index = pair.value();
219 if (ShapedType::isDynamic(strides[i])) {
220 strideOp = rewriter.
create<LLVM::MulOp>(
221 loc, memrefDescriptor.stride(rewriter, loc, i), byteWidthConst);
226 index = rewriter.
create<LLVM::MulOp>(loc, index, strideOp);
227 voffset = rewriter.
create<LLVM::AddOp>(loc, voffset, index);
229 if (adaptor.getIndexOffset()) {
230 int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth;
233 voffset ? rewriter.
create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
236 args.push_back(voffset);
238 Value sgprOffset = adaptor.getSgprOffset();
241 if (ShapedType::isDynamic(offset))
242 sgprOffset = rewriter.
create<LLVM::AddOp>(
243 loc, memrefDescriptor.offset(rewriter, loc), sgprOffset);
245 sgprOffset = rewriter.
create<LLVM::AddOp>(
247 args.push_back(sgprOffset);
256 Operation *lowered = rewriter.
create<Intrinsic>(loc, resultTypes, args,
260 if (llvmBufferValType != llvmWantedDataType) {
261 replacement = rewriter.
create<LLVM::BitcastOp>(loc, llvmWantedDataType,
279 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
281 bool requiresInlineAsm =
286 if (requiresInlineAsm) {
288 LLVM::AsmDialect::AD_ATT);
290 ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
291 const char *constraints =
"";
295 asmStr, constraints,
true,
296 false, asmDialectAttr,
300 constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
301 constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8);
304 constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4);
308 ldsOnlyBits = ldsOnlyBitsGfx11;
310 ldsOnlyBits = ldsOnlyBitsGfx10;
312 ldsOnlyBits = ldsOnlyBitsGfx6789;
315 "don't know how to lower this for chipset major version")
319 rewriter.
create<ROCDL::WaitcntOp>(loc, ldsOnlyBits);
337 if (
auto vectorType = dyn_cast<VectorType>(inputType)) {
338 if (vectorType.getElementType().isBF16())
339 return rewriter.
create<LLVM::BitcastOp>(
340 loc, vectorType.clone(rewriter.
getI16Type()), input);
342 if (!vectorType.getElementType().isInteger(8))
344 int64_t numBytes = vectorType.getNumElements();
348 for (int64_t i = 0; i < numBytes; ++i) {
351 rewriter.
create<LLVM::ExtractElementOp>(loc, input, idxConst);
352 Value extended = rewriter.
create<LLVM::ZExtOp>(loc, destType, element);
353 Value shiftConst = rewriter.
create<LLVM::ConstantOp>(
355 Value shifted = rewriter.
create<LLVM::ShlOp>(loc, extended, shiftConst);
356 result = rewriter.
create<LLVM::OrOp>(loc, result, shifted);
371 bool isUnsigned,
Value llvmInput,
374 auto vectorType = dyn_cast<VectorType>(inputType);
375 Type elemType = vectorType.getElementType();
378 llvmInput = rewriter.
create<LLVM::BitcastOp>(
379 loc, vectorType.clone(rewriter.
getI16Type()), llvmInput);
381 operands.push_back(llvmInput);
385 int64_t numBytes = vectorType.getNumElements();
388 auto llvmVectorType32bits = typeConverter->
convertType(vectorType32bits);
391 loc, llvmVectorType32bits, llvmInput);
394 bool localIsUnsigned = isUnsigned;
396 localIsUnsigned =
true;
398 localIsUnsigned =
false;
401 operands.push_back(sign);
402 operands.push_back(result);
414 Value output, int32_t subwordOffset,
417 auto vectorType = dyn_cast<VectorType>(inputType);
418 Type elemType = vectorType.getElementType();
420 output = rewriter.
create<LLVM::BitcastOp>(
421 loc, vectorType.clone(rewriter.
getI16Type()), output);
422 operands.push_back(output);
435 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
436 b = mfma.getBlocks();
437 Type sourceElem = mfma.getSourceA().getType();
438 if (
auto sourceType = dyn_cast<VectorType>(sourceElem))
439 sourceElem = sourceType.getElementType();
440 Type destElem = mfma.getDestC().getType();
441 if (
auto destType = dyn_cast<VectorType>(destElem))
442 destElem = destType.getElementType();
445 if (mfma.getReducePrecision() && chipset.
minorVersion >= 0x40) {
446 if (m == 32 && n == 32 && k == 4 && b == 1)
447 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
448 if (m == 16 && n == 16 && k == 8 && b == 1)
449 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
451 if (m == 32 && n == 32 && k == 1 && b == 2)
452 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
453 if (m == 16 && n == 16 && k == 1 && b == 4)
454 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
455 if (m == 4 && n == 4 && k == 1 && b == 16)
456 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
457 if (m == 32 && n == 32 && k == 2 && b == 1)
458 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
459 if (m == 16 && n == 16 && k == 4 && b == 1)
460 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
464 if (m == 32 && n == 32 && k == 4 && b == 2)
465 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
466 if (m == 16 && n == 16 && k == 4 && b == 4)
467 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
468 if (m == 4 && n == 4 && k == 4 && b == 16)
469 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
470 if (m == 32 && n == 32 && k == 8 && b == 1)
471 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
472 if (m == 16 && n == 16 && k == 16 && b == 1)
473 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
477 if (m == 32 && n == 32 && k == 4 && b == 2)
478 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
479 if (m == 16 && n == 16 && k == 4 && b == 4)
480 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
481 if (m == 4 && n == 4 && k == 4 && b == 16)
482 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
483 if (m == 32 && n == 32 && k == 8 && b == 1)
484 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
485 if (m == 16 && n == 16 && k == 16 && b == 1)
486 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
490 if (m == 32 && n == 32 && k == 2 && b == 2)
491 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
492 if (m == 16 && n == 16 && k == 2 && b == 4)
493 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
494 if (m == 4 && n == 4 && k == 2 && b == 16)
495 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
496 if (m == 32 && n == 32 && k == 4 && b == 1)
497 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
498 if (m == 16 && n == 16 && k == 8 && b == 1)
499 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
502 if (isa<IntegerType>(sourceElem) && destElem.
isInteger(32)) {
503 if (m == 32 && n == 32 && k == 4 && b == 2)
504 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
505 if (m == 16 && n == 16 && k == 4 && b == 4)
506 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
507 if (m == 4 && n == 4 && k == 4 && b == 16)
508 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
509 if (m == 32 && n == 32 && k == 8 && b == 1)
510 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
511 if (m == 16 && n == 16 && k == 16 && b == 1)
512 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
513 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset.
minorVersion >= 0x40)
514 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
515 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset.
minorVersion >= 0x40)
516 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
520 if (m == 16 && n == 16 && k == 4 && b == 1)
521 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
522 if (m == 4 && n == 4 && k == 4 && b == 4)
523 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
531 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
532 if (m == 16 && n == 16 && k == 32 && b == 1) {
534 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
536 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
538 if (m == 32 && n == 32 && k == 16 && b == 1) {
540 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
542 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
549 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
550 if (m == 16 && n == 16 && k == 32 && b == 1) {
552 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
554 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
556 if (m == 32 && n == 32 && k == 16 && b == 1) {
558 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
560 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
572 auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
573 auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
574 auto elemSourceType = sourceVectorType.getElementType();
575 auto elemDestType = destVectorType.getElementType();
577 if (elemSourceType.isF16() && elemDestType.isF32()) {
578 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
580 if (elemSourceType.isBF16() && elemDestType.isF32()) {
581 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
582 }
else if (elemSourceType.isF16() && elemDestType.isF16()) {
583 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
584 }
else if (elemSourceType.isBF16() && elemDestType.isBF16()) {
585 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
586 }
else if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) {
587 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
600 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
603 Type outType = typeConverter->convertType(op.getDestD().getType());
604 Type intrinsicOutType = outType;
605 if (
auto outVecType = dyn_cast<VectorType>(outType))
606 if (outVecType.getElementType().isBF16())
607 intrinsicOutType = outVecType.clone(rewriter.
getI16Type());
610 return op->
emitOpError(
"MFMA only supported on gfx908+");
611 uint32_t getBlgpField =
static_cast<uint32_t
>(op.getBlgp());
612 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
614 return op.
emitOpError(
"negation unsupported on older than gfx840");
616 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
619 if (!maybeIntrinsic.has_value())
620 return op.
emitOpError(
"no intrinsic matching MFMA size on given chipset");
622 loweredOp.addTypes(intrinsicOutType);
623 loweredOp.addOperands(
630 if (outType != intrinsicOutType)
631 lowered = rewriter.
create<LLVM::BitcastOp>(loc, outType, lowered);
644 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
647 Type outType = typeConverter->convertType(op.getDestD().getType());
650 return op->
emitOpError(
"WMMA only supported on gfx11");
654 if (!maybeIntrinsic.has_value())
655 return op.
emitOpError(
"no intrinsic matching WMMA on the given chipset");
658 loweredOp.addTypes(outType);
662 adaptor.getSourceA(), operands);
664 adaptor.getSourceB(), operands);
666 op.getSubwordOffset(), op.getClamp(), operands);
668 loweredOp.addOperands(operands);
677 struct ExtPackedFp8OpLowering final
685 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
689 struct PackedTrunc2xFp8OpLowering final
697 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
701 struct PackedStochRoundFp8OpLowering final
709 matchAndRewrite(PackedStochRoundFp8Op op,
710 PackedStochRoundFp8OpAdaptor adaptor,
716 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
721 loc,
"Fp8 conversion instructions are not available on target "
722 "architecture and their emulation is not implemented");
725 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
728 Value source = adaptor.getSource();
729 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
732 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
733 Value longVec = rewriter.
create<LLVM::UndefOp>(loc, v4i8);
734 if (!sourceVecType) {
735 longVec = rewriter.
create<LLVM::InsertElementOp>(
738 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
740 Value elem = rewriter.
create<LLVM::ExtractElementOp>(loc, source, idx);
742 rewriter.
create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
747 Value i32Source = rewriter.
create<LLVM::BitcastOp>(loc, i32, source);
760 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
765 loc,
"Fp8 conversion instructions are not available on target "
766 "architecture and their emulation is not implemented");
767 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
772 Value sourceA = adaptor.getSourceA();
773 Value sourceB = adaptor.getSourceB();
775 sourceB = rewriter.
create<LLVM::UndefOp>(loc, sourceA.
getType());
776 Value existing = adaptor.getExisting();
778 existing = rewriter.
create<LLVM::BitcastOp>(loc, i32, existing);
780 existing = rewriter.
create<LLVM::UndefOp>(loc, i32);
785 result = rewriter.
create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
788 result = rewriter.
create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
792 op, getTypeConverter()->convertType(resultType), result);
796 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
797 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
802 loc,
"Fp8 conversion instructions are not available on target "
803 "architecture and their emulation is not implemented");
804 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
809 Value source = adaptor.getSource();
810 Value stoch = adaptor.getStochiasticParam();
811 Value existing = adaptor.getExisting();
813 existing = rewriter.
create<LLVM::BitcastOp>(loc, i32, existing);
815 existing = rewriter.
create<LLVM::UndefOp>(loc, i32);
820 result = rewriter.
create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
823 result = rewriter.
create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
827 op, getTypeConverter()->convertType(resultType), result);
831 struct ConvertAMDGPUToROCDLPass
832 :
public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
833 ConvertAMDGPUToROCDLPass() =
default;
835 void runOnOperation()
override {
838 if (
failed(maybeChipset)) {
840 return signalPassFailure();
847 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
848 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
849 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
851 std::move(patterns))))
863 converter.
addConversion([&converter](VectorType t) -> std::optional<Type> {
864 if (!t.getElementType().isBF16())
870 .
add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
871 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
872 RawBufferOpLowering<RawBufferAtomicFaddOp,
873 ROCDL::RawPtrBufferAtomicFaddOp>,
874 RawBufferOpLowering<RawBufferAtomicFmaxOp,
875 ROCDL::RawPtrBufferAtomicFmaxOp>,
876 RawBufferOpLowering<RawBufferAtomicSmaxOp,
877 ROCDL::RawPtrBufferAtomicSmaxOp>,
878 RawBufferOpLowering<RawBufferAtomicUminOp,
879 ROCDL::RawPtrBufferAtomicUminOp>,
880 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
881 ROCDL::RawPtrBufferAtomicCmpSwap>,
882 LDSBarrierOpLowering, MFMAOpLowering, WMMAOpLowering,
883 ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
884 PackedStochRoundFp8OpLowering>(converter, chipset);
888 return std::make_unique<ConvertAMDGPUToROCDLPass>();
static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter, Location loc, Value input)
If input is a vector of bytes, concatentate those bytes in little-endian order to form a single integ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Return the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, SmallVector< Value, 4 > &operands)
Push an input operand.
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value)
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector< Value, 4 > &operands)
Push the output operand.
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static MLIRContext * getContext(OpFoldResult val)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerAttr getI16IntegerAttr(int16_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
This class provides support for representing a failure result, or a valid value of type T.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isFloat8E4M3FNUZ() const
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isFloat8E5M2FNUZ() const
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
std::unique_ptr< Pass > createConvertAMDGPUToROCDLPass()
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: The ROCDL target does not support the LLVM bfloat type at this time and so this function will a...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static FailureOr< Chipset > parse(StringRef name)