MLIR  21.0.0git
MMAUtils.cpp
Go to the documentation of this file.
1 //===- MMAUtils.cpp - MLIR NVGPU dialect utils for MMA operations----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 
15 
16 using namespace mlir;
17 using namespace mlir::nvgpu;
18 
19 /// There are always 4 threads per [128|256|512] bit row.
20 static constexpr int64_t kThreadsPerRow = 4;
21 static constexpr int64_t kNumRowsPerTile = 8;
22 
23 static bool isAccumulatorOrResult(MatMulOperandRole operandType) {
24  return operandType == MatMulOperandRole::C;
25 }
26 
27 /// Returns the number of registers which compose a matrix fragment held by a
28 /// single thread.
30  int64_t lineSize = inferTileWidthInBits(type);
31  auto shape = type.vectorType.getShape();
32  return (shape[0] / kNumRowsPerTile) *
33  (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) /
34  lineSize;
35 }
36 
37 /// Returns the number of 8 x [128|256|512] bit tiles that compose the given
38 /// operand shape.
39 static std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape,
40  Type elementType,
41  int64_t lineSizeBits) {
42  // For each 8x128bit square, a thread is responsible for one 32bit register.
43  return {operandShape[0] / kNumRowsPerTile,
44  (operandShape[1] * elementType.getIntOrFloatBitWidth()) /
45  lineSizeBits};
46 }
47 
48 /// Returns the first user of the `op` that is vector.contract. If no
49 /// vector.contract user exists, return failure.
50 FailureOr<vector::ContractionOp> nvgpu::getUserContract(Operation *op) {
51  for (Operation *user : op->getUsers()) {
52  if (auto contractOp = dyn_cast<vector::ContractionOp>(user))
53  return contractOp;
54  }
55  return failure();
56 }
57 
58 FailureOr<WarpMatrixInfo> nvgpu::getWarpMatrixInfo(Operation *op) {
59  WarpMatrixInfo info;
60 
61  // Determine the vector type at warp-level.
62  if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
63  info.vectorType = writeOp.getVectorType();
64  } else if (isa<vector::TransferReadOp, vector::ContractionOp,
65  vector::ExtractStridedSliceOp, arith::ConstantOp>(op)) {
66  info.vectorType = cast<VectorType>(op->getResult(0).getType());
67  } else {
68  return op->emitError()
69  << "unhandled operation type in nvgpu.mma.sync conversion path";
70  }
71 
72  // Determine the operand role. We assume it is an accumulator/result unless it
73  // is directly consumed by a `vector.contract` op.
75  FailureOr<vector::ContractionOp> contractOp = getUserContract(op);
76  if (failed(contractOp))
77  return info;
78 
79  if ((*contractOp).getLhs() == op->getResult(0))
81  else if ((*contractOp).getRhs() == op->getResult(0))
83 
84  return info;
85 }
86 
88  bool isAcc = isAccumulatorOrResult(type.operandRole);
89  Type elType = type.vectorType.getElementType();
90  if (isAcc && elType.getIntOrFloatBitWidth() == 32) {
91  return 256;
92  }
93  if (elType.getIntOrFloatBitWidth() == 64) {
94  return isAcc ? 512 : 256;
95  }
96  return 128;
97 }
98 
99 FailureOr<FragmentElementInfo>
101  MLIRContext *ctx = type.vectorType.getContext();
102  const bool isAccum = isAccumulatorOrResult(type.operandRole);
103 
104  Type elType = type.vectorType.getElementType();
105  if (elType.isF16()) {
108  }
109 
110  // f64 operand
111  Type f64Ty = Float64Type::get(ctx);
112  if (elType.isF64()) {
113  return isAccum
114  ? FragmentElementInfo{VectorType::get(2, f64Ty), 2, 128,
116  : FragmentElementInfo{f64Ty, 1, 64,
118  }
119 
120  // int8 operand
121  if (elType.isInteger(8)) {
124  }
125 
126  // int4 operand
127  if (elType.isInteger(4)) {
130  }
131 
132  // Integer 32bit acc operands
133  if (elType.isInteger(32)) {
136  }
137 
138  // Floating point 32bit operands
139  if (elType.isF32()) {
140  Type f32Ty = Float32Type::get(ctx);
141  return isAccum
142  ? FragmentElementInfo{VectorType::get(2, f32Ty), 2, 64,
144  : FragmentElementInfo{f32Ty, 1, 32,
146  }
147  return failure();
148 }
149 
151  Type elementType,
152  ArrayRef<int64_t> operandShape,
153  bool isAccumulator,
154  int64_t elementsPerRegister,
155  AffineExpr logicalValueId) {
156  const int64_t elementsPerLine =
157  lineSize / elementType.getIntOrFloatBitWidth();
158  const std::array<int64_t, 2> num8x128bTiles =
159  getTileShape(operandShape, elementType, lineSize);
160  AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister);
161  return AffineMap::get(
162  2, 0,
163  {(registerIdx % num8x128bTiles[0]) * 8,
164  (registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine},
165  elementType.getContext());
166 }
167 
168 FailureOr<AffineMap>
170  const WarpMatrixInfo &fragmentType) {
171  Type elementType = fragmentType.vectorType.getElementType();
172  ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape();
173  FailureOr<nvgpu::FragmentElementInfo> regInfo =
174  getMmaSyncRegisterType(fragmentType);
175  if (failed(regInfo))
176  return failure();
177 
178  const int64_t elementBitWidth = elementType.getIntOrFloatBitWidth();
179  const int64_t elementsPerRegister =
180  regInfo->registerWidthBits / elementBitWidth;
181  const int64_t lineSize = inferTileWidthInBits(fragmentType);
182 
183  AffineExpr laneId, logicalValueIdDim;
184  bindDims(builder.getContext(), laneId, logicalValueIdDim);
185 
186  // Determine what register logicalValueId corresponds to. Use that as a
187  // linear index into the coordinate mapping `index -> (tile row, tile col)`.
188  AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap(
189  lineSize, elementType, operandShape,
190  isAccumulatorOrResult(fragmentType.operandRole), elementsPerRegister,
191  logicalValueIdDim);
192 
193  auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
194  return AffineMap::get(2, 0, dimExprs, builder.getContext());
195  };
196 
197  auto tileRow = registerIndexToTileCoord.getResult(0);
198  auto tileCol = registerIndexToTileCoord.getResult(1);
199  return makeMap({tileRow + laneId.floorDiv(kThreadsPerRow),
200  tileCol + (laneId % kThreadsPerRow) * elementsPerRegister +
201  (logicalValueIdDim % elementsPerRegister)});
202 }
203 
204 FailureOr<nvgpu::LdMatrixParams>
206  LdMatrixParams params;
207  Type elType = type.vectorType.getElementType();
208  params.fragmentType = type.vectorType;
209  if (type.operandRole == MatMulOperandRole::A ||
211  params.targetLayout = NVVM::MMALayout::row;
212  } else {
213  params.targetLayout = NVVM::MMALayout::col;
214  }
215  ArrayRef<int64_t> shape = type.vectorType.getShape();
216  params.contiguousDimType = transpose ? vector::IteratorType::parallel
217  : vector::IteratorType::reduction;
218 
219  if (params.contiguousDimType == vector::IteratorType::reduction) {
220  params.numTiles = (shape[0] / kNumRowsPerTile) *
221  ((shape[1] * elType.getIntOrFloatBitWidth()) / 128);
222  } else {
223  params.numTiles = (shape[1] / kNumRowsPerTile) *
224  ((shape[0] * elType.getIntOrFloatBitWidth()) / 128);
225  }
226 
227  if (params.numTiles == 0)
228  return failure();
229 
230  return params;
231 }
232 
233 FailureOr<AffineMap>
235  const LdMatrixParams &params) {
236  // One thread per 128b row.
237  const int bitsPerElement = static_cast<int>(
238  params.fragmentType.getElementType().getIntOrFloatBitWidth());
239  const int kElementsPer128b = (128 / bitsPerElement);
240  ArrayRef<int64_t> operandShape = params.fragmentType.getShape();
241  AffineExpr d0 = getAffineDimExpr(0, builder.getContext());
242 
243  auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
244  return AffineMap::get(1, 0, dimExprs, builder.getContext());
245  };
246 
247  // Index `idx` in vectorType `operandShape` maps to the strided dimension of
248  // the `srcMemref` memory of the LdMatrixOp.
249  int idx =
250  (params.contiguousDimType == vector::IteratorType::reduction) ? 0 : 1;
251 
252  // Affine expr in strided and contiguous dimension encodes the coordinate
253  // mapping for the element a thread points to for warp-wide LdMatrixOp.
254  AffineExpr strided = d0 % (operandShape[idx]);
255  AffineExpr contiguous = d0.floorDiv(operandShape[idx]) * (kElementsPer128b);
256 
257  // This case corresponds to row-major matrixA or col-major matrixB or
258  // row-major matrixC. This is when the memory layout in `srcMemref`
259  // match mma.sync hardware vector register operand layout.
260  if (params.contiguousDimType == vector::IteratorType::reduction)
261  return makeMap({strided, contiguous});
262 
263  // This case corresponds to col-major matrixA or row-major matrixB or
264  // col-major matrixC. This is when the memory layout in `srcMemref` does not
265  // match mma.sync hardware vector register operand layout.
266  if (params.contiguousDimType == vector::IteratorType::parallel)
267  return makeMap({contiguous, strided});
268 
269  return failure();
270 }
271 
272 bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferReadOp op) {
273  if (op.getMask() || op.hasOutOfBoundsDim())
274  return false;
275  VectorType type = op.getType();
276  // The result type should be 2D. Note that it is possible to expand support so
277  // that we are robust to extra unit dimensions that failed to fold, but that
278  // would significantly increase downstream code complexity in the conversion
279  // step. For now, we rely on other patterns to ensure canonical 2D form is
280  // used when targeting the `nvgpu.mma.sync` lowering path.
281  if (!type.hasStaticShape() || type.getRank() != 2)
282  return false;
283 
284  // Currently we can't support reads on tensor types because we need stride
285  // information to ensure correctness of downstream assumptions. It is possible
286  // to enable this if caller can assert that tensor will be lowered in a
287  // particular manner.
288  auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
289  if (!sourceType)
290  return false;
291 
292  // Check that the last dimension of the read is contiguous. Note that it is
293  // possible to expand support for this by scalarizing all the loads during
294  // conversion.
295  auto [strides, offset] = sourceType.getStridesAndOffset();
296  return strides.back() == 1;
297 }
298 
299 bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferWriteOp op) {
300  if (op.getMask() || op.hasOutOfBoundsDim() || op.getTransferRank() == 0)
301  return false;
302  VectorType type = op.getVectorType();
303  if (!type.hasStaticShape() || type.getRank() != 2)
304  return false;
305  // TODO: Currently we rely on lowering to a `vector.store` operation. We could
306  // support the transposed write case by lowering to scalarized `memref.store`
307  // operations.
308  if (!op.getPermutationMap().isMinorIdentity())
309  return false;
310  // Currently we can't support reads on tensor types because we need stride
311  // information to ensure correctness of downstream assumptions.
312  auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
313  if (!sourceType)
314  return false;
315 
316  // Check that the last dimension of the target memref is contiguous. Note that
317  // it is possible to expand support for this by scalarizing all the stores
318  // during conversion.
319  auto [strides, offset] = sourceType.getStridesAndOffset();
320  return strides.back() == 1;
321 }
static constexpr int64_t kNumRowsPerTile
Definition: MMAUtils.cpp:21
static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize, Type elementType, ArrayRef< int64_t > operandShape, bool isAccumulator, int64_t elementsPerRegister, AffineExpr logicalValueId)
Definition: MMAUtils.cpp:150
static constexpr int64_t kThreadsPerRow
There are always 4 threads per [128|256|512] bit row.
Definition: MMAUtils.cpp:20
static bool isAccumulatorOrResult(MatMulOperandRole operandType)
Definition: MMAUtils.cpp:23
static int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type)
Returns the number of registers which compose a matrix fragment held by a single thread.
Definition: MMAUtils.cpp:29
static std::array< int64_t, 2 > getTileShape(ArrayRef< int64_t > operandShape, Type elementType, int64_t lineSizeBits)
Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.
Definition: MMAUtils.cpp:39
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:921
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
MLIRContext * getContext() const
Definition: Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:874
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isF32() const
Definition: Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
Type getType() const
Return the type of this value.
Definition: Value.h:105
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
Definition: MMAUtils.cpp:87
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
Definition: MMAUtils.cpp:50
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
Definition: MMAUtils.cpp:169
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
Definition: MMAUtils.cpp:58
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams &params)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
Definition: MMAUtils.cpp:234
MatMulOperandRole
Represents the role of an operand in an MMA instruction: result := matmul(A, B) + C
Definition: MMAUtils.h:26
FailureOr< LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
Given type that contains info for a warp-matrix operand and whether or not the load is a transposed l...
Definition: MMAUtils.cpp:205
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
Definition: MMAUtils.cpp:100
bool canLowerToWarpMatrixOperation(vector::TransferReadOp op)
Returns whether the vector.transfer_read instruction can be interpreted as a warp-level cooperative m...
Definition: MMAUtils.cpp:272
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:22
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
Definition: MMAUtils.h:52
Encapsulates the parameters needed to lower a nvgpu.ldmatrix operation to nvvm.ldmatrix.
Definition: MMAUtils.h:77
NVVM::MMALayout targetLayout
Definition: MMAUtils.h:82
vector::IteratorType contiguousDimType
Definition: MMAUtils.h:81
Collects information about a warp-level matrix operand represented by a VectorType.
Definition: MMAUtils.h:34
MatMulOperandRole operandRole
Definition: MMAUtils.h:36