MLIR  20.0.0git
MMAUtils.h
Go to the documentation of this file.
1 //===-- MMAUtils.h - MLIR NVGPU dialect utilities for MMA operations-------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides utilities to assist in the lowering of other dialects
10 // (e.g. Vector) to `nvgpu.mma.*` dialect operations.
11 //
12 //===----------------------------------------------------------------------===//
13 #ifndef MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H
14 #define MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H
15 
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/Types.h"
20 
21 namespace mlir {
22 namespace nvgpu {
23 
24 /// Represents the role of an operand in an MMA instruction:
25 /// `result := matmul(A, B) + C`
26 enum class MatMulOperandRole : int32_t { A = 0, B, C };
27 
28 /// Returns the first user of the `op` that is vector.contract. If no
29 /// vector.contract user exists, return failure.
30 FailureOr<vector::ContractionOp> getUserContract(Operation *op);
31 
32 /// Collects information about a warp-level matrix operand represented by a
33 /// VectorType.
35  VectorType vectorType;
37 };
38 
39 /// If `op` is a `vector.transfer_write`, return the `WarpMatrixInfo` for the
40 /// vector operand. If op is a `vector.transfer_read`, `vector.contraction`, or
41 /// `arith.constant`, return the `WarpMatrixInfo` corresponding to the result.
42 /// Otherwise, return failure.
43 FailureOr<WarpMatrixInfo> getWarpMatrixInfo(Operation *op);
44 
45 /// Returns the number of bits in a single tile row. It is either 128, 256, or
46 /// 512 bits depending on the data type and` whether the operand is an
47 /// accumulator/result operand
48 int64_t inferTileWidthInBits(const WarpMatrixInfo &type);
49 
50 /// Specifies information about the registers which compose a matrix fragment
51 /// according to the PTX documentation.
57 };
58 
59 /// Returns a FragmentElementInfo struct describing the register types for the
60 /// given matrix fragment type.
61 FailureOr<FragmentElementInfo>
63 
64 /// Returns an AffineMap which maps a two dimensions representing (laneId,
65 /// logicalValueId) and returns two results representing offsets within a
66 /// matrix operand. The offsets point to the values the thread is responsible
67 /// for (AKA the matrix fragment values) during a warp-collective matrix
68 /// operation. For a visual reference of this LaneId -> (row, col) mapping,
69 /// please see NVIDIA's PTX documentation:
70 /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma
71 FailureOr<AffineMap>
73  const WarpMatrixInfo &fragmentType);
74 
75 /// Encapsulates the parameters needed to lower a `nvgpu.ldmatrix` operation to
76 /// `nvvm.ldmatrix`.
78  VectorType fragmentType;
79  bool isAccum;
80  int64_t numTiles;
81  vector::IteratorType contiguousDimType;
82  NVVM::MMALayout targetLayout;
83 };
84 
85 /// Given `type` that contains info for a warp-matrix operand and whether or not
86 /// the load is a transposed load, return the LdMatrixParams.
87 FailureOr<LdMatrixParams> getLdMatrixParams(const WarpMatrixInfo &type,
88  bool transpose);
89 /// Returns an AffineMap which maps a single dimension representing the laneId
90 /// to two results representing offsets within the matrix operand that should
91 /// be the pointer locations a thread should pass to the ldmatrix instruction.
92 FailureOr<AffineMap>
94  const LdMatrixParams &params);
95 
96 /// Returns whether the `vector.transfer_read` instruction can be interpreted
97 /// as a warp-level cooperative matrix load operation. This function is meant to
98 /// be used to establish whether `op` is part of a chain of such warp-level
99 /// operations.
100 bool canLowerToWarpMatrixOperation(vector::TransferReadOp op);
101 
102 /// Returns whether the `vector.transfer_write` instruction can be interpreted
103 /// as a warp-level cooperative matrix store operation. This function is meant
104 /// to be used to establish whether `op` is part of a chain of such warp-level
105 /// operations.
106 bool canLowerToWarpMatrixOperation(vector::TransferWriteOp op);
107 
108 } // namespace nvgpu
109 } // namespace mlir
110 
111 #endif // MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:216
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
Definition: MMAUtils.cpp:87
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
Definition: MMAUtils.cpp:50
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
Definition: MMAUtils.cpp:173
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
Definition: MMAUtils.cpp:58
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams &params)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
Definition: MMAUtils.cpp:238
MatMulOperandRole
Represents the role of an operand in an MMA instruction: result := matmul(A, B) + C
Definition: MMAUtils.h:26
FailureOr< LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
Given type that contains info for a warp-matrix operand and whether or not the load is a transposed l...
Definition: MMAUtils.cpp:209
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
Definition: MMAUtils.cpp:100
bool canLowerToWarpMatrixOperation(vector::TransferReadOp op)
Returns whether the vector.transfer_read instruction can be interpreted as a warp-level cooperative m...
Definition: MMAUtils.cpp:276
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:22
Include the generated interface declarations.
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
Definition: MMAUtils.h:52
Encapsulates the parameters needed to lower a nvgpu.ldmatrix operation to nvvm.ldmatrix.
Definition: MMAUtils.h:77
NVVM::MMALayout targetLayout
Definition: MMAUtils.h:82
vector::IteratorType contiguousDimType
Definition: MMAUtils.h:81
Collects information about a warp-level matrix operand represented by a VectorType.
Definition: MMAUtils.h:34
MatMulOperandRole operandRole
Definition: MMAUtils.h:36