21 #include "llvm/ADT/STLExtras.h"
25 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDL
26 #include "mlir/Conversion/Passes.h.inc"
35 return rewriter.
create<LLVM::ConstantOp>(loc, llvmI32, value);
41 return rewriter.
createOrFold<LLVM::ConstantOp>(loc, llvmI1, value);
46 template <
typename GpuOp,
typename Intrinsic>
52 static constexpr uint32_t maxVectorOpWidth = 128;
55 matchAndRewrite(GpuOp gpuOp,
typename GpuOp::Adaptor adaptor,
58 Value memref = adaptor.getMemref();
59 Value unconvertedMemref = gpuOp.getMemref();
60 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.
getType());
63 return gpuOp.emitOpError(
"raw buffer ops require GCN or higher");
65 Value storeData = adaptor.getODSOperands(0)[0];
66 if (storeData == memref)
70 wantedDataType = storeData.
getType();
72 wantedDataType = gpuOp.getODSResults(0)[0].getType();
77 Value maybeCmpData = adaptor.getODSOperands(1)[0];
78 if (maybeCmpData != memref)
79 atomicCmpData = maybeCmpData;
82 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
85 Type llvmI32 = this->typeConverter->convertType(i32);
86 Type llvmI16 = this->typeConverter->convertType(rewriter.
getI16Type());
88 int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8;
98 Type llvmBufferValType = llvmWantedDataType;
99 if (wantedDataType.
isBF16())
101 if (
auto wantedVecType = dyn_cast<VectorType>(wantedDataType))
102 if (wantedVecType.getElementType().isBF16())
103 llvmBufferValType = wantedVecType.clone(rewriter.
getI16Type());
105 if (isa<VectorType>(wantedDataType))
106 return gpuOp.emitOpError(
"vector compare-and-swap does not exist");
107 if (
auto floatType = dyn_cast<FloatType>(wantedDataType))
108 llvmBufferValType = this->getTypeConverter()->convertType(
111 if (
auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
112 uint32_t elemBits = dataVector.getElementTypeBitWidth();
113 uint32_t totalBits = elemBits * dataVector.getNumElements();
114 if (totalBits > maxVectorOpWidth)
115 return gpuOp.emitOpError(
116 "Total width of loads or stores must be no more than " +
117 Twine(maxVectorOpWidth) +
" bits, but we call for " +
119 " bits. This should've been caught in validation");
121 if (totalBits > 32) {
122 if (totalBits % 32 != 0)
123 return gpuOp.emitOpError(
"Load or store of more than 32-bits that "
124 "doesn't fit into words. Can't happen\n");
125 llvmBufferValType = this->typeConverter->convertType(
128 llvmBufferValType = this->typeConverter->convertType(
136 if (llvmBufferValType != llvmWantedDataType) {
138 rewriter.
create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
139 args.push_back(castForStore);
141 args.push_back(storeData);
146 if (llvmBufferValType != llvmWantedDataType) {
147 Value castForCmp = rewriter.
create<LLVM::BitcastOp>(
148 loc, llvmBufferValType, atomicCmpData);
149 args.push_back(castForCmp);
151 args.push_back(atomicCmpData);
159 return gpuOp.emitOpError(
"Can't lower non-stride-offset memrefs");
163 Value ptr = memrefDescriptor.alignedPtr(rewriter, loc);
169 if (memrefType.hasStaticShape()) {
172 static_cast<int32_t
>(memrefType.getNumElements() * elementByteWidth));
175 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
176 Value size = memrefDescriptor.size(rewriter, loc, i);
177 Value stride = memrefDescriptor.stride(rewriter, loc, i);
178 stride = rewriter.
create<LLVM::MulOp>(loc, stride, byteWidthConst);
179 Value maxThisDim = rewriter.
create<LLVM::MulOp>(loc, size, stride);
180 maxIndex = maxIndex ? rewriter.
create<LLVM::MaximumOp>(loc, maxIndex,
184 numRecords = rewriter.
create<LLVM::TruncOp>(loc, llvmI32, maxIndex);
201 uint32_t flags = (7 << 12) | (4 << 15);
204 uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2;
205 flags |= (oob << 28);
210 loc, rsrcType, ptr, stride, numRecords, flagsConst);
211 args.push_back(resource);
216 size_t i = pair.index();
217 Value index = pair.value();
219 if (ShapedType::isDynamic(strides[i])) {
220 strideOp = rewriter.
create<LLVM::MulOp>(
221 loc, memrefDescriptor.stride(rewriter, loc, i), byteWidthConst);
226 index = rewriter.
create<LLVM::MulOp>(loc, index, strideOp);
227 voffset = rewriter.
create<LLVM::AddOp>(loc, voffset, index);
229 if (adaptor.getIndexOffset()) {
230 int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth;
233 voffset ? rewriter.
create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
236 args.push_back(voffset);
238 Value sgprOffset = adaptor.getSgprOffset();
241 if (ShapedType::isDynamic(offset))
242 sgprOffset = rewriter.
create<LLVM::AddOp>(
243 loc, memrefDescriptor.offset(rewriter, loc), sgprOffset);
245 sgprOffset = rewriter.
create<LLVM::AddOp>(
247 args.push_back(sgprOffset);
256 Operation *lowered = rewriter.
create<Intrinsic>(loc, resultTypes, args,
260 if (llvmBufferValType != llvmWantedDataType) {
261 replacement = rewriter.
create<LLVM::BitcastOp>(loc, llvmWantedDataType,
276 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
279 LLVM::AsmDialect::AD_ATT);
280 const char *asmStr =
"s_waitcnt lgkmcnt(0)\ns_barrier";
281 const char *constraints =
"";
285 asmStr, constraints,
true,
286 false, asmDialectAttr,
304 if (
auto vectorType = dyn_cast<VectorType>(inputType)) {
305 if (vectorType.getElementType().isBF16())
306 return rewriter.
create<LLVM::BitcastOp>(
307 loc, vectorType.clone(rewriter.
getI16Type()), input);
309 if (!vectorType.getElementType().isInteger(8))
311 int64_t numBytes = vectorType.getNumElements();
315 for (int64_t i = 0; i < numBytes; ++i) {
318 rewriter.
create<LLVM::ExtractElementOp>(loc, input, idxConst);
319 Value extended = rewriter.
create<LLVM::ZExtOp>(loc, destType, element);
320 Value shiftConst = rewriter.
create<LLVM::ConstantOp>(
322 Value shifted = rewriter.
create<LLVM::ShlOp>(loc, extended, shiftConst);
323 result = rewriter.
create<LLVM::OrOp>(loc, result, shifted);
338 bool isUnsigned,
Value llvmInput,
341 auto vectorType = inputType.
dyn_cast<VectorType>();
342 Type elemType = vectorType.getElementType();
344 if (elemType.isBF16())
345 llvmInput = rewriter.
create<LLVM::BitcastOp>(
346 loc, vectorType.clone(rewriter.
getI16Type()), llvmInput);
347 if (!elemType.isInteger(8)) {
348 operands.push_back(llvmInput);
352 int64_t numBytes = vectorType.getNumElements();
355 auto llvmVectorType32bits = typeConverter->
convertType(vectorType32bits);
358 loc, llvmVectorType32bits, llvmInput);
361 bool localIsUnsigned = isUnsigned;
362 if (elemType.isUnsignedInteger(8)) {
363 localIsUnsigned =
true;
364 }
else if (elemType.isSignedInteger(8)) {
365 localIsUnsigned =
false;
368 operands.push_back(sign);
369 operands.push_back(result);
381 Value output, int32_t subwordOffset,
384 auto vectorType = inputType.
dyn_cast<VectorType>();
385 Type elemType = vectorType.getElementType();
386 if (elemType.isBF16())
387 output = rewriter.
create<LLVM::BitcastOp>(
388 loc, vectorType.clone(rewriter.
getI16Type()), output);
389 operands.push_back(output);
390 if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
392 }
else if (elemType.isInteger(32)) {
402 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
403 b = mfma.getBlocks();
404 Type sourceElem = mfma.getSourceA().getType();
405 if (
auto sourceType = dyn_cast<VectorType>(sourceElem))
406 sourceElem = sourceType.getElementType();
407 Type destElem = mfma.getDestC().getType();
408 if (
auto destType = dyn_cast<VectorType>(destElem))
409 destElem = destType.getElementType();
412 if (mfma.getReducePrecision() && chipset.
minorVersion >= 0x40) {
413 if (m == 32 && n == 32 && k == 4 && b == 1)
414 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
415 if (m == 16 && n == 16 && k == 8 && b == 1)
416 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
418 if (m == 32 && n == 32 && k == 1 && b == 2)
419 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
420 if (m == 16 && n == 16 && k == 1 && b == 4)
421 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
422 if (m == 4 && n == 4 && k == 1 && b == 16)
423 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
424 if (m == 32 && n == 32 && k == 2 && b == 1)
425 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
426 if (m == 16 && n == 16 && k == 4 && b == 1)
427 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
431 if (m == 32 && n == 32 && k == 4 && b == 2)
432 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
433 if (m == 16 && n == 16 && k == 4 && b == 4)
434 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
435 if (m == 4 && n == 4 && k == 4 && b == 16)
436 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
437 if (m == 32 && n == 32 && k == 8 && b == 1)
438 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
439 if (m == 16 && n == 16 && k == 16 && b == 1)
440 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
444 if (m == 32 && n == 32 && k == 4 && b == 2)
445 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
446 if (m == 16 && n == 16 && k == 4 && b == 4)
447 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
448 if (m == 4 && n == 4 && k == 4 && b == 16)
449 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
450 if (m == 32 && n == 32 && k == 8 && b == 1)
451 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
452 if (m == 16 && n == 16 && k == 16 && b == 1)
453 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
457 if (m == 32 && n == 32 && k == 2 && b == 2)
458 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
459 if (m == 16 && n == 16 && k == 2 && b == 4)
460 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
461 if (m == 4 && n == 4 && k == 2 && b == 16)
462 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
463 if (m == 32 && n == 32 && k == 4 && b == 1)
464 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
465 if (m == 16 && n == 16 && k == 8 && b == 1)
466 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
469 if (isa<IntegerType>(sourceElem) && destElem.
isInteger(32)) {
470 if (m == 32 && n == 32 && k == 4 && b == 2)
471 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
472 if (m == 16 && n == 16 && k == 4 && b == 4)
473 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
474 if (m == 4 && n == 4 && k == 4 && b == 16)
475 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
476 if (m == 32 && n == 32 && k == 8 && b == 1)
477 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
478 if (m == 16 && n == 16 && k == 16 && b == 1)
479 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
480 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset.
minorVersion >= 0x40)
481 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
482 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset.
minorVersion >= 0x40)
483 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
487 if (m == 16 && n == 16 && k == 4 && b == 1)
488 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
489 if (m == 4 && n == 4 && k == 4 && b == 4)
490 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
498 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
499 if (m == 16 && n == 16 && k == 32 && b == 1) {
501 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
503 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
505 if (m == 32 && n == 32 && k == 16 && b == 1) {
507 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
509 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
516 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
517 if (m == 16 && n == 16 && k == 32 && b == 1) {
519 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
521 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
523 if (m == 32 && n == 32 && k == 16 && b == 1) {
525 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
527 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
540 auto sourceVectorType = wmma.getSourceA().getType().dyn_cast<VectorType>();
541 auto destVectorType = wmma.getDestC().getType().dyn_cast<VectorType>();
542 auto elemSourceType = sourceVectorType.getElementType();
543 auto elemDestType = destVectorType.getElementType();
545 if (elemSourceType.isF16() && elemDestType.isF32()) {
546 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
548 if (elemSourceType.isBF16() && elemDestType.isF32()) {
549 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
550 }
else if (elemSourceType.isF16() && elemDestType.isF16()) {
551 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
552 }
else if (elemSourceType.isBF16() && elemDestType.isBF16()) {
553 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
554 }
else if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) {
555 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
568 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
571 Type outType = typeConverter->convertType(op.getDestD().getType());
572 Type intrinsicOutType = outType;
573 if (
auto outVecType = dyn_cast<VectorType>(outType))
574 if (outVecType.getElementType().isBF16())
575 intrinsicOutType = outVecType.clone(rewriter.
getI16Type());
578 return op->
emitOpError(
"MFMA only supported on gfx908+");
579 uint32_t getBlgpField =
static_cast<uint32_t
>(op.getBlgp());
580 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
582 return op.
emitOpError(
"negation unsupported on older than gfx840");
584 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
587 if (!maybeIntrinsic.has_value())
588 return op.
emitOpError(
"no intrinsic matching MFMA size on given chipset");
590 loweredOp.addTypes(intrinsicOutType);
591 loweredOp.addOperands(
598 if (outType != intrinsicOutType)
599 lowered = rewriter.
create<LLVM::BitcastOp>(loc, outType, lowered);
612 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
615 Type outType = typeConverter->convertType(op.getDestD().getType());
618 return op->
emitOpError(
"WMMA only supported on gfx11");
622 if (!maybeIntrinsic.has_value())
623 return op.
emitOpError(
"no intrinsic matching WMMA on the given chipset");
626 loweredOp.addTypes(outType);
630 adaptor.getSourceA(), operands);
632 adaptor.getSourceB(), operands);
634 op.getSubwordOffset(), op.getClamp(), operands);
636 loweredOp.addOperands(operands);
645 struct ExtPackedFp8OpLowering final
653 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
657 struct PackedTrunc2xFp8OpLowering final
665 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
669 struct PackedStochRoundFp8OpLowering final
677 matchAndRewrite(PackedStochRoundFp8Op op,
678 PackedStochRoundFp8OpAdaptor adaptor,
684 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
689 loc,
"Fp8 conversion instructions are not available on target "
690 "architecture and their emulation is not implemented");
693 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
696 Value source = adaptor.getSource();
697 auto sourceVecType = op.getSource().getType().dyn_cast<VectorType>();
700 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
701 Value longVec = rewriter.
create<LLVM::UndefOp>(loc, v4i8);
702 if (!sourceVecType) {
703 longVec = rewriter.
create<LLVM::InsertElementOp>(
706 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
708 Value elem = rewriter.
create<LLVM::ExtractElementOp>(loc, source, idx);
710 rewriter.
create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
715 Value i32Source = rewriter.
create<LLVM::BitcastOp>(loc, i32, source);
717 if (sourceElemType.isFloat8E5M2FNUZ()) {
720 }
else if (sourceElemType.isFloat8E4M3FNUZ()) {
728 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
733 loc,
"Fp8 conversion instructions are not available on target "
734 "architecture and their emulation is not implemented");
735 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
740 Value sourceA = adaptor.getSourceA();
741 Value sourceB = adaptor.getSourceB();
743 sourceB = rewriter.
create<LLVM::UndefOp>(loc, sourceA.
getType());
744 Value existing = adaptor.getExisting();
746 existing = rewriter.
create<LLVM::BitcastOp>(loc, i32, existing);
748 existing = rewriter.
create<LLVM::UndefOp>(loc, i32);
753 result = rewriter.
create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
756 result = rewriter.
create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
760 op, getTypeConverter()->convertType(resultType), result);
764 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
765 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
770 loc,
"Fp8 conversion instructions are not available on target "
771 "architecture and their emulation is not implemented");
772 Type i32 = getTypeConverter()->convertType(rewriter.
getI32Type());
777 Value source = adaptor.getSource();
778 Value stoch = adaptor.getStochiasticParam();
779 Value existing = adaptor.getExisting();
781 existing = rewriter.
create<LLVM::BitcastOp>(loc, i32, existing);
783 existing = rewriter.
create<LLVM::UndefOp>(loc, i32);
788 result = rewriter.
create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
791 result = rewriter.
create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
795 op, getTypeConverter()->convertType(resultType), result);
799 struct ConvertAMDGPUToROCDLPass
800 :
public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
801 ConvertAMDGPUToROCDLPass() =
default;
803 void runOnOperation()
override {
806 if (
failed(maybeChipset)) {
808 return signalPassFailure();
815 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
816 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
817 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
819 std::move(patterns))))
831 converter.
addConversion([&converter](VectorType t) -> std::optional<Type> {
832 if (!t.getElementType().isBF16())
837 patterns.
add<LDSBarrierOpLowering>(converter);
839 .
add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
840 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
841 RawBufferOpLowering<RawBufferAtomicFaddOp,
842 ROCDL::RawPtrBufferAtomicFaddOp>,
843 RawBufferOpLowering<RawBufferAtomicFmaxOp,
844 ROCDL::RawPtrBufferAtomicFmaxOp>,
845 RawBufferOpLowering<RawBufferAtomicSmaxOp,
846 ROCDL::RawPtrBufferAtomicSmaxOp>,
847 RawBufferOpLowering<RawBufferAtomicUminOp,
848 ROCDL::RawPtrBufferAtomicUminOp>,
849 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
850 ROCDL::RawPtrBufferAtomicCmpSwap>,
851 MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
852 PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
857 return std::make_unique<ConvertAMDGPUToROCDLPass>();
static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter, Location loc, Value input)
If input is a vector of bytes, concatentate those bytes in little-endian order to form a single integ...
static std::optional< StringRef > wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset)
Return the rocdl intrinsic corresponding to a WMMA operation wmma if one exists.
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput, SmallVector< Value, 4 > &operands)
Push an input operand.
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value)
static std::optional< StringRef > mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset)
Return the rocdl intrinsic corresponding to a MFMA operation mfma if one exists.
static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, Location loc, const TypeConverter *typeConverter, Value output, int32_t subwordOffset, bool clamp, SmallVector< Value, 4 > &operands)
Push the output operand.
static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value)
static MLIRContext * getContext(OpFoldResult val)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerAttr getI16IntegerAttr(int16_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
This class provides support for representing a failure result, or a valid value of type T.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
void addConversion(FnT &&callback)
Register a conversion function.
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
bool isFloat8E4M3FNUZ() const
bool isFloat8E5M2FNUZ() const
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
std::unique_ptr< Pass > createConvertAMDGPUToROCDLPass()
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, amdgpu::Chipset chipset)
Note: The ROCDL target does not support the LLVM bfloat type at this time and so this function will a...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static FailureOr< Chipset > parse(StringRef name)