MLIR 23.0.0git
NVVMDialect.cpp File Reference
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/NVVMIntrinsicUtils.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/NVPTXAddrSpace.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <optional>
#include <string>
#include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
#include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"

Go to the source code of this file.

Classes

struct  ConvertFsubToFnegFadd

Macros

#define CP_ASYNC_ID_IMPL(mod, size, suffix)
#define GET_CP_ASYNC_ID(mod, size, has_cpsize)
#define _none
#define CVT_F2TF32_ID_IMPL(rnd, relu, sf)
#define GET_CVT_F2TF32_ID(rnd, relu, sf)
#define GET_F32x2_TO_F6x2_ID(type, has_relu)
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)
#define GET_F32x2_TO_F8X2_S_ID(type, has_relu)
#define GET_F16x2_TO_F8X2_ID(type, has_relu)
#define TCGEN05_COMMIT_IMPL(cg, is_shared, mc)
#define GET_TCGEN05_COMMIT_ID(cta_group, is_shared, has_mc)
#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg)
#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta)
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)
#define TCGEN05LDRED(SHAPE, NUM, TYPE)
#define GET_OP_LIST
#define GET_ATTRDEF_LIST
#define GET_OP_CLASSES
#define GET_ATTRDEF_CLASSES

Functions

static bool isPtrInAddrSpace (mlir::Value ptr, NVVMMemorySpace targetAS)
static bool isPtrInGenericSpace (mlir::Value ptr)
static bool isPtrInSharedCTASpace (mlir::Value ptr)
static bool isPtrInSharedClusterSpace (mlir::Value ptr)
static llvm::Value * castPtrToAddrSpace (llvm::IRBuilderBase &builder, llvm::Value *ptr, NVVMMemorySpace targetAS)
static llvm::nvvm::CTAGroupKind getNVVMCtaGroupKind (NVVM::CTAGroupKind ctaGroup)
static LogicalResult cpAsyncBulkTensorCommonVerifier (size_t tensorDims, bool isIm2Col, size_t numIm2ColOffsets, Location loc)
static LogicalResult verifyTMALoadParams (size_t tensorDims, size_t numIm2colOff, TMALoadMode mode, Location loc)
static LogicalResult verifyMBarrierArriveLikeOp (Operation *op, Value addr, NVVM::MemScopeKind scope, Value retVal=nullptr)
static LogicalResult inferMBarrierArriveResultTypes (MLIRContext *context, Value addr, SmallVectorImpl< Type > &inferredReturnTypes)
 Only shared_cluster (ptr<7>) produces zero results; all other address spaces (including generic) return i64.
static bool isCompatibleReturnTypesOptionalResult (TypeRange inferred, TypeRange actual)
 For ops with optional results, allow the user to omit the result even when inference would produce one.
static LogicalResult verifyConvertF32x2ToFP16x2Op (Twine dstType, FPRoundingMode rnd, bool hasRandomBits, Operation *op)
static bool isInt4PtxType (MMATypes type)
static bool isInt8PtxType (MMATypes type)
static bool isIntegerPtxType (MMATypes type)
static void printOperandList (OpAsmPrinter &p, StringRef name, ArrayRef< Value > operands)
static LogicalResult parseMmaOperand (OpAsmParser &parser, StringRef operandName, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &regs)
template<typename Op>
static void processOperandFragments (Op &op, std::array< MMAOperandFragment, 3 > &frags, SmallVectorImpl< Type > &regTypes, SmallVectorImpl< StringRef > &ignoreAttrNames)
static LogicalResult parseMmaTypeSignature (OpAsmParser &parser, SmallVectorImpl< Type > &operandTypes)
static void inferAndSetMultiplicandTypes (MLIRContext *ctx, NamedAttrList &attrs, const SmallVectorImpl< Type > &operandTypes)
template<typename OpType>
static void addBlockScaleProperties (OpBuilder &builder, OperationState &result, ArrayRef< int64_t > shape, ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat, MMABlockScaleKind kind)
static void addInferredMultiplicandTypes (MLIRContext *ctx, OperationState &result, ValueRange operandA, ValueRange operandB, std::optional< std::array< MMATypes, 2 > > multiplicandPtxTypes)
template<typename OpTy>
static MMATypes inferPtxTypeFromResult (OpTy op)
static std::pair< mlir::Type, unsignedinferMMATypeFromMNK (NVVM::MMATypes type, NVVM::MMAFrag frag, int m, int n, int k, MLIRContext *context)
static FailureOr< int > getAllowedSizeK (NVVM::WGMMATypes typeA)
static LogicalResult isAllowedWGMMADataType (NVVM::WGMMATypes typeD, NVVM::WGMMATypes typeA, NVVM::WGMMATypes typeB)
static LogicalResult isAllowedSizeN (int sizeN, NVVM::WGMMATypes typeA)
template<typename OpType>
static LogicalResult verifyAddSubFOp (OpType op)
static llvm::Value * packValInto64Bits (llvm::IRBuilderBase &builder, llvm::Value *result, llvm::Value *field, unsigned sizeInBits, unsigned start)
 Packs the given field into the result.
