33 (shape[1] * type.
vectorType.getElementType().getIntOrFloatBitWidth()) /
41 int64_t lineSizeBits) {
52 if (
auto contractOp = dyn_cast<vector::ContractionOp>(user))
62 if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
64 }
else if (isa<vector::TransferReadOp, vector::ContractionOp,
65 vector::ExtractStridedSliceOp, arith::ConstantOp>(op)) {
69 <<
"unhandled operation type in nvgpu.mma.sync conversion path";
76 if (failed(contractOp))
79 if ((*contractOp).getLhs() == op->
getResult(0))
81 else if ((*contractOp).getRhs() == op->
getResult(0))
94 return isAcc ? 512 : 256;
99 FailureOr<FragmentElementInfo>
105 if (elType.
isF16()) {
113 if (elType.
isF64()) {
143 if (elType.
isF32()) {
158 int64_t elementsPerRegister,
160 const int64_t elementsPerLine =
162 const std::array<int64_t, 2> num8x128bTiles =
167 {(registerIdx % num8x128bTiles[0]) * 8,
168 (registerIdx.
floorDiv(num8x128bTiles[0])) * elementsPerLine},
177 FailureOr<nvgpu::FragmentElementInfo> regInfo =
183 const int64_t elementsPerRegister =
184 regInfo->registerWidthBits / elementBitWidth;
193 lineSize, elementType, operandShape,
201 auto tileRow = registerIndexToTileCoord.
getResult(0);
202 auto tileCol = registerIndexToTileCoord.
getResult(1);
205 (logicalValueIdDim % elementsPerRegister)});
208 FailureOr<nvgpu::LdMatrixParams>
221 : vector::IteratorType::reduction;
241 const int bitsPerElement =
static_cast<int>(
242 params.
fragmentType.getElementType().getIntOrFloatBitWidth());
243 const int kElementsPer128b = (128 / bitsPerElement);
258 AffineExpr strided = d0 % (operandShape[idx]);
265 return makeMap({strided, contiguous});
271 return makeMap({contiguous, strided});
277 if (op.getMask() || op.hasOutOfBoundsDim())
279 VectorType type = op.getType();
285 if (!type.hasStaticShape() || type.getRank() != 2)
292 auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
300 return strides.back() == 1;
304 if (op.getMask() || op.hasOutOfBoundsDim() || op.getTransferRank() == 0)
306 VectorType type = op.getVectorType();
307 if (!type.hasStaticShape() || type.getRank() != 2)
312 if (!op.getPermutationMap().isMinorIdentity())
316 auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
324 return strides.back() == 1;
static constexpr int64_t kNumRowsPerTile
static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize, Type elementType, ArrayRef< int64_t > operandShape, bool isAccumulator, int64_t elementsPerRegister, AffineExpr logicalValueId)
static constexpr int64_t kThreadsPerRow
There are always 4 threads per [128|256|512] bit row.
static bool isAccumulatorOrResult(MatMulOperandRole operandType)
static int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type)
Returns the number of registers which compose a matrix fragment held by a single thread.
static std::array< int64_t, 2 > getTileShape(ArrayRef< int64_t > operandShape, Type elementType, int64_t lineSizeBits)
Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineExpr getResult(unsigned idx) const
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
user_range getUsers()
Returns a range of all users.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Type getType() const
Return the type of this value.
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams ¶ms)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
MatMulOperandRole
Represents the role of an operand in an MMA instruction: result := matmul(A, B) + C
FailureOr< LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
Given type that contains info for a warp-matrix operand and whether or not the load is a transposed l...
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
bool canLowerToWarpMatrixOperation(vector::TransferReadOp op)
Returns whether the vector.transfer_read instruction can be interpreted as a warp-level cooperative m...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
Encapsulates the parameters needed to lower a nvgpu.ldmatrix operation to nvvm.ldmatrix.
NVVM::MMALayout targetLayout
vector::IteratorType contiguousDimType
Collects information about a warp-level matrix operand represented by a VectorType.
MatMulOperandRole operandRole