MLIR  22.0.0git
MMAUtils.cpp
Go to the documentation of this file.
1 //===- MMAUtils.cpp - MLIR NVGPU dialect utils for MMA operations----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 
13 
14 using namespace mlir;
15 using namespace mlir::nvgpu;
16 
17 /// There are always 4 threads per [128|256|512] bit row.
18 static constexpr int64_t kThreadsPerRow = 4;
19 static constexpr int64_t kNumRowsPerTile = 8;
20 
21 static bool isAccumulatorOrResult(MatMulOperandRole operandType) {
22  return operandType == MatMulOperandRole::C;
23 }
24 
25 /// Returns the number of registers which compose a matrix fragment held by a
26 /// single thread.
28  int64_t lineSize = inferTileWidthInBits(type);
29  auto shape = type.vectorType.getShape();
30  return (shape[0] / kNumRowsPerTile) *
31  (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) /
32  lineSize;
33 }
34 
35 /// Returns the number of 8 x [128|256|512] bit tiles that compose the given
36 /// operand shape.
37 static std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape,
38  Type elementType,
39  int64_t lineSizeBits) {
40  // For each 8x128bit square, a thread is responsible for one 32bit register.
41  return {operandShape[0] / kNumRowsPerTile,
42  (operandShape[1] * elementType.getIntOrFloatBitWidth()) /
43  lineSizeBits};
44 }
45 
46 /// Returns the first user of the `op` that is vector.contract. If no
47 /// vector.contract user exists, return failure.
48 FailureOr<vector::ContractionOp> nvgpu::getUserContract(Operation *op) {
49  for (Operation *user : op->getUsers()) {
50  if (auto contractOp = dyn_cast<vector::ContractionOp>(user))
51  return contractOp;
52  }
53  return failure();
54 }
55 
56 FailureOr<WarpMatrixInfo> nvgpu::getWarpMatrixInfo(Operation *op) {
57  WarpMatrixInfo info;
58 
59  // Determine the vector type at warp-level.
60  if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
61  info.vectorType = writeOp.getVectorType();
62  } else if (isa<vector::TransferReadOp, vector::ContractionOp,
63  vector::ExtractStridedSliceOp, arith::ConstantOp>(op)) {
64  info.vectorType = cast<VectorType>(op->getResult(0).getType());
65  } else {
66  return op->emitError()
67  << "unhandled operation type in nvgpu.mma.sync conversion path";
68  }
69 
70  // Determine the operand role. We assume it is an accumulator/result unless it
71  // is directly consumed by a `vector.contract` op.
73  FailureOr<vector::ContractionOp> contractOp = getUserContract(op);
74  if (failed(contractOp))
75  return info;
76 
77  if ((*contractOp).getLhs() == op->getResult(0))
79  else if ((*contractOp).getRhs() == op->getResult(0))
81 
82  return info;
83 }
84 
86  bool isAcc = isAccumulatorOrResult(type.operandRole);
87  Type elType = type.vectorType.getElementType();
88  if (isAcc && elType.getIntOrFloatBitWidth() == 32) {
89  return 256;
90  }
91  if (elType.getIntOrFloatBitWidth() == 64) {
92  return isAcc ? 512 : 256;
93  }
94  return 128;
95 }
96 
97 FailureOr<FragmentElementInfo>
99  MLIRContext *ctx = type.vectorType.getContext();
100  const bool isAccum = isAccumulatorOrResult(type.operandRole);
101 
102  Type elType = type.vectorType.getElementType();
103  if (elType.isF16()) {
106  }
107 
108  // f64 operand
109  Type f64Ty = Float64Type::get(ctx);
110  if (elType.isF64()) {
111  return isAccum
112  ? FragmentElementInfo{VectorType::get(2, f64Ty), 2, 128,
114  : FragmentElementInfo{f64Ty, 1, 64,
116  }
117 
118  // int8 operand
119  if (elType.isInteger(8)) {
122  }
123 
124  // int4 operand
125  if (elType.isInteger(4)) {
128  }
129 
130  // Integer 32bit acc operands
131  if (elType.isInteger(32)) {
134  }
135 
136  // Floating point 32bit operands
137  if (elType.isF32()) {
138  Type f32Ty = Float32Type::get(ctx);
139  return isAccum
140  ? FragmentElementInfo{VectorType::get(2, f32Ty), 2, 64,
142  : FragmentElementInfo{f32Ty, 1, 32,
144  }
145  return failure();
146 }
147 
149  Type elementType,
150  ArrayRef<int64_t> operandShape,
151  bool isAccumulator,
152  int64_t elementsPerRegister,
153  AffineExpr logicalValueId) {
154  const int64_t elementsPerLine =
155  lineSize / elementType.getIntOrFloatBitWidth();
156  const std::array<int64_t, 2> num8x128bTiles =
157  getTileShape(operandShape, elementType, lineSize);
158  AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister);
159  return AffineMap::get(
160  2, 0,
161  {(registerIdx % num8x128bTiles[0]) * 8,
162  (registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine},
163  elementType.getContext());
164 }
165 
166 FailureOr<AffineMap>
168  const WarpMatrixInfo &fragmentType) {
169  Type elementType = fragmentType.vectorType.getElementType();
170  ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape();
171  FailureOr<FragmentElementInfo> regInfo = getMmaSyncRegisterType(fragmentType);
172  if (failed(regInfo))
173  return failure();
174 
175  const int64_t elementBitWidth = elementType.getIntOrFloatBitWidth();
176  const int64_t elementsPerRegister =
177  regInfo->registerWidthBits / elementBitWidth;
178  const int64_t lineSize = inferTileWidthInBits(fragmentType);
179 
180  AffineExpr laneId, logicalValueIdDim;
181  bindDims(builder.getContext(), laneId, logicalValueIdDim);
182 
183  // Determine what register logicalValueId corresponds to. Use that as a
184  // linear index into the coordinate mapping `index -> (tile row, tile col)`.
185  AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap(
186  lineSize, elementType, operandShape,
187  isAccumulatorOrResult(fragmentType.operandRole), elementsPerRegister,
188  logicalValueIdDim);
189 
190  auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
191  return AffineMap::get(2, 0, dimExprs, builder.getContext());
192  };
193 
194  auto tileRow = registerIndexToTileCoord.getResult(0);
195  auto tileCol = registerIndexToTileCoord.getResult(1);
196  return makeMap({tileRow + laneId.floorDiv(kThreadsPerRow),
197  tileCol + (laneId % kThreadsPerRow) * elementsPerRegister +
198  (logicalValueIdDim % elementsPerRegister)});
199 }
200 
201 FailureOr<LdMatrixParams> nvgpu::getLdMatrixParams(const WarpMatrixInfo &type,
202  bool transpose) {
203  LdMatrixParams params;
204  Type elType = type.vectorType.getElementType();
205  params.fragmentType = type.vectorType;
206  if (type.operandRole == MatMulOperandRole::A ||
208  params.targetLayout = NVVM::MMALayout::row;
209  } else {
210  params.targetLayout = NVVM::MMALayout::col;
211  }
212  ArrayRef<int64_t> shape = type.vectorType.getShape();
213  params.contiguousDimType = transpose ? vector::IteratorType::parallel
214  : vector::IteratorType::reduction;
215 
216  if (params.contiguousDimType == vector::IteratorType::reduction) {
217  params.numTiles = (shape[0] / kNumRowsPerTile) *
218  ((shape[1] * elType.getIntOrFloatBitWidth()) / 128);
219  } else {
220  params.numTiles = (shape[1] / kNumRowsPerTile) *
221  ((shape[0] * elType.getIntOrFloatBitWidth()) / 128);
222  }
223 
224  if (params.numTiles == 0)
225  return failure();
226 
227  return params;
228 }
229 
230 FailureOr<AffineMap>
232  const LdMatrixParams &params) {
233  // One thread per 128b row.
234  const int bitsPerElement = static_cast<int>(
235  params.fragmentType.getElementType().getIntOrFloatBitWidth());
236  const int kElementsPer128b = (128 / bitsPerElement);
237  ArrayRef<int64_t> operandShape = params.fragmentType.getShape();
238  AffineExpr d0 = getAffineDimExpr(0, builder.getContext());
239 
240  auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
241  return AffineMap::get(1, 0, dimExprs, builder.getContext());
242  };
243 
244  // Index `idx` in vectorType `operandShape` maps to the strided dimension of
245  // the `srcMemref` memory of the LdMatrixOp.
246  int idx =
247  (params.contiguousDimType == vector::IteratorType::reduction) ? 0 : 1;
248 
249  // Affine expr in strided and contiguous dimension encodes the coordinate
250  // mapping for the element a thread points to for warp-wide LdMatrixOp.
251  AffineExpr strided = d0 % (operandShape[idx]);
252  AffineExpr contiguous = d0.floorDiv(operandShape[idx]) * (kElementsPer128b);
253 
254  // This case corresponds to row-major matrixA or col-major matrixB or
255  // row-major matrixC. This is when the memory layout in `srcMemref`
256  // match mma.sync hardware vector register operand layout.
257  if (params.contiguousDimType == vector::IteratorType::reduction)
258  return makeMap({strided, contiguous});
259 
260  // This case corresponds to col-major matrixA or row-major matrixB or
261  // col-major matrixC. This is when the memory layout in `srcMemref` does not
262  // match mma.sync hardware vector register operand layout.
263  if (params.contiguousDimType == vector::IteratorType::parallel)
264  return makeMap({contiguous, strided});
265 
266  return failure();
267 }
268 
269 bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferReadOp op) {
270  if (op.getMask() || op.hasOutOfBoundsDim())
271  return false;
272  VectorType type = op.getType();
273  // The result type should be 2D. Note that it is possible to expand support so
274  // that we are robust to extra unit dimensions that failed to fold, but that
275  // would significantly increase downstream code complexity in the conversion
276  // step. For now, we rely on other patterns to ensure canonical 2D form is
277  // used when targeting the `nvgpu.mma.sync` lowering path.
278  if (!type.hasStaticShape() || type.getRank() != 2)
279  return false;
280 
281  // Currently we can't support reads on tensor types because we need stride
282  // information to ensure correctness of downstream assumptions. It is possible
283  // to enable this if caller can assert that tensor will be lowered in a
284  // particular manner.
285  auto sourceType = dyn_cast<MemRefType>(op.getBase().getType());
286  if (!sourceType)
287  return false;
288 
289  // Check that the last dimension of the read is contiguous. Note that it is
290  // possible to expand support for this by scalarizing all the loads during
291  // conversion.
292  auto [strides, offset] = sourceType.getStridesAndOffset();
293  return strides.back() == 1;
294 }
295 
296 bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferWriteOp op) {
297  if (op.getMask() || op.hasOutOfBoundsDim() || op.getTransferRank() == 0)
298  return false;
299  VectorType type = op.getVectorType();
300  if (!type.hasStaticShape() || type.getRank() != 2)
301  return false;
302  // TODO: Currently we rely on lowering to a `vector.store` operation. We could
303  // support the transposed write case by lowering to scalarized `memref.store`
304  // operations.
305  if (!op.getPermutationMap().isMinorIdentity())
306  return false;
307  // Currently we can't support reads on tensor types because we need stride
308  // information to ensure correctness of downstream assumptions.
309  auto sourceType = dyn_cast<MemRefType>(op.getBase().getType());
310  if (!sourceType)
311  return false;
312 
313  // Check that the last dimension of the target memref is contiguous. Note that
314  // it is possible to expand support for this by scalarizing all the stores
315  // during conversion.
316  auto [strides, offset] = sourceType.getStridesAndOffset();
317  return strides.back() == 1;
318 }
static constexpr int64_t kNumRowsPerTile
Definition: MMAUtils.cpp:19
static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize, Type elementType, ArrayRef< int64_t > operandShape, bool isAccumulator, int64_t elementsPerRegister, AffineExpr logicalValueId)
Definition: MMAUtils.cpp:148
static constexpr int64_t kThreadsPerRow
There are always 4 threads per [128|256|512] bit row.
Definition: MMAUtils.cpp:18
static bool isAccumulatorOrResult(MatMulOperandRole operandType)
Definition: MMAUtils.cpp:21
static int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type)
Returns the number of registers which compose a matrix fragment held by a single thread.
Definition: MMAUtils.cpp:27
static std::array< int64_t, 2 > getTileShape(ArrayRef< int64_t > operandShape, Type elementType, int64_t lineSizeBits)
Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.
Definition: MMAUtils.cpp:37
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:959
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:407
MLIRContext * getContext() const
Definition: Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isF32() const
Definition: Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
Type getType() const
Return the type of this value.
Definition: Value.h:105
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
Definition: MMAUtils.cpp:85
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
Definition: MMAUtils.cpp:48
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
Definition: MMAUtils.cpp:167
FailureOr< LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
Given type that contains info for a warp-matrix operand and whether or not the load is a transposed l...
Definition: MMAUtils.cpp:201
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
Definition: MMAUtils.cpp:56
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams &params)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
Definition: MMAUtils.cpp:231
MatMulOperandRole
Represents the role of an operand in an MMA instruction: result := matmul(A, B) + C
Definition: MMAUtils.h:26
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
Definition: MMAUtils.cpp:98
bool canLowerToWarpMatrixOperation(vector::TransferReadOp op)
Returns whether the vector.transfer_read instruction can be interpreted as a warp-level cooperative m...
Definition: MMAUtils.cpp:269
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
Definition: MMAUtils.h:52
Encapsulates the parameters needed to lower a nvgpu.ldmatrix operation to nvvm.ldmatrix.
Definition: MMAUtils.h:77
NVVM::MMALayout targetLayout
Definition: MMAUtils.h:82
vector::IteratorType contiguousDimType
Definition: MMAUtils.h:81
Collects information about a warp-level matrix operand represented by a VectorType.
Definition: MMAUtils.h:34
MatMulOperandRole operandRole
Definition: MMAUtils.h:36