MLIR 22.0.0git
NVVMDialect.cpp File Reference
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/NVVMIntrinsicUtils.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/NVPTXAddrSpace.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <optional>
#include <string>
#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"

Go to the source code of this file.

Macros

#define CP_ASYNC_ID_IMPL(mod, size, suffix)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
#define _none
#define CVT_F2TF32_ID_IMPL(rnd, relu, sf)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf)
#define TCGEN05_COMMIT_IMPL(cg, is_shared, mc)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg)
#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta)
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
#define GET_OP_LIST
#define GET_ATTRDEF_LIST
#define GET_OP_CLASSES
#define GET_ATTRDEF_CLASSES

Functions

static bool isPtrInAddrSpace (mlir::Value ptr, NVVMMemorySpace targetAS)
static bool isPtrInSharedCTASpace (mlir::Value ptr)
static llvm::nvvm::CTAGroupKind getNVVMCtaGroupKind (NVVM::CTAGroupKind ctaGroup)
static LogicalResult cpAsyncBulkTensorCommonVerifier (size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
static LogicalResult verifyTMALoadParams (size_t tensorDims, size_t numIm2colOff, TMALoadMode mode, Location loc)
static bool isInt4PtxType (MMATypes type)
static bool isInt8PtxType (MMATypes type)
static bool isIntegerPtxType (MMATypes type)
static std::pair< mlir::Type, unsignedinferMMATypeFromMNK (NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
static FailureOr< int > getAllowedSizeK (NVVM::WGMMATypes typeA)
static LogicalResult isAllowedWGMMADataType (NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
static LogicalResult isAllowedSizeN (int sizeN, NVVM::WGMMATypes typeA)
static llvm::Value * packValInto64Bits (llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
 Packs the given field into the result.
static unsigned isValidVectorLength (NVVM::Tcgen05LdStShape shape, unsigned vecLen)
static void nvvmInferResultRanges (Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
 Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
static LogicalResult verifyConstantRangeAttr (Operation *op, std::optional< LLVM::ConstantRangeAttr > rangeAttr)
 Verify the range attribute satisfies LLVM ConstantRange constructor requirements for NVVM SpecialRangeableRegisterOp.
static llvm::Value * getAsPackedI32 (llvm::Value *arg, llvm::IRBuilderBase &builder)
static llvm::Value * getParamCastedAddr (llvm::Value *addr, llvm::IRBuilderBase &builder)
static LogicalResult verifyTcgen05MMAOp (bool isATensor, mlir::Value disableOutputLane, NVVM::CTAGroupKind ctaGroup, bool hasAShift, NVVM::Tcgen05MMACollectorOp collectorOp, Location loc)
static LogicalResult verifyTcgen05MMABlockScaleOp (NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::Tcgen05MMABlockScaleKind kind, NVVM::Tcgen05MMABlockScale blockScale, Location loc)

Variables

static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic

Macro Definition Documentation

◆ _none

#define _none

Definition at line 2427 of file NVVMDialect.cpp.

◆ CP_ASYNC_ID_IMPL

#define CP_ASYNC_ID_IMPL ( mod,
size,
suffix )
Value:
llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix

Definition at line 1935 of file NVVMDialect.cpp.

◆ CVT_F2TF32_ID_IMPL

#define CVT_F2TF32_ID_IMPL ( rnd,
relu,
sf )
Value:
hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
: llvm::Intrinsic::nvvm_f2tf32_##rnd##sf

Definition at line 2429 of file NVVMDialect.cpp.

◆ GET_ATTRDEF_CLASSES

#define GET_ATTRDEF_CLASSES

Definition at line 3926 of file NVVMDialect.cpp.

◆ GET_ATTRDEF_LIST

#define GET_ATTRDEF_LIST

◆ GET_BF16X2_TO_F8X2_ID

#define GET_BF16X2_TO_F8X2_ID ( rnd,
has_satf )
Value:
has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
: llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd

Definition at line 2545 of file NVVMDialect.cpp.

◆ GET_CP_ASYNC_ID

#define GET_CP_ASYNC_ID ( mod,
size,
has_cpsize )
Value:
has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
#define CP_ASYNC_ID_IMPL(mod, size, suffix)

Definition at line 1938 of file NVVMDialect.cpp.

◆ GET_CVT_F2TF32_ID

#define GET_CVT_F2TF32_ID ( rnd,
relu,
sf )
Value:
hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
: CVT_F2TF32_ID_IMPL(rnd, relu, )
#define CVT_F2TF32_ID_IMPL(rnd, relu, sf)