static unsigned isValidVectorLength (NVVM::Tcgen05LdStShape shape, unsigned vecLen)
static void nvvmInferResultRanges (Operation *op, Value result, ArrayRef<::mlir::ConstantIntRanges > argRanges, SetIntRangeFn setResultRanges)
 Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.
static LogicalResult verifyConstantRangeAttr (Operation *op, std::optional< LLVM::ConstantRangeAttr > rangeAttr)
 Verify the range attribute satisfies LLVM ConstantRange constructor requirements for NVVM SpecialRangeableRegisterOp.
static llvm::Value * getAsPackedI32 (llvm::Value *arg, llvm::IRBuilderBase &builder)
static llvm::Value * getParamCastedAddr (llvm::Value *addr, llvm::IRBuilderBase &builder)
static LogicalResult verifyTcgen05MMAOp (bool isATensor, mlir::Value disableOutputLane, NVVM::CTAGroupKind ctaGroup, bool hasAShift, NVVM::Tcgen05MMACollectorOp collectorOp, Location loc)
static LogicalResult verifyTcgen05MMABlockScaleOp (NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::Tcgen05MMAKind kind, NVVM::Tcgen05MMABlockScale blockScale, Location loc)

Variables

static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic

Macro Definition Documentation

◆ _none

#define _none

Definition at line 4408 of file NVVMDialect.cpp.

◆ CP_ASYNC_ID_IMPL

#define CP_ASYNC_ID_IMPL ( mod,
size,
suffix )
Value:
llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix

Definition at line 3916 of file NVVMDialect.cpp.

◆ CVT_F2TF32_ID_IMPL

#define CVT_F2TF32_ID_IMPL ( rnd,
relu,
sf )
Value:
hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
: llvm::Intrinsic::nvvm_f2tf32_##rnd##sf

Definition at line 4410 of file NVVMDialect.cpp.

◆ GET_ATTRDEF_CLASSES

#define GET_ATTRDEF_CLASSES

Definition at line 6293 of file NVVMDialect.cpp.

◆ GET_ATTRDEF_LIST

#define GET_ATTRDEF_LIST

◆ GET_CP_ASYNC_ID

#define GET_CP_ASYNC_ID ( mod,
size,
has_cpsize )
Value:
has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
#define CP_ASYNC_ID_IMPL(mod, size, suffix)

Definition at line 3919 of file NVVMDialect.cpp.

◆ GET_CVT_F2TF32_ID

#define GET_CVT_F2TF32_ID ( rnd,
relu,
sf )
Value:
hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
: CVT_F2TF32_ID_IMPL(rnd, relu, )
#define CVT_F2TF32_ID_IMPL(rnd, relu, sf)

Definition at line 4414 of file NVVMDialect.cpp.

◆ GET_F16x2_TO_F8X2_ID

#define GET_F16x2_TO_F8X2_ID ( type,
has_relu )
Value:
has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
: llvm::Intrinsic::nvvm_f16x2_to_##type##_rn

Definition at line 4581 of file NVVMDialect.cpp.

◆ GET_F32x2_TO_F6x2_ID

#define GET_F32x2_TO_F6x2_ID ( type,
has_relu )
Value:
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite

Definition at line 4452 of file NVVMDialect.cpp.

◆ GET_F32x2_TO_F8X2_S_ID

#define GET_F32x2_TO_F8X2_S_ID ( type,
has_relu )
Value:
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
: llvm::Intrinsic::nvvm_ff_to_##type##_rn

