MLIR  17.0.0git
Functions | Variables
MMAUtils.cpp File Reference
#include "mlir/Dialect/NVGPU/Utils/MMAUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+ Include dependency graph for MMAUtils.cpp:

Go to the source code of this file.

Functions

static bool isAccumulatorOrResult (MatMulOperandRole operandType)
 
static int64_t inferNumRegistersPerMatrixFragment (const WarpMatrixInfo &type)
 Returns the number of registers which compose a matrix fragment held by a single thread. More...
 
static std::array< int64_t, 2 > getTileShape (ArrayRef< int64_t > operandShape, Type elementType, int64_t lineSizeBits)
 Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape. More...
 
static AffineMap getRegisterIndexToTileOffsetMap (int64_t lineSize, Type elementType, ArrayRef< int64_t > operandShape, bool isAccumulator, int64_t elementsPerRegister, AffineExpr logicalValueId)
 

Variables

static constexpr int64_t kThreadsPerRow = 4
 There are always 4 threads per [128|256|512] bit row. More...
 
static constexpr int64_t kNumRowsPerTile = 8
 

Function Documentation

◆ getRegisterIndexToTileOffsetMap()

static AffineMap getRegisterIndexToTileOffsetMap ( int64_t  lineSize,
Type  elementType,
ArrayRef< int64_t >  operandShape,
bool  isAccumulator,
int64_t  elementsPerRegister,
AffineExpr  logicalValueId 
)
static

◆ getTileShape()

static std::array<int64_t, 2> getTileShape ( ArrayRef< int64_t >  operandShape,
Type  elementType,
int64_t  lineSizeBits 
)
static

Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.

Definition at line 39 of file MMAUtils.cpp.

References mlir::Type::getIntOrFloatBitWidth(), and kNumRowsPerTile.

Referenced by getRegisterIndexToTileOffsetMap().

◆ inferNumRegistersPerMatrixFragment()

static int64_t inferNumRegistersPerMatrixFragment ( const WarpMatrixInfo type)
static

Returns the number of registers which compose a matrix fragment held by a single thread.

Definition at line 29 of file MMAUtils.cpp.

References mlir::nvgpu::inferTileWidthInBits(), kNumRowsPerTile, and mlir::nvgpu::WarpMatrixInfo::vectorType.

Referenced by mlir::nvgpu::getMmaSyncRegisterType().

◆ isAccumulatorOrResult()

static bool isAccumulatorOrResult ( MatMulOperandRole  operandType)
static

Variable Documentation

◆ kNumRowsPerTile

constexpr int64_t kNumRowsPerTile = 8
staticconstexpr

◆ kThreadsPerRow

constexpr int64_t kThreadsPerRow = 4
staticconstexpr

There are always 4 threads per [128|256|512] bit row.

Definition at line 20 of file MMAUtils.cpp.

Referenced by mlir::nvgpu::getLaneIdAndValueIdToOperandCoord().