MLIR 22.0.0git
MMAUtils.h
Go to the documentation of this file.
1//===-- MMAUtils.h - MLIR NVGPU dialect utilities for MMA operations-------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file provides utilities to assist in the lowering of other dialects
10// (e.g. Vector) to `nvgpu.mma.*` dialect operations.
11//
12//===----------------------------------------------------------------------===//
13#ifndef MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H
14#define MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H
15
19#include "mlir/IR/Types.h"
20
21namespace mlir {
22namespace nvgpu {
23
24/// Represents the role of an operand in an MMA instruction:
25/// `result := matmul(A, B) + C`
26enum class MatMulOperandRole : int32_t { A = 0, B, C };
27
28/// Returns the first user of the `op` that is vector.contract. If no
29/// vector.contract user exists, return failure.
30FailureOr<vector::ContractionOp> getUserContract(Operation *op);
31
32/// Collects information about a warp-level matrix operand represented by a
33/// VectorType.
38
39/// If `op` is a `vector.transfer_write`, return the `WarpMatrixInfo` for the
40/// vector operand. If op is a `vector.transfer_read`, `vector.contraction`, or
41/// `arith.constant`, return the `WarpMatrixInfo` corresponding to the result.
42/// Otherwise, return failure.
43FailureOr<WarpMatrixInfo> getWarpMatrixInfo(Operation *op);
44
45/// Returns the number of bits in a single tile row. It is either 128, 256, or
46/// 512 bits depending on the data type and` whether the operand is an
47/// accumulator/result operand
48int64_t inferTileWidthInBits(const WarpMatrixInfo &type);
49
50/// Specifies information about the registers which compose a matrix fragment
51/// according to the PTX documentation.
52struct FragmentElementInfo {
53 Type registerLLVMType;
54 int64_t elementsPerRegister;
55 int64_t registerWidthBits;
56 int64_t numRegistersPerFragment;
57};
58
59/// Returns a FragmentElementInfo struct describing the register types for the
60/// given matrix fragment type.
61FailureOr<FragmentElementInfo>
62getMmaSyncRegisterType(const WarpMatrixInfo &type);
63
64/// Returns an AffineMap which maps a two dimensions representing (laneId,
65/// logicalValueId) and returns two results representing offsets within a
66/// matrix operand. The offsets point to the values the thread is responsible
67/// for (AKA the matrix fragment values) during a warp-collective matrix
68/// operation. For a visual reference of this LaneId -> (row, col) mapping,
69/// please see NVIDIA's PTX documentation:
70/// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma
71FailureOr<AffineMap>
72getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc,
73 const WarpMatrixInfo &fragmentType);
74
75/// Encapsulates the parameters needed to lower a `nvgpu.ldmatrix` operation to
76/// `nvvm.ldmatrix`.
77struct LdMatrixParams {
78 VectorType fragmentType;
79 bool isAccum;
80 int64_t numTiles;
81 vector::IteratorType contiguousDimType;
82 NVVM::MMALayout targetLayout;
83};
84
85/// Given `type` that contains info for a warp-matrix operand and whether or not
86/// the load is a transposed load, return the LdMatrixParams.
87FailureOr<LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
88 bool transpose);
89/// Returns an AffineMap which maps a single dimension representing the laneId
90/// to two results representing offsets within the matrix operand that should
91/// be the pointer locations a thread should pass to the ldmatrix instruction.
92FailureOr<AffineMap>
93getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc,
94 const LdMatrixParams &params);
95
96/// Returns whether the `vector.transfer_read` instruction can be interpreted
97/// as a warp-level cooperative matrix load operation. This function is meant to
98/// be used to establish whether `op` is part of a chain of such warp-level
99/// operations.
100bool canLowerToWarpMatrixOperation(vector::TransferReadOp op);
101
102/// Returns whether the `vector.transfer_write` instruction can be interpreted
103/// as a warp-level cooperative matrix store operation. This function is meant
104/// to be used to establish whether `op` is part of a chain of such warp-level
105/// operations.
106bool canLowerToWarpMatrixOperation(vector::TransferWriteOp op);
107
108} // namespace nvgpu
109} // namespace mlir
110
111#endif // MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
Definition MMAUtils.cpp:48
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
Definition MMAUtils.cpp:56
bool canLowerToWarpMatrixOperation(vector::TransferWriteOp op)
Returns the number of bits in a single tile row.
Definition MMAUtils.cpp:296
MatMulOperandRole
Represents the role of an operand in an MMA instruction: result := matmul(A, B) + C
Definition MMAUtils.h:26
Include the generated interface declarations.
Collects information about a warp-level matrix operand represented by a VectorType.
Definition MMAUtils.h:34
MatMulOperandRole operandRole
Definition MMAUtils.h:36