Definition at line 4549 of file NVVMDialect.cpp.

◆ GET_F32x2_TO_F8X2_US_ID

#define GET_F32x2_TO_F8X2_US_ID ( rnd,
has_satf )
Value:
has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
: llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd

Definition at line 4545 of file NVVMDialect.cpp.

◆ GET_OP_CLASSES

#define GET_OP_CLASSES

Definition at line 6290 of file NVVMDialect.cpp.

◆ GET_OP_LIST

#define GET_OP_LIST

◆ GET_TCGEN05_COMMIT_ID

#define GET_TCGEN05_COMMIT_ID ( cta_group,
is_shared,
has_mc )
Value:
has_mc ? TCGEN05_COMMIT_IMPL(cta_group, is_shared, _mc) \
: TCGEN05_COMMIT_IMPL(cta_group, is_shared, )
#define TCGEN05_COMMIT_IMPL(cg, is_shared, mc)

Definition at line 4840 of file NVVMDialect.cpp.

◆ GET_TCGEN05_CP_ID

#define GET_TCGEN05_CP_ID ( shape_mc,
src_fmt,
is_2cta )
Value:
[&]() -> auto { \
if ((src_fmt) == Tcgen05CpSrcFormat::B6x16_P32) \
return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
if ((src_fmt) == Tcgen05CpSrcFormat::B4x16_P64) \
return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
}()
#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta)

Definition at line 4874 of file NVVMDialect.cpp.

◆ TCGEN05_COMMIT_IMPL

#define TCGEN05_COMMIT_IMPL ( cg,
is_shared,
mc )
Value:
is_shared ? llvm::Intrinsic::nvvm_tcgen05_commit##mc##_shared##_##cg \
: llvm::Intrinsic::nvvm_tcgen05_commit##mc##_##cg

Definition at line 4836 of file NVVMDialect.cpp.

◆ TCGEN05_CP_2CTA

#define TCGEN05_CP_2CTA ( shape_mc,
src_fmt,
is_2cta )
Value:
is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
: TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg)

Definition at line 4870 of file NVVMDialect.cpp.

◆ TCGEN05_CP_IMPL

#define TCGEN05_CP_IMPL ( shape_mc,
src_fmt,
cg )
Value:
llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg

Definition at line 4867 of file NVVMDialect.cpp.

◆ TCGEN05LDRED

#define TCGEN05LDRED ( SHAPE,
NUM,
TYPE )
Value:
llvm::Intrinsic::nvvm_tcgen05_ld_red_##SHAPE##_##NUM##_##TYPE

Definition at line 5986 of file NVVMDialect.cpp.

Function Documentation

◆ addBlockScaleProperties()

template<typename OpType>
void addBlockScaleProperties ( OpBuilder & builder,
OperationState & result,
ArrayRef< int64_t > shape,
ScaleVecSize scaleVecSize,
BlockScaleFormat blockScaleFormat,
MMABlockScaleKind kind )
static

Definition at line 1775 of file NVVMDialect.cpp.

References mlir::Builder::getAttr(), mlir::Builder::getContext(), and result.

◆ addInferredMultiplicandTypes()

void addInferredMultiplicandTypes ( MLIRContext * ctx,
OperationState & result,
ValueRange operandA,
ValueRange operandB,
std::optional< std::array< MMATypes, 2 > > multiplicandPtxTypes )
static

Definition at line 1791 of file NVVMDialect.cpp.

References mlir::getType(), and result.

◆ castPtrToAddrSpace()

llvm::Value * castPtrToAddrSpace ( llvm::IRBuilderBase & builder,
llvm::Value * ptr,
NVVMMemorySpace targetAS )
static

Definition at line 72 of file NVVMDialect.cpp.

◆ cpAsyncBulkTensorCommonVerifier()

LogicalResult cpAsyncBulkTensorCommonVerifier ( size_t tensorDims,
bool isIm2Col,
size_t numIm2ColOffsets,
Location loc )
static

Definition at line 99 of file NVVMDialect.cpp.

References mlir::emitError(), and success().

◆ getAllowedSizeK()

FailureOr< int > getAllowedSizeK ( NVVM::WGMMATypes typeA)
static

Definition at line 2621 of file NVVMDialect.cpp.

◆ getAsPackedI32()

