MLIR 22.0.0git
MMAUtils.cpp File Reference

Go to the source code of this file.

Functions

static bool isAccumulatorOrResult (MatMulOperandRole operandType)
static int64_t inferNumRegistersPerMatrixFragment (const WarpMatrixInfo &type)
 Returns the number of registers which compose a matrix fragment held by a single thread.
static std::array< int64_t, 2 > getTileShape (ArrayRef< int64_t > operandShape, Type elementType, int64_t lineSizeBits)
 Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.
static AffineMap getRegisterIndexToTileOffsetMap (int64_t lineSize, Type elementType, ArrayRef< int64_t > operandShape, bool isAccumulator, int64_t elementsPerRegister, AffineExpr logicalValueId)

Variables

static constexpr int64_t kThreadsPerRow = 4
 There are always 4 threads per [128|256|512] bit row.
static constexpr int64_t kNumRowsPerTile = 8

Function Documentation

◆ getRegisterIndexToTileOffsetMap()

AffineMap getRegisterIndexToTileOffsetMap ( int64_t lineSize,
Type elementType,
ArrayRef< int64_t > operandShape,
bool isAccumulator,
int64_t elementsPerRegister,
AffineExpr logicalValueId )
static

◆ getTileShape()

std::array< int64_t, 2 > getTileShape ( ArrayRef< int64_t > operandShape,
Type elementType,
int64_t lineSizeBits )
static

Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.

Definition at line 37 of file MMAUtils.cpp.

References mlir::Type::getIntOrFloatBitWidth(), and kNumRowsPerTile.

Referenced by getRegisterIndexToTileOffsetMap().

◆ inferNumRegistersPerMatrixFragment()

int64_t inferNumRegistersPerMatrixFragment ( const WarpMatrixInfo & type)
static

Returns the number of registers which compose a matrix fragment held by a single thread.

Definition at line 27 of file MMAUtils.cpp.

References kNumRowsPerTile, and mlir::nvgpu::WarpMatrixInfo::vectorType.

◆ isAccumulatorOrResult()

bool isAccumulatorOrResult ( MatMulOperandRole operandType)
static

Definition at line 21 of file MMAUtils.cpp.

References mlir::nvgpu::C.

Variable Documentation

◆ kNumRowsPerTile

int64_t kNumRowsPerTile = 8
staticconstexpr

Definition at line 19 of file MMAUtils.cpp.

Referenced by getTileShape(), and inferNumRegistersPerMatrixFragment().

◆ kThreadsPerRow

int64_t kThreadsPerRow = 4
staticconstexpr

There are always 4 threads per [128|256|512] bit row.

Definition at line 18 of file MMAUtils.cpp.