Definition at line 2433 of file NVVMDialect.cpp.

◆ GET_F16x2_TO_F8X2_ID

#define GET_F16x2_TO_F8X2_ID ( type,
has_relu )
Value:
has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
: llvm::Intrinsic::nvvm_f16x2_to_##type##_rn

Definition at line 2526 of file NVVMDialect.cpp.

◆ GET_F32x2_TO_F6x2_ID

#define GET_F32x2_TO_F6x2_ID ( type,
has_relu )
Value:
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite

Definition at line 2471 of file NVVMDialect.cpp.

◆ GET_F32x2_TO_F8X2_S_ID

#define GET_F32x2_TO_F8X2_S_ID ( type,
has_relu )
Value:
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
: llvm::Intrinsic::nvvm_ff_to_##type##_rn

Definition at line 2494 of file NVVMDialect.cpp.

◆ GET_F32x2_TO_F8X2_US_ID

#define GET_F32x2_TO_F8X2_US_ID ( rnd,
has_satf )
Value:
has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
: llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd

Definition at line 2490 of file NVVMDialect.cpp.

◆ GET_OP_CLASSES

#define GET_OP_CLASSES

Definition at line 3923 of file NVVMDialect.cpp.

◆ GET_OP_LIST

#define GET_OP_LIST

◆ GET_TCGEN05_COMMIT_ID

#define GET_TCGEN05_COMMIT_ID ( cta_group,
is_shared,
has_mc )
Value:
has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
: TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
#define TCGEN05_COMMIT_IMPL(cg, is_shared, mc)

Definition at line 2700 of file NVVMDialect.cpp.

◆ GET_TCGEN05_CP_ID

#define GET_TCGEN05_CP_ID ( shape_mc,
src_fmt,
is_2cta )
Value:
[&]() -> auto { \
if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
}()
#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta)

Definition at line 2734 of file NVVMDialect.cpp.

◆ TCGEN05_COMMIT_IMPL

#define TCGEN05_COMMIT_IMPL ( cg,
is_shared,
mc )
Value:
is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
: llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg

Definition at line 2696 of file NVVMDialect.cpp.

◆ TCGEN05_CP_2CTA

#define TCGEN05_CP_2CTA ( shape_mc,
src_fmt,
is_2cta )
Value:
is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
: TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg)

Definition at line 2730 of file NVVMDialect.cpp.

◆ TCGEN05_CP_IMPL

#define TCGEN05_CP_IMPL ( shape_mc,
src_fmt,
cg )
Value:
llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg

Definition at line 2727 of file NVVMDialect.cpp.

Function Documentation

◆ cpAsyncBulkTensorCommonVerifier()

LogicalResult cpAsyncBulkTensorCommonVerifier ( size_t tensorDims,
bool isIm2Col,
size_t numIm2ColOffsets,
Location loc )
static

Definition at line 83 of file NVVMDialect.cpp.

References mlir::emitError(), and success().

◆ getAllowedSizeK()

FailureOr< int > getAllowedSizeK ( NVVM::WGMMATypes typeA)
static

Definition at line 1224 of file NVVMDialect.cpp.

◆ getAsPackedI32()

llvm::Value * getAsPackedI32 ( llvm::Value * arg,
llvm::IRBuilderBase & builder )
static

Definition at line 2930 of file NVVMDialect.cpp.

◆ getNVVMCtaGroupKind()

llvm::nvvm::CTAGroupKind getNVVMCtaGroupKind ( NVVM::CTAGroupKind ctaGroup)
static

Definition at line 66 of file NVVMDialect.cpp.

◆ getParamCastedAddr()

llvm::Value * getParamCastedAddr ( llvm::Value * addr,
llvm::IRBuilderBase & builder )
static

Definition at line 2979 of file NVVMDialect.cpp.

◆ inferMMATypeFromMNK()

std::pair< mlir::Type, unsigned > inferMMATypeFromMNK ( NVVM::MMATypes type,
NVVM::MMAFrag frag,
int m,
int n,
int k,
MLIRContext * context )
static