llvm::Value * getAsPackedI32 ( llvm::Value * arg,
llvm::IRBuilderBase & builder )
static

Definition at line 5142 of file NVVMDialect.cpp.

◆ getNVVMCtaGroupKind()

llvm::nvvm::CTAGroupKind getNVVMCtaGroupKind ( NVVM::CTAGroupKind ctaGroup)
static

Definition at line 82 of file NVVMDialect.cpp.

◆ getParamCastedAddr()

llvm::Value * getParamCastedAddr ( llvm::Value * addr,
llvm::IRBuilderBase & builder )
static

Definition at line 5191 of file NVVMDialect.cpp.

◆ inferAndSetMultiplicandTypes()

void inferAndSetMultiplicandTypes ( MLIRContext * ctx,
NamedAttrList & attrs,
const SmallVectorImpl< Type > & operandTypes )
static

Definition at line 1757 of file NVVMDialect.cpp.

References mlir::NamedAttrList::get(), and mlir::NamedAttrList::set().

◆ inferMBarrierArriveResultTypes()

LogicalResult inferMBarrierArriveResultTypes ( MLIRContext * context,
Value addr,
SmallVectorImpl< Type > & inferredReturnTypes )
static

Only shared_cluster (ptr<7>) produces zero results; all other address spaces (including generic) return i64.

Definition at line 312 of file NVVMDialect.cpp.

References isPtrInSharedClusterSpace(), and success().

◆ inferMMATypeFromMNK()

std::pair< mlir::Type, unsigned > inferMMATypeFromMNK ( NVVM::MMATypes type,
NVVM::MMAFrag frag,
int m,
int n,
int k,
MLIRContext * context )
static

Definition at line 2414 of file NVVMDialect.cpp.

References mlir::NVVM::inferMMAType().

◆ inferPtxTypeFromResult()

template<typename OpTy>
MMATypes inferPtxTypeFromResult ( OpTy op)
static

Definition at line 1810 of file NVVMDialect.cpp.

◆ isAllowedSizeN()

LogicalResult isAllowedSizeN ( int sizeN,
NVVM::WGMMATypes typeA )
static

Definition at line 2677 of file NVVMDialect.cpp.

References success().

◆ isAllowedWGMMADataType()

LogicalResult isAllowedWGMMADataType ( NVVM::WGMMATypes typeD,
NVVM::WGMMATypes typeA,
NVVM::WGMMATypes typeB )
static

Definition at line 2635 of file NVVMDialect.cpp.

References success().

◆ isCompatibleReturnTypesOptionalResult()

bool isCompatibleReturnTypesOptionalResult ( TypeRange inferred,
TypeRange actual )
static

For ops with optional results, allow the user to omit the result even when inference would produce one.

This preserves backward compatibility: the result can be silently discarded (e.g., for fire-and-forget arrive ops).

Definition at line 359 of file NVVMDialect.cpp.

◆ isInt4PtxType()

bool isInt4PtxType ( MMATypes type)
static

Definition at line 747 of file NVVMDialect.cpp.

Referenced by isIntegerPtxType().

◆ isInt8PtxType()

bool isInt8PtxType ( MMATypes type)
static

Definition at line 751 of file NVVMDialect.cpp.

Referenced by isIntegerPtxType().

◆ isIntegerPtxType()

bool isIntegerPtxType ( MMATypes type)
static

Definition at line 755 of file NVVMDialect.cpp.

References isInt4PtxType(), and isInt8PtxType().

◆ isPtrInAddrSpace()

bool isPtrInAddrSpace ( mlir::Value ptr,
NVVMMemorySpace targetAS )
static

◆ isPtrInGenericSpace()

bool isPtrInGenericSpace ( mlir::Value ptr)
static

Definition at line 60 of file NVVMDialect.cpp.

References isPtrInAddrSpace().

◆ isPtrInSharedClusterSpace()

bool isPtrInSharedClusterSpace ( mlir::Value ptr)
static

Definition at line 68 of file NVVMDialect.cpp.

References isPtrInAddrSpace().

Referenced by inferMBarrierArriveResultTypes(), and verifyMBarrierArriveLikeOp().

◆ isPtrInSharedCTASpace()

bool isPtrInSharedCTASpace ( mlir::Value ptr)
static

Definition at line 64 of file NVVMDialect.cpp.

