MLIR  22.0.0git
MMAUtils.cpp
Go to the documentation of this file.
1 //===- MMAUtils.cpp - MLIR NVGPU dialect utils for MMA operations----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 
13 
14 using namespace mlir;
15 using namespace mlir::nvgpu;
16 
17 /// There are always 4 threads per [128|256|512] bit row.
18 static constexpr int64_t kThreadsPerRow = 4;
19 static constexpr int64_t kNumRowsPerTile = 8;
20 
21 static bool isAccumulatorOrResult(MatMulOperandRole operandType) {
22  return operandType == MatMulOperandRole::C;
23 }
24 
25 /// Returns the number of registers which compose a matrix fragment held by a
26 /// single thread.
28  int64_t lineSize = inferTileWidthInBits(type);
29  auto shape = type.vectorType.getShape();
30  return (shape[0] / kNumRowsPerTile) *
31  (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) /
32  lineSize;
33 }
34 
35 /// Returns the number of 8 x [128|256|512] bit tiles that compose the given
36 /// operand shape.
37 static std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape,
38  Type elementType,
39  int64_t lineSizeBits) {
40  // For each 8x128bit square, a thread is responsible for one 32bit register.
41  return {operandShape[0] / kNumRowsPerTile,
42  (operandShape[1] * elementType.getIntOrFloatBitWidth()) /
43  lineSizeBits};
44 }
45 
46 /// Returns the first user of the `op` that is vector.contract. If no
47 /// vector.contract user exists, return failure.
48 FailureOr<vector::ContractionOp> nvgpu::getUserContract(Operation *op) {
49  for (Operation *user : op->getUsers()) {
50  if (auto contractOp = dyn_cast<vector::ContractionOp>(user))
51  return contractOp;
52  }
53  return failure();
54 }
55 
56 FailureOr<WarpMatrixInfo> nvgpu::getWarpMatrixInfo(Operation *op) {
57  WarpMatrixInfo info;
58 
59  // Determine the vector type at warp-level.
60  if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
61  info.vectorType = writeOp.getVectorType();
62  } else if (isa<vector::TransferReadOp, vector::ContractionOp,
63  vector::ExtractStridedSliceOp, arith::ConstantOp>(op)) {
64  info.vectorType = cast<VectorType>(op->getResult(0).getType());
65  } else {
66  return op->emitError()
67  << "unhandled operation type in nvgpu.mma.sync conversion path";
68  }
69 
70  // Determine the operand role. We assume it is an accumulator/result unless it
71  // is directly consumed by a `vector.contract` op.
73  FailureOr<vector::ContractionOp> contractOp = getUserContract(op);
74  if (failed(contractOp))
75  return info;
76 
77  if ((*contractOp).getLhs() == op->getResult(0))
79  else if ((*contractOp).getRhs() == op->getResult(0))
81 
82  return info;
83 }
84 
86  bool isAcc = isAccumulatorOrResult(type.operandRole);
87  Type elType = type.vectorType.getElementType();
88  if (isAcc && elType.getIntOrFloatBitWidth() == 32) {
89  return 256;
90  }
91  if (elType.getIntOrFloatBitWidth() == 64) {
92  return isAcc ? 512 : 256;
93  }
94  return 128;
95 }
96 
97 FailureOr<FragmentElementInfo>
99  MLIRContext *ctx = type.vectorType.getContext();
100  const bool isAccum = isAccumulatorOrResult(type.operandRole);
101 
102  Type elType = type.vectorType.getElementType();
103  if (elType.isF16()) {
106  }
107 
108  // f64 operand
109  Type f64Ty = Float64Type::get(ctx);
110  if (elType.isF64()) {
111  return isAccum
112  ? FragmentElementInfo{VectorType::get(2, f64Ty), 2, 128,
114  : FragmentElementInfo{f64Ty, 1, 64,
116  }
117 
118  // int8 operand
119  if (elType.isInteger(8)) {
122  }
123 
124  // int4 operand
125  if (elType.isInteger(4)) {
128  }
129 
130  // Integer 32bit acc operands
131  if (elType.isInteger(32)) {
134  }
135 
136  // Floating point 32bit operands
137  if (elType.isF32()) {
138  Type f32Ty = Float32Type::get(ctx);
139  return isAccum
140  ? FragmentElementInfo{VectorType::get(2, f32Ty), 2, 64,
142  : FragmentElementInfo{f32Ty, 1, 32,
144  }
145  return failure();
146 }
147 
149  Type elementType,
150  ArrayRef<int64_t> operandShape,
151  bool isAccumulator,
152  int64_t elementsPerRegister,
153  AffineExpr logicalValueId) {
154  const int64_t elementsPerLine =
155  lineSize / elementType.getIntOrFloatBitWidth();
156  const std::array<int64_t, 2> num8x128bTiles =
157  getTileShape(operandShape, elementType, lineSize);
158  AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister);
159  return AffineMap::get(
160  2, 0,
161  {(registerIdx % num8x128bTiles[0]) * 8,
162  (registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine},
163  elementType.getContext());
164 }
165 
166 FailureOr<AffineMap>
168  const WarpMatrixInfo &fragmentType) {
169  Type elementType = fragmentType.vectorType.getElementType();
170  ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape();
171  FailureOr<nvgpu::FragmentElementInfo> regInfo =
172  getMmaSyncRegisterType(fragmentType);
173  if (failed(regInfo))
174  return failure();
175 
176  const int64_t elementBitWidth = elementType.getIntOrFloatBitWidth();
177  const int64_t elementsPerRegister =
178  regInfo->registerWidthBits / elementBitWidth;
179  const int64_t lineSize = inferTileWidthInBits(fragmentType);
180 
181  AffineExpr laneId, logicalValueIdDim;
182  bindDims(builder.getContext(), laneId, logicalValueIdDim);
183 
184  // Determine what register logicalValueId corresponds to. Use that as a
185  // linear index into the coordinate mapping `index -> (tile row, tile col)`.
186  AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap(
187  lineSize, elementType, operandShape,
188  isAccumulatorOrResult(fragmentType.operandRole), elementsPerRegister,
189  logicalValueIdDim);
190 
191  auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
192  return AffineMap::get(2, 0, dimExprs, builder.getContext());
193  };
194 
195  auto tileRow = registerIndexToTileCoord.getResult(0);
196  auto tileCol = registerIndexToTileCoord.getResult(1);
197  return makeMap({tileRow + laneId.floorDiv(kThreadsPerRow),
198  tileCol + (laneId % kThreadsPerRow) * elementsPerRegister +
199  (logicalValueIdDim % elementsPerRegister)});
200 }
201 
202 FailureOr<nvgpu::LdMatrixParams>
203 nvgpu::getLdMatrixParams(const WarpMatrixInfo &type, bool transpose) {
204  LdMatrixParams params;
205  Type elType = type.vectorType.getElementType();
206  params.fragmentType = type.vectorType;
207  if (type.operandRole == MatMulOperandRole::A ||
209  params.targetLayout = NVVM::MMALayout::row;
210  } else {
211  params.targetLayout = NVVM::MMALayout::col;
212  }
213  ArrayRef<int64_t> shape = type.vectorType.getShape();
214  params.contiguousDimType = transpose ? vector::IteratorType::parallel
215  : vector::IteratorType::reduction;
216 
217  if (params.contiguousDimType == vector::IteratorType::reduction) {
218  params.numTiles = (shape[0] / kNumRowsPerTile) *
219  ((shape[1] * elType.getIntOrFloatBitWidth()) / 128);
220  } else {
221  params.numTiles = (shape[1] / kNumRowsPerTile) *
222  ((shape[0] * elType.getIntOrFloatBitWidth()) / 128);
223  }
224 
225  if (params.numTiles == 0)
226  return failure();
227 
228  return params;
229 }
230 
231 FailureOr<AffineMap>
233  const LdMatrixParams &params) {
234  // One thread per 128b row.
235  const int bitsPerElement = static_cast<int>(
236  params.fragmentType.getElementType().getIntOrFloatBitWidth());
237  const int kElementsPer128b = (128 / bitsPerElement);
238  ArrayRef<int64_t> operandShape = params.fragmentType.getShape();
239  AffineExpr d0 = getAffineDimExpr(0, builder.getContext());
240 
241  auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
242  return AffineMap::get(1, 0, dimExprs, builder.getContext());
243  };
244 
245  // Index `idx` in vectorType `operandShape` maps to the strided dimension of
246  // the `srcMemref` memory of the LdMatrixOp.
247  int idx =
248  (params.contiguousDimType == vector::IteratorType::reduction) ? 0 : 1;
249 
250  // Affine expr in strided and contiguous dimension encodes the coordinate
251  // mapping for the element a thread points to for warp-wide LdMatrixOp.
252  AffineExpr strided = d0 % (operandShape[idx]);
253  AffineExpr contiguous = d0.floorDiv(operandShape[idx]) * (kElementsPer128b);
254 
255  // This case corresponds to row-major matrixA or col-major matrixB or
256  // row-major matrixC. This is when the memory layout in `srcMemref`
257  // match mma.sync hardware vector register operand layout.
258  if (params.contiguousDimType == vector::IteratorType::reduction)
259  return makeMap({strided, contiguous});
260 
261  // This case corresponds to col-major matrixA or row-major matrixB or
262  // col-major matrixC. This is when the memory layout in `srcMemref` does not
263  // match mma.sync hardware vector register operand layout.
264  if (params.contiguousDimType == vector::IteratorType::parallel)
265  return makeMap({contiguous, strided});
266 
267  return failure();
268 }
269 
270 bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferReadOp op) {
271  if (op.getMask() || op.hasOutOfBoundsDim())
272  return false;
273  VectorType type = op.getType();
274  // The result type should be 2D. Note that it is possible to expand support so
275  // that we are robust to extra unit dimensions that failed to fold, but that
276  // would significantly increase downstream code complexity in the conversion
277  // step. For now, we rely on other patterns to ensure canonical 2D form is
278  // used when targeting the `nvgpu.mma.sync` lowering path.
279  if (!type.hasStaticShape() || type.getRank() != 2)
280  return false;
281 
282  // Currently we can't support reads on tensor types because we need stride
283  // information to ensure correctness of downstream assumptions. It is possible
284  // to enable this if caller can assert that tensor will be lowered in a
285  // particular manner.
286  auto sourceType = dyn_cast<MemRefType>(op.getBase().getType());
287  if (!sourceType)
288  return false;
289 
290  // Check that the last dimension of the read is contiguous. Note that it is
291  // possible to expand support for this by scalarizing all the loads during
292  // conversion.
293  auto [strides, offset] = sourceType.getStridesAndOffset();
294  return strides.back() == 1;
295 }
296 
297 bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferWriteOp op) {
298  if (op.getMask() || op.hasOutOfBoundsDim() || op.getTransferRank() == 0)
299  return false;
300  VectorType type = op.getVectorType();
301  if (!type.hasStaticShape() || type.getRank() != 2)
302  return false;
303  // TODO: Currently we rely on lowering to a `vector.store` operation. We could
304  // support the transposed write case by lowering to scalarized `memref.store`
305  // operations.
306  if (!op.getPermutationMap().isMinorIdentity())
307  return false;
308  // Currently we can't support reads on tensor types because we need stride
309  // information to ensure correctness of downstream assumptions.
310  auto sourceType = dyn_cast<MemRefType>(op.getBase().getType());
311  if (!sourceType)
312  return false;
313 
314  // Check that the last dimension of the target memref is contiguous. Note that
315  // it is possible to expand support for this by scalarizing all the stores
316  // during conversion.
317  auto [strides, offset] = sourceType.getStridesAndOffset();
318  return strides.back() == 1;
319 }
static constexpr int64_t kNumRowsPerTile
Definition: MMAUtils.cpp:19
static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize, Type elementType, ArrayRef< int64_t > operandShape, bool isAccumulator, int64_t elementsPerRegister, AffineExpr logicalValueId)
Definition: MMAUtils.cpp:148
static constexpr int64_t kThreadsPerRow
There are always 4 threads per [128|256|512] bit row.
Definition: MMAUtils.cpp:18
static bool isAccumulatorOrResult(MatMulOperandRole operandType)
Definition: MMAUtils.cpp:21
static int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type)
Returns the number of registers which compose a matrix fragment held by a single thread.
Definition: MMAUtils.cpp:27
static std::array< int64_t, 2 > getTileShape(ArrayRef< int64_t > operandShape, Type elementType, int64_t lineSizeBits)
Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.
Definition: MMAUtils.cpp:37
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:959
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:407
MLIRContext * getContext() const
Definition: Builders.h:55
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isF32() const
Definition: Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
Type getType() const
Return the type of this value.
Definition: Value.h:105
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
Definition: MMAUtils.cpp:85
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
Definition: MMAUtils.cpp:48
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
Definition: MMAUtils.cpp:167
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
Definition: MMAUtils.cpp:56
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams &params)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
Definition: MMAUtils.cpp:232
MatMulOperandRole
Represents the role of an operand in an MMA instruction: result := matmul(A, B) + C
Definition: MMAUtils.h:26
FailureOr< LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
Given type that contains info for a warp-matrix operand and whether or not the load is a transposed l...
Definition: MMAUtils.cpp:203
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
Definition: MMAUtils.cpp:98
bool canLowerToWarpMatrixOperation(vector::TransferReadOp op)
Returns whether the vector.transfer_read instruction can be interpreted as a warp-level cooperative m...
Definition: MMAUtils.cpp:270
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
Definition: MMAUtils.h:52
Encapsulates the parameters needed to lower a nvgpu.ldmatrix operation to nvvm.ldmatrix.
Definition: MMAUtils.h:77
NVVM::MMALayout targetLayout
Definition: MMAUtils.h:82
vector::IteratorType contiguousDimType
Definition: MMAUtils.h:81
Collects information about a warp-level matrix operand represented by a VectorType.
Definition: MMAUtils.h:34
MatMulOperandRole operandRole
Definition: MMAUtils.h:36