Definition at line 1045 of file NVVMDialect.cpp.

References mlir::NVVM::inferMMAType().

◆ isAllowedSizeN()

LogicalResult isAllowedSizeN ( int sizeN,
NVVM::WGMMATypes typeA )
static

Definition at line 1280 of file NVVMDialect.cpp.

References success().

◆ isAllowedWGMMADataType()

LogicalResult isAllowedWGMMADataType ( NVVM::WGMMATypes typeD,
NVVM::WGMMATypes typeA,
NVVM::WGMMATypes typeB )
static

Definition at line 1238 of file NVVMDialect.cpp.

References success().

◆ isInt4PtxType()

bool isInt4PtxType ( MMATypes type)
static

Definition at line 511 of file NVVMDialect.cpp.

Referenced by isIntegerPtxType().

◆ isInt8PtxType()

bool isInt8PtxType ( MMATypes type)
static

Definition at line 515 of file NVVMDialect.cpp.

Referenced by isIntegerPtxType().

◆ isIntegerPtxType()

bool isIntegerPtxType ( MMATypes type)
static

Definition at line 519 of file NVVMDialect.cpp.

References isInt4PtxType(), and isInt8PtxType().

◆ isPtrInAddrSpace()

bool isPtrInAddrSpace ( mlir::Value ptr,
NVVMMemorySpace targetAS )
static

Definition at line 55 of file NVVMDialect.cpp.

Referenced by isPtrInSharedCTASpace().

◆ isPtrInSharedCTASpace()

bool isPtrInSharedCTASpace ( mlir::Value ptr)
static

Definition at line 60 of file NVVMDialect.cpp.

References isPtrInAddrSpace().

◆ isValidVectorLength()

unsigned isValidVectorLength ( NVVM::Tcgen05LdStShape shape,
unsigned vecLen )
static

Definition at line 2847 of file NVVMDialect.cpp.

◆ nvvmInferResultRanges()

void nvvmInferResultRanges ( Operation * op,
Value result,
ArrayRef<::mlir::ConstantIntRanges > argRanges,
SetIntRangeFn setResultRanges )
static

Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.

Definition at line 2895 of file NVVMDialect.cpp.

References mlir::Operation::getAttrOfType(), and result.

◆ packValInto64Bits()

llvm::Value * packValInto64Bits ( llvm::IRBuilderBase & builder,
llvm::Value * result,
llvm::Value * field,
unsigned sizeInBits,
unsigned start )
static

Packs the given field into the result.

The result is 64-bits and each field can be 32-bits or narrower.

Definition at line 1741 of file NVVMDialect.cpp.

References result.

◆ verifyConstantRangeAttr()

LogicalResult verifyConstantRangeAttr ( Operation * op,
std::optional< LLVM::ConstantRangeAttr > rangeAttr )
static

Verify the range attribute satisfies LLVM ConstantRange constructor requirements for NVVM SpecialRangeableRegisterOp.

Definition at line 2907 of file NVVMDialect.cpp.

References mlir::Operation::emitOpError(), and success().

◆ verifyTcgen05MMABlockScaleOp()

LogicalResult verifyTcgen05MMABlockScaleOp ( NVVM::Tcgen05MMACollectorOp collectorOp,
NVVM::Tcgen05MMABlockScaleKind kind,
NVVM::Tcgen05MMABlockScale blockScale,
Location loc )
static

Definition at line 3520 of file NVVMDialect.cpp.

References mlir::emitError(), and success().

◆ verifyTcgen05MMAOp()

LogicalResult verifyTcgen05MMAOp ( bool isATensor,
mlir::Value disableOutputLane,
NVVM::CTAGroupKind ctaGroup,
bool hasAShift,
NVVM::Tcgen05MMACollectorOp collectorOp,
Location loc )
static

Definition at line 3254 of file NVVMDialect.cpp.

References mlir::emitError(), mlir::Value::getType(), and success().

◆ verifyTMALoadParams()

LogicalResult verifyTMALoadParams ( size_t tensorDims,
size_t numIm2colOff,
TMALoadMode mode,
Location loc )
static

Definition at line 142 of file NVVMDialect.cpp.

References mlir::emitError(), and success().

Variable Documentation

◆ notIntrinsic

unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic
staticconstexpr

Definition at line 49 of file NVVMDialect.cpp.