References isPtrInAddrSpace().

◆ isValidVectorLength()

unsigned isValidVectorLength ( NVVM::Tcgen05LdStShape shape,
unsigned vecLen )
static

Definition at line 5057 of file NVVMDialect.cpp.

◆ nvvmInferResultRanges()

void nvvmInferResultRanges ( Operation * op,
Value result,
ArrayRef<::mlir::ConstantIntRanges > argRanges,
SetIntRangeFn setResultRanges )
static

Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might have ConstantRangeAttr.

Definition at line 5105 of file NVVMDialect.cpp.

References mlir::Operation::getAttrOfType(), mlir::IntegerValueRange::getMaxRange(), and result.

◆ packValInto64Bits()

llvm::Value * packValInto64Bits ( llvm::IRBuilderBase & builder,
llvm::Value * result,
llvm::Value * field,
unsigned sizeInBits,
unsigned start )
static

Packs the given field into the result.

The result is 64-bits and each field can be 32-bits or narrower.

Definition at line 3345 of file NVVMDialect.cpp.

References result.

◆ parseMmaOperand()

LogicalResult parseMmaOperand ( OpAsmParser & parser,
StringRef operandName,
SmallVectorImpl< OpAsmParser::UnresolvedOperand > & regs )
static

◆ parseMmaTypeSignature()

◆ printOperandList()

void printOperandList ( OpAsmPrinter & p,
StringRef name,
ArrayRef< Value > operands )
static

Definition at line 1683 of file NVVMDialect.cpp.

References mlir::OpAsmPrinter::printOperands().

◆ processOperandFragments()

template<typename Op>
void processOperandFragments ( Op & op,
std::array< MMAOperandFragment, 3 > & frags,
SmallVectorImpl< Type > & regTypes,
SmallVectorImpl< StringRef > & ignoreAttrNames )
static

Definition at line 1706 of file NVVMDialect.cpp.

◆ verifyAddSubFOp()

template<typename OpType>
LogicalResult verifyAddSubFOp ( OpType op)
static

◆ verifyConstantRangeAttr()

LogicalResult verifyConstantRangeAttr ( Operation * op,
std::optional< LLVM::ConstantRangeAttr > rangeAttr )
static

Verify the range attribute satisfies LLVM ConstantRange constructor requirements for NVVM SpecialRangeableRegisterOp.

Definition at line 5119 of file NVVMDialect.cpp.

References mlir::Operation::emitOpError(), and success().

◆ verifyConvertF32x2ToFP16x2Op()

LogicalResult verifyConvertF32x2ToFP16x2Op ( Twine dstType,
FPRoundingMode rnd,
bool hasRandomBits,
Operation * op )
static

Definition at line 618 of file NVVMDialect.cpp.

References mlir::Operation::emitOpError(), and success().

◆ verifyMBarrierArriveLikeOp()

LogicalResult verifyMBarrierArriveLikeOp ( Operation * op,
Value addr,
NVVM::MemScopeKind scope,
Value retVal = nullptr )
static

◆ verifyTcgen05MMABlockScaleOp()

LogicalResult verifyTcgen05MMABlockScaleOp ( NVVM::Tcgen05MMACollectorOp collectorOp,
NVVM::Tcgen05MMAKind kind,
NVVM::Tcgen05MMABlockScale blockScale,
Location loc )
static

Definition at line 5798 of file NVVMDialect.cpp.

References mlir::emitError(), and success().

◆ verifyTcgen05MMAOp()

LogicalResult verifyTcgen05MMAOp ( bool isATensor,
mlir::Value disableOutputLane,
NVVM::CTAGroupKind ctaGroup,
bool hasAShift,
NVVM::Tcgen05MMACollectorOp collectorOp,
Location loc )
static

Definition at line 5533 of file NVVMDialect.cpp.

References mlir::emitError(), mlir::Value::getType(), and success().

◆ verifyTMALoadParams()

LogicalResult verifyTMALoadParams ( size_t tensorDims,
size_t numIm2colOff,
TMALoadMode mode,
Location loc )
static

Definition at line 158 of file NVVMDialect.cpp.

References mlir::emitError(), and success().

Variable Documentation

◆ notIntrinsic

unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic
staticconstexpr

Definition at line 49 of file NVVMDialect.cpp.