MLIR
20.0.0git
|
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations. More...
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
Public Types | |
using | Base = StorageUserBase< ConcreteT, BaseT, StorageT, UniquerT, Traits... > |
Utility declarations for the concrete attribute class. More... | |
Public Types inherited from mlir::detail::StorageUserBase< ConcreteT, BaseT, StorageT, UniquerT, Traits > | |
using | Base = StorageUserBase< ConcreteT, BaseT, StorageT, UniquerT, Traits... > |
Utility declarations for the concrete attribute class. More... | |
using | ImplType = StorageT |
using | HasTraitFn = bool(*)(TypeID) |
Public Member Functions | |
unsigned | getNumDims () const |
Get number of dims. More... | |
ArrayRef< int64_t > | getShape () const |
Get shape of the matrix. More... | |
Type | getElementType () const |
Get elementType of a single element. More... | |
StringRef | getOperand () const |
The general form of operation this type supports is given by the equation C += A*B. More... | |
Public Member Functions inherited from mlir::detail::StorageUserBase< ConcreteT, BaseT, StorageT, UniquerT, Traits > | |
ImplType * | getImpl () const |
Utility for easy access to the storage instance. More... | |
Static Public Member Functions | |
static MMAMatrixType | get (ArrayRef< int64_t > shape, Type elementType, StringRef operand) |
Get MMAMatrixType and verify construction Invariants. More... | |
static MMAMatrixType | getChecked (function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand) |
Get MMAMatrixType at a particular location and verify construction Invariants. More... | |
static bool | isValidElementType (Type elementType) |
Check if a type is valid a MMAMatrixType elementType. More... | |
static LogicalResult | verifyInvariants (function_ref< InFlightDiagnostic()> emitError, ArrayRef< int64_t > shape, Type elementType, StringRef operand) |
Verify that shape and elementType are actually allowed for the MMAMatrixType. More... | |
Static Public Member Functions inherited from mlir::detail::StorageUserBase< ConcreteT, BaseT, StorageT, UniquerT, Traits > | |
static TypeID | getTypeID () |
Return a unique identifier for the concrete type. More... | |
template<typename T > | |
static bool | classof (T val) |
Provide an implementation of 'classof' that compares the type id of the provided value with that of the concrete type. More... | |
static detail::InterfaceMap | getInterfaceMap () |
Returns an interface map for the interfaces registered to this storage user. More... | |
static HasTraitFn | getHasTraitFn () |
Returns the function that returns true if the given Trait ID matches the IDs of any of the traits defined by the storage user. More... | |
static auto | getWalkImmediateSubElementsFn () |
Returns a function that walks immediate sub elements of a given instance of the storage user. More... | |
static auto | getReplaceImmediateSubElementsFn () |
Returns a function that replaces immediate sub elements of a given instance of the storage user. More... | |
template<typename... IfaceModels> | |
static void | attachInterface (MLIRContext &context) |
Attach the given models as implementations of the corresponding interfaces for the concrete storage user class. More... | |
template<typename... Args> | |
static ConcreteT | get (MLIRContext *ctx, Args &&...args) |
Get or create a new ConcreteT instance within the ctx. More... | |
template<typename... Args> | |
static ConcreteT | getChecked (const Location &loc, Args &&...args) |
Get or create a new ConcreteT instance within the ctx, defined at the given, potentially unknown, location. More... | |
template<typename... Args> | |
static ConcreteT | getChecked (function_ref< InFlightDiagnostic()> emitErrorFn, MLIRContext *ctx, Args... args) |
Get or create a new ConcreteT instance within the ctx. More... | |
static ConcreteT | getFromOpaquePointer (const void *ptr) |
Get an instance of the concrete type from a void pointer. More... | |
Static Public Attributes | |
static constexpr StringLiteral | name = "gpu.mma_matrix" |
Additional Inherited Members | |
Protected Member Functions inherited from mlir::detail::StorageUserBase< ConcreteT, BaseT, StorageT, UniquerT, Traits > | |
template<typename... Args> | |
LogicalResult | mutate (Args &&...args) |
Mutate the current storage instance. More... | |
Static Protected Member Functions inherited from mlir::detail::StorageUserBase< ConcreteT, BaseT, StorageT, UniquerT, Traits > | |
template<typename... Args> | |
static LogicalResult | verifyInvariants (Args... args) |
Default implementation that just returns success. More... | |
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
MMAMatrices are taken as direct operands by these operations and are also produced as results. These matrices are meant to reside in the registers. A limited number of pointwise operations can be performed on these matrices, i.e., operations which operate uniformly on all the elements in the matrix and do not change the order of matrix elements. The above conditions exist because the layout of matrix elements inside the matrix is opaque i.e., the elements may be present in the matrix in any order. The general usage of this type is shown as follows:-
%0 = gpu.subgroup_mma_load_matrix arg0[c0, c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
The MMAMatrixType describes the shape of the matrix being loaded and the operand being loaded too. The operand needs to be specified to aid the lowering of this type to dialects such as NVVM where each workitem may hold different amount of elements depending on the elementType of the matrix. For e.g., Each workitem holds 4 vector<2xf16>s for f16 data type and 8 f32s for f32 data type of MMAMatrix. Some other instances of usage are:-
%3 = gpu.subgroup_mma_compute %0, %1, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf32, "COp">
gpu.subgroup_mma_store_matrix %3, arg22[c0, c0] {leadDimension = 16 : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
Definition at line 130 of file GPUDialect.h.
using mlir::detail::StorageUserBase< ConcreteT, BaseT, StorageT, UniquerT, Traits >::Base = StorageUserBase<ConcreteT, BaseT, StorageT, UniquerT, Traits...> |
Utility declarations for the concrete attribute class.
Definition at line 100 of file StorageUniquerSupport.h.
|
static |
Get MMAMatrixType and verify construction Invariants.
Definition at line 122 of file GPUDialect.cpp.
References mlir::get(), and mlir::Type::getContext().
Referenced by convertBroadcastOp(), convertConstantOp(), convertElementwiseOp(), and convertTransferReadOp().
|
static |
Get MMAMatrixType at a particular location and verify construction Invariants.
Definition at line 128 of file GPUDialect.cpp.
References mlir::emitError(), and mlir::Type::getContext().
Type MMAMatrixType::getElementType | ( | ) | const |
Get elementType of a single element.
Definition at line 141 of file GPUDialect.cpp.
References mlir::detail::StorageUserBase< ConcreteT, BaseT, StorageT, UniquerT, Traits >::getImpl().
Referenced by mlir::populateMMAToSPIRVCoopMatrixTypeConversion().
unsigned MMAMatrixType::getNumDims | ( | ) | const |
Get number of dims.
Definition at line 135 of file GPUDialect.cpp.
References mlir::detail::StorageUserBase< ConcreteT, BaseT, StorageT, UniquerT, Traits >::getImpl().
StringRef MMAMatrixType::getOperand | ( | ) | const |
The general form of operation this type supports is given by the equation C += A*B.
This function returns which operand in the given equation is held by this type. String returned can be one of"AOp", "BOp" and "COp".
Definition at line 143 of file GPUDialect.cpp.
References mlir::detail::StorageUserBase< ConcreteT, BaseT, StorageT, UniquerT, Traits >::getImpl().
Referenced by mlir::convertMMAToLLVMType(), and mlir::populateMMAToSPIRVCoopMatrixTypeConversion().
ArrayRef< int64_t > MMAMatrixType::getShape | ( | ) | const |
Get shape of the matrix.
Definition at line 137 of file GPUDialect.cpp.
References mlir::detail::StorageUserBase< ConcreteT, BaseT, StorageT, UniquerT, Traits >::getImpl().
Referenced by mlir::convertMMAToLLVMType(), and mlir::populateMMAToSPIRVCoopMatrixTypeConversion().
|
static |
Check if a type is valid a MMAMatrixType elementType.
Definition at line 145 of file GPUDialect.cpp.
References mlir::Type::isF16(), mlir::Type::isF32(), mlir::Type::isInteger(), mlir::Type::isSignedInteger(), and mlir::Type::isUnsignedInteger().
Referenced by verifyInvariants().
|
static |
Verify that shape and elementType are actually allowed for the MMAMatrixType.
Definition at line 152 of file GPUDialect.cpp.
References mlir::emitError(), and isValidElementType().
|
staticconstexpr |
Definition at line 135 of file GPUDialect.h.