MLIR 22.0.0git
MMAUtils.cpp
Go to the documentation of this file.
1//===- MMAUtils.cpp - MLIR NVGPU dialect utils for MMA operations----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
9
13
14using namespace mlir;
15using namespace mlir::nvgpu;
16
17/// There are always 4 threads per [128|256|512] bit row.
18static constexpr int64_t kThreadsPerRow = 4;
19static constexpr int64_t kNumRowsPerTile = 8;
20
21static bool isAccumulatorOrResult(MatMulOperandRole operandType) {
22 return operandType == MatMulOperandRole::C;
23}
24
25/// Returns the number of registers which compose a matrix fragment held by a
26/// single thread.
28 int64_t lineSize = inferTileWidthInBits(type);
29 auto shape = type.vectorType.getShape();
30 return (shape[0] / kNumRowsPerTile) *
31 (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) /
32 lineSize;
33}
34
35/// Returns the number of 8 x [128|256|512] bit tiles that compose the given
36/// operand shape.
37static std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape,
38 Type elementType,
39 int64_t lineSizeBits) {
40 // For each 8x128bit square, a thread is responsible for one 32bit register.
41 return {operandShape[0] / kNumRowsPerTile,
42 (operandShape[1] * elementType.getIntOrFloatBitWidth()) /
43 lineSizeBits};
44}
45
46/// Returns the first user of the `op` that is vector.contract. If no
47/// vector.contract user exists, return failure.
48FailureOr<vector::ContractionOp> nvgpu::getUserContract(Operation *op) {
49 for (Operation *user : op->getUsers()) {
50 if (auto contractOp = dyn_cast<vector::ContractionOp>(user))
51 return contractOp;
52 }
53 return failure();
54}
55
56FailureOr<WarpMatrixInfo> nvgpu::getWarpMatrixInfo(Operation *op) {
57 WarpMatrixInfo info;
58
59 // Determine the vector type at warp-level.
60 if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
61 info.vectorType = writeOp.getVectorType();
62 } else if (isa<vector::TransferReadOp, vector::ContractionOp,
63 vector::ExtractStridedSliceOp, arith::ConstantOp>(op)) {
64 info.vectorType = cast<VectorType>(op->getResult(0).getType());
65 } else {
66 return op->emitError()
67 << "unhandled operation type in nvgpu.mma.sync conversion path";
68 }
69
70 // Determine the operand role. We assume it is an accumulator/result unless it
71 // is directly consumed by a `vector.contract` op.
73 FailureOr<vector::ContractionOp> contractOp = getUserContract(op);
74 if (failed(contractOp))
75 return info;
76
77 if ((*contractOp).getLhs() == op->getResult(0))
79 else if ((*contractOp).getRhs() == op->getResult(0))
81
82 return info;
83}
84
85int64_t nvgpu::inferTileWidthInBits(const WarpMatrixInfo &type) {
86 bool isAcc = isAccumulatorOrResult(type.operandRole);
87 Type elType = type.vectorType.getElementType();
88 if (isAcc && elType.getIntOrFloatBitWidth() == 32) {
89 return 256;
90 }
91 if (elType.getIntOrFloatBitWidth() == 64) {
92 return isAcc ? 512 : 256;
93 }
94 return 128;
95}
96
97FailureOr<FragmentElementInfo>
98nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) {
99 MLIRContext *ctx = type.vectorType.getContext();
100 const bool isAccum = isAccumulatorOrResult(type.operandRole);
101
102 Type elType = type.vectorType.getElementType();
103 if (elType.isF16()) {
104 return FragmentElementInfo{VectorType::get(2, Float16Type::get(ctx)), 2, 32,
106 }
107
108 // f64 operand
109 Type f64Ty = Float64Type::get(ctx);
110 if (elType.isF64()) {
111 return isAccum
112 ? FragmentElementInfo{VectorType::get(2, f64Ty), 2, 128,
114 : FragmentElementInfo{f64Ty, 1, 64,
116 }
117
118 // int8 operand
119 if (elType.isInteger(8)) {
120 return FragmentElementInfo{VectorType::get(4, IntegerType::get(ctx, 8)), 4,
122 }
123
124 // int4 operand
125 if (elType.isInteger(4)) {
126 return FragmentElementInfo{VectorType::get(8, IntegerType::get(ctx, 4)), 8,
128 }
129
130 // Integer 32bit acc operands
131 if (elType.isInteger(32)) {
132 return FragmentElementInfo{VectorType::get(2, IntegerType::get(ctx, 32)), 2,
134 }
135
136 // Floating point 32bit operands
137 if (elType.isF32()) {
138 Type f32Ty = Float32Type::get(ctx);
139 return isAccum
140 ? FragmentElementInfo{VectorType::get(2, f32Ty), 2, 64,
142 : FragmentElementInfo{f32Ty, 1, 32,
144 }
145 return failure();
146}
147
149 Type elementType,
150 ArrayRef<int64_t> operandShape,
151 bool isAccumulator,
152 int64_t elementsPerRegister,
153 AffineExpr logicalValueId) {
154 const int64_t elementsPerLine =
155 lineSize / elementType.getIntOrFloatBitWidth();
156 const std::array<int64_t, 2> num8x128bTiles =
157 getTileShape(operandShape, elementType, lineSize);
158 AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister);
159 return AffineMap::get(
160 2, 0,
161 {(registerIdx % num8x128bTiles[0]) * 8,
162 (registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine},
163 elementType.getContext());
164}
165
166FailureOr<AffineMap>
167nvgpu::getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc,
168 const WarpMatrixInfo &fragmentType) {
169 Type elementType = fragmentType.vectorType.getElementType();
170 ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape();
171 FailureOr<FragmentElementInfo> regInfo = getMmaSyncRegisterType(fragmentType);
172 if (failed(regInfo))
173 return failure();
174
175 const int64_t elementBitWidth = elementType.getIntOrFloatBitWidth();
176 const int64_t elementsPerRegister =
177 regInfo->registerWidthBits / elementBitWidth;
178 const int64_t lineSize = inferTileWidthInBits(fragmentType);
179
180 AffineExpr laneId, logicalValueIdDim;
181 bindDims(builder.getContext(), laneId, logicalValueIdDim);
182
183 // Determine what register logicalValueId corresponds to. Use that as a
184 // linear index into the coordinate mapping `index -> (tile row, tile col)`.
185 AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap(
186 lineSize, elementType, operandShape,
187 isAccumulatorOrResult(fragmentType.operandRole), elementsPerRegister,
188 logicalValueIdDim);
189
190 auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
191 return AffineMap::get(2, 0, dimExprs, builder.getContext());
192 };
193
194 auto tileRow = registerIndexToTileCoord.getResult(0);
195 auto tileCol = registerIndexToTileCoord.getResult(1);
196 return makeMap({tileRow + laneId.floorDiv(kThreadsPerRow),
197 tileCol + (laneId % kThreadsPerRow) * elementsPerRegister +
198 (logicalValueIdDim % elementsPerRegister)});
199}
200
201FailureOr<LdMatrixParams> nvgpu::getLdMatrixParams(const WarpMatrixInfo &type,
202 bool transpose) {
203 LdMatrixParams params;
204 Type elType = type.vectorType.getElementType();
205 params.fragmentType = type.vectorType;
206 if (type.operandRole == MatMulOperandRole::A ||
208 params.targetLayout = NVVM::MMALayout::row;
209 } else {
210 params.targetLayout = NVVM::MMALayout::col;
211 }
212 ArrayRef<int64_t> shape = type.vectorType.getShape();
213 params.contiguousDimType = transpose ? vector::IteratorType::parallel
214 : vector::IteratorType::reduction;
215
216 if (params.contiguousDimType == vector::IteratorType::reduction) {
217 params.numTiles = (shape[0] / kNumRowsPerTile) *
218 ((shape[1] * elType.getIntOrFloatBitWidth()) / 128);
219 } else {
220 params.numTiles = (shape[1] / kNumRowsPerTile) *
221 ((shape[0] * elType.getIntOrFloatBitWidth()) / 128);
222 }
223
224 if (params.numTiles == 0)
225 return failure();
226
227 return params;
228}
229
230FailureOr<AffineMap>
231nvgpu::getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc,
232 const LdMatrixParams &params) {
233 // One thread per 128b row.
234 const int bitsPerElement = static_cast<int>(
235 params.fragmentType.getElementType().getIntOrFloatBitWidth());
236 const int kElementsPer128b = (128 / bitsPerElement);
237 ArrayRef<int64_t> operandShape = params.fragmentType.getShape();
238 AffineExpr d0 = getAffineDimExpr(0, builder.getContext());
239
240 auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
241 return AffineMap::get(1, 0, dimExprs, builder.getContext());
242 };
243
244 // Index `idx` in vectorType `operandShape` maps to the strided dimension of
245 // the `srcMemref` memory of the LdMatrixOp.
246 int idx =
247 (params.contiguousDimType == vector::IteratorType::reduction) ? 0 : 1;
248
249 // Affine expr in strided and contiguous dimension encodes the coordinate
250 // mapping for the element a thread points to for warp-wide LdMatrixOp.
251 AffineExpr strided = d0 % (operandShape[idx]);
252 AffineExpr contiguous = d0.floorDiv(operandShape[idx]) * (kElementsPer128b);
253
254 // This case corresponds to row-major matrixA or col-major matrixB or
255 // row-major matrixC. This is when the memory layout in `srcMemref`
256 // match mma.sync hardware vector register operand layout.
257 if (params.contiguousDimType == vector::IteratorType::reduction)
258 return makeMap({strided, contiguous});
259
260 // This case corresponds to col-major matrixA or row-major matrixB or
261 // col-major matrixC. This is when the memory layout in `srcMemref` does not
262 // match mma.sync hardware vector register operand layout.
263 if (params.contiguousDimType == vector::IteratorType::parallel)
264 return makeMap({contiguous, strided});
265
266 return failure();
267}
268
269bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferReadOp op) {
270 if (op.getMask() || op.hasOutOfBoundsDim())
271 return false;
272 VectorType type = op.getType();
273 // The result type should be 2D. Note that it is possible to expand support so
274 // that we are robust to extra unit dimensions that failed to fold, but that
275 // would significantly increase downstream code complexity in the conversion
276 // step. For now, we rely on other patterns to ensure canonical 2D form is
277 // used when targeting the `nvgpu.mma.sync` lowering path.
278 if (!type.hasStaticShape() || type.getRank() != 2)
279 return false;
280
281 // Currently we can't support reads on tensor types because we need stride
282 // information to ensure correctness of downstream assumptions. It is possible
283 // to enable this if caller can assert that tensor will be lowered in a
284 // particular manner.
285 auto sourceType = dyn_cast<MemRefType>(op.getBase().getType());
286 if (!sourceType)
287 return false;
288
289 // Check that the last dimension of the read is contiguous. Note that it is
290 // possible to expand support for this by scalarizing all the loads during
291 // conversion.
292 auto [strides, offset] = sourceType.getStridesAndOffset();
293 return strides.back() == 1;
294}
295
296bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferWriteOp op) {
297 if (op.getMask() || op.hasOutOfBoundsDim() || op.getTransferRank() == 0)
298 return false;
299 VectorType type = op.getVectorType();
300 if (!type.hasStaticShape() || type.getRank() != 2)
301 return false;
302 // TODO: Currently we rely on lowering to a `vector.store` operation. We could
303 // support the transposed write case by lowering to scalarized `memref.store`
304 // operations.
305 if (!op.getPermutationMap().isMinorIdentity())
306 return false;
307 // Currently we can't support reads on tensor types because we need stride
308 // information to ensure correctness of downstream assumptions.
309 auto sourceType = dyn_cast<MemRefType>(op.getBase().getType());
310 if (!sourceType)
311 return false;
312
313 // Check that the last dimension of the target memref is contiguous. Note that
314 // it is possible to expand support for this by scalarizing all the stores
315 // during conversion.
316 auto [strides, offset] = sourceType.getStridesAndOffset();
317 return strides.back() == 1;
318}
static constexpr int64_t kNumRowsPerTile
Definition MMAUtils.cpp:19
static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize, Type elementType, ArrayRef< int64_t > operandShape, bool isAccumulator, int64_t elementsPerRegister, AffineExpr logicalValueId)
Definition MMAUtils.cpp:148
static constexpr int64_t kThreadsPerRow
There are always 4 threads per [128|256|512] bit row.
Definition MMAUtils.cpp:18
static bool isAccumulatorOrResult(MatMulOperandRole operandType)
Definition MMAUtils.cpp:21
static int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type)
Returns the number of registers which compose a matrix fragment held by a single thread.
Definition MMAUtils.cpp:27
static std::array< int64_t, 2 > getTileShape(ArrayRef< int64_t > operandShape, Type elementType, int64_t lineSizeBits)
Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.
Definition MMAUtils.cpp:37
Base type for affine expression.
Definition AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineExpr getResult(unsigned idx) const
MLIRContext * getContext() const
Definition Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
Type getType() const
Return the type of this value.
Definition Value.h:105
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
Definition MMAUtils.cpp:48
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
Definition MMAUtils.cpp:56
bool canLowerToWarpMatrixOperation(vector::TransferWriteOp op)
Returns the number of bits in a single tile row.
Definition MMAUtils.cpp:296
MatMulOperandRole
Represents the role of an operand in an MMA instruction: result := matmul(A, B) + C
Definition MMAUtils.h:26
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Collects information about a warp-level matrix operand represented by a VectorType.
Definition MMAUtils.h:34
MatMulOperandRole operandRole
Definition MMAUtils.h:36