MLIR  20.0.0git
MMAUtils.cpp
Go to the documentation of this file.
1 //===- MMAUtils.cpp - MLIR NVGPU dialect utils for MMA operations----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 
15 
16 using namespace mlir;
17 using namespace mlir::nvgpu;
18 
19 /// There are always 4 threads per [128|256|512] bit row.
20 static constexpr int64_t kThreadsPerRow = 4;
21 static constexpr int64_t kNumRowsPerTile = 8;
22 
23 static bool isAccumulatorOrResult(MatMulOperandRole operandType) {
24  return operandType == MatMulOperandRole::C;
25 }
26 
27 /// Returns the number of registers which compose a matrix fragment held by a
28 /// single thread.
30  int64_t lineSize = inferTileWidthInBits(type);
31  auto shape = type.vectorType.getShape();
32  return (shape[0] / kNumRowsPerTile) *
33  (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) /
34  lineSize;
35 }
36 
37 /// Returns the number of 8 x [128|256|512] bit tiles that compose the given
38 /// operand shape.
39 static std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape,
40  Type elementType,
41  int64_t lineSizeBits) {
42  // For each 8x128bit square, a thread is responsible for one 32bit register.
43  return {operandShape[0] / kNumRowsPerTile,
44  (operandShape[1] * elementType.getIntOrFloatBitWidth()) /
45  lineSizeBits};
46 }
47 
48 /// Returns the first user of the `op` that is vector.contract. If no
49 /// vector.contract user exists, return failure.
50 FailureOr<vector::ContractionOp> nvgpu::getUserContract(Operation *op) {
51  for (Operation *user : op->getUsers()) {
52  if (auto contractOp = dyn_cast<vector::ContractionOp>(user))
53  return contractOp;
54  }
55  return failure();
56 }
57 
58 FailureOr<WarpMatrixInfo> nvgpu::getWarpMatrixInfo(Operation *op) {
59  WarpMatrixInfo info;
60 
61  // Determine the vector type at warp-level.
62  if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
63  info.vectorType = writeOp.getVectorType();
64  } else if (isa<vector::TransferReadOp, vector::ContractionOp,
65  vector::ExtractStridedSliceOp, arith::ConstantOp>(op)) {
66  info.vectorType = cast<VectorType>(op->getResult(0).getType());
67  } else {
68  return op->emitError()
69  << "unhandled operation type in nvgpu.mma.sync conversion path";
70  }
71 
72  // Determine the operand role. We assume it is an accumulator/result unless it
73  // is directly consumed by a `vector.contract` op.
75  FailureOr<vector::ContractionOp> contractOp = getUserContract(op);
76  if (failed(contractOp))
77  return info;
78 
79  if ((*contractOp).getLhs() == op->getResult(0))
81  else if ((*contractOp).getRhs() == op->getResult(0))
83 
84  return info;
85 }
86 
88  bool isAcc = isAccumulatorOrResult(type.operandRole);
89  Type elType = type.vectorType.getElementType();
90  if (isAcc && elType.getIntOrFloatBitWidth() == 32) {
91  return 256;
92  }
93  if (elType.getIntOrFloatBitWidth() == 64) {
94  return isAcc ? 512 : 256;
95  }
96  return 128;
97 }
98 
99 FailureOr<FragmentElementInfo>
101  MLIRContext *ctx = type.vectorType.getContext();
102  const bool isAccum = isAccumulatorOrResult(type.operandRole);
103 
104  Type elType = type.vectorType.getElementType();
105  if (elType.isF16()) {
106  return FragmentElementInfo{
109  }
110 
111  // f64 operand
112  Type f64Ty = Float64Type::get(ctx);
113  if (elType.isF64()) {
114  return isAccum
115  ? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128,
117  : FragmentElementInfo{f64Ty, 1, 64,
119  }
120 
121  // int8 operand
122  if (elType.isInteger(8)) {
123  return FragmentElementInfo{
124  LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32,
126  }
127 
128  // int4 operand
129  if (elType.isInteger(4)) {
130  return FragmentElementInfo{
131  LLVM::getFixedVectorType(IntegerType::get(ctx, 4), 8), 8, 32,
133  }
134 
135  // Integer 32bit acc operands
136  if (elType.isInteger(32)) {
137  return FragmentElementInfo{
138  LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64,
140  }
141 
142  // Floating point 32bit operands
143  if (elType.isF32()) {
144  Type f32Ty = Float32Type::get(ctx);
145  return isAccum
148  : FragmentElementInfo{f32Ty, 1, 32,
150  }
151  return failure();
152 }
153 
155  Type elementType,
156  ArrayRef<int64_t> operandShape,
157  bool isAccumulator,
158  int64_t elementsPerRegister,
159  AffineExpr logicalValueId) {
160  const int64_t elementsPerLine =
161  lineSize / elementType.getIntOrFloatBitWidth();
162  const std::array<int64_t, 2> num8x128bTiles =
163  getTileShape(operandShape, elementType, lineSize);
164  AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister);
165  return AffineMap::get(
166  2, 0,
167  {(registerIdx % num8x128bTiles[0]) * 8,
168  (registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine},
169  elementType.getContext());
170 }
171 
172 FailureOr<AffineMap>
174  const WarpMatrixInfo &fragmentType) {
175  Type elementType = fragmentType.vectorType.getElementType();
176  ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape();
177  FailureOr<nvgpu::FragmentElementInfo> regInfo =
178  getMmaSyncRegisterType(fragmentType);
179  if (failed(regInfo))
180  return failure();
181 
182  const int64_t elementBitWidth = elementType.getIntOrFloatBitWidth();
183  const int64_t elementsPerRegister =
184  regInfo->registerWidthBits / elementBitWidth;
185  const int64_t lineSize = inferTileWidthInBits(fragmentType);
186 
187  AffineExpr laneId, logicalValueIdDim;
188  bindDims(builder.getContext(), laneId, logicalValueIdDim);
189 
190  // Determine what register logicalValueId corresponds to. Use that as a
191  // linear index into the coordinate mapping `index -> (tile row, tile col)`.
192  AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap(
193  lineSize, elementType, operandShape,
194  isAccumulatorOrResult(fragmentType.operandRole), elementsPerRegister,
195  logicalValueIdDim);
196 
197  auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
198  return AffineMap::get(2, 0, dimExprs, builder.getContext());
199  };
200 
201  auto tileRow = registerIndexToTileCoord.getResult(0);
202  auto tileCol = registerIndexToTileCoord.getResult(1);
203  return makeMap({tileRow + laneId.floorDiv(kThreadsPerRow),
204  tileCol + (laneId % kThreadsPerRow) * elementsPerRegister +
205  (logicalValueIdDim % elementsPerRegister)});
206 }
207 
208 FailureOr<nvgpu::LdMatrixParams>
210  LdMatrixParams params;
211  Type elType = type.vectorType.getElementType();
212  params.fragmentType = type.vectorType;
213  if (type.operandRole == MatMulOperandRole::A ||
215  params.targetLayout = NVVM::MMALayout::row;
216  } else {
217  params.targetLayout = NVVM::MMALayout::col;
218  }
219  ArrayRef<int64_t> shape = type.vectorType.getShape();
220  params.contiguousDimType = transpose ? vector::IteratorType::parallel
221  : vector::IteratorType::reduction;
222 
223  if (params.contiguousDimType == vector::IteratorType::reduction) {
224  params.numTiles = (shape[0] / kNumRowsPerTile) *
225  ((shape[1] * elType.getIntOrFloatBitWidth()) / 128);
226  } else {
227  params.numTiles = (shape[1] / kNumRowsPerTile) *
228  ((shape[0] * elType.getIntOrFloatBitWidth()) / 128);
229  }
230 
231  if (params.numTiles == 0)
232  return failure();
233 
234  return params;
235 }
236 
237 FailureOr<AffineMap>
239  const LdMatrixParams &params) {
240  // One thread per 128b row.
241  const int bitsPerElement = static_cast<int>(
242  params.fragmentType.getElementType().getIntOrFloatBitWidth());
243  const int kElementsPer128b = (128 / bitsPerElement);
244  ArrayRef<int64_t> operandShape = params.fragmentType.getShape();
245  AffineExpr d0 = getAffineDimExpr(0, builder.getContext());
246 
247  auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
248  return AffineMap::get(1, 0, dimExprs, builder.getContext());
249  };
250 
251  // Index `idx` in vectorType `operandShape` maps to the strided dimension of
252  // the `srcMemref` memory of the LdMatrixOp.
253  int idx =
254  (params.contiguousDimType == vector::IteratorType::reduction) ? 0 : 1;
255 
256  // Affine expr in strided and contiguous dimension encodes the coordinate
257  // mapping for the element a thread points to for warp-wide LdMatrixOp.
258  AffineExpr strided = d0 % (operandShape[idx]);
259  AffineExpr contiguous = d0.floorDiv(operandShape[idx]) * (kElementsPer128b);
260 
261  // This case corresponds to row-major matrixA or col-major matrixB or
262  // row-major matrixC. This is when the memory layout in `srcMemref`
263  // match mma.sync hardware vector register operand layout.
264  if (params.contiguousDimType == vector::IteratorType::reduction)
265  return makeMap({strided, contiguous});
266 
267  // This case corresponds to col-major matrixA or row-major matrixB or
268  // col-major matrixC. This is when the memory layout in `srcMemref` does not
269  // match mma.sync hardware vector register operand layout.
270  if (params.contiguousDimType == vector::IteratorType::parallel)
271  return makeMap({contiguous, strided});
272 
273  return failure();
274 }
275 
276 bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferReadOp op) {
277  if (op.getMask() || op.hasOutOfBoundsDim())
278  return false;
279  VectorType type = op.getType();
280  // The result type should be 2D. Note that it is possible to expand support so
281  // that we are robust to extra unit dimensions that failed to fold, but that
282  // would significantly increase downstream code complexity in the conversion
283  // step. For now, we rely on other patterns to ensure canonical 2D form is
284  // used when targeting the `nvgpu.mma.sync` lowering path.
285  if (!type.hasStaticShape() || type.getRank() != 2)
286  return false;
287 
288  // Currently we can't support reads on tensor types because we need stride
289  // information to ensure correctness of downstream assumptions. It is possible
290  // to enable this if caller can assert that tensor will be lowered in a
291  // particular manner.
292  auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
293  if (!sourceType)
294  return false;
295 
296  // Check that the last dimension of the read is contiguous. Note that it is
297  // possible to expand support for this by scalarizing all the loads during
298  // conversion.
299  auto [strides, offset] = mlir::getStridesAndOffset(sourceType);
300  return strides.back() == 1;
301 }
302 
303 bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferWriteOp op) {
304  if (op.getMask() || op.hasOutOfBoundsDim() || op.getTransferRank() == 0)
305  return false;
306  VectorType type = op.getVectorType();
307  if (!type.hasStaticShape() || type.getRank() != 2)
308  return false;
309  // TODO: Currently we rely on lowering to a `vector.store` operation. We could
310  // support the transposed write case by lowering to scalarized `memref.store`
311  // operations.
312  if (!op.getPermutationMap().isMinorIdentity())
313  return false;
314  // Currently we can't support reads on tensor types because we need stride
315  // information to ensure correctness of downstream assumptions.
316  auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
317  if (!sourceType)
318  return false;
319 
320  // Check that the last dimension of the target memref is contiguous. Note that
321  // it is possible to expand support for this by scalarizing all the stores
322  // during conversion.
323  auto [strides, offset] = mlir::getStridesAndOffset(sourceType);
324  return strides.back() == 1;
325 }
static constexpr int64_t kNumRowsPerTile
Definition: MMAUtils.cpp:21
static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize, Type elementType, ArrayRef< int64_t > operandShape, bool isAccumulator, int64_t elementsPerRegister, AffineExpr logicalValueId)
Definition: MMAUtils.cpp:154
static constexpr int64_t kThreadsPerRow
There are always 4 threads per [128|256|512] bit row.
Definition: MMAUtils.cpp:20
static bool isAccumulatorOrResult(MatMulOperandRole operandType)
Definition: MMAUtils.cpp:23
static int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type)
Returns the number of registers which compose a matrix fragment held by a single thread.
Definition: MMAUtils.cpp:29
static std::array< int64_t, 2 > getTileShape(ArrayRef< int64_t > operandShape, Type elementType, int64_t lineSizeBits)
Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.
Definition: MMAUtils.cpp:39
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:907
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
MLIRContext * getContext() const
Definition: Builders.h:55
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:213
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:56
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isF32() const
Definition: Types.cpp:55
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:62
bool isF16() const
Definition: Types.cpp:53
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:129
Type getType() const
Return the type of this value.
Definition: Value.h:129
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
Definition: LLVMTypes.cpp:953
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
Definition: MMAUtils.cpp:87
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
Definition: MMAUtils.cpp:50
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
Definition: MMAUtils.cpp:173
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
Definition: MMAUtils.cpp:58
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams &params)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
Definition: MMAUtils.cpp:238
MatMulOperandRole
Represents the role of an operand in an MMA instruction: result := matmul(A, B) + C
Definition: MMAUtils.h:26
FailureOr< LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
Given type that contains info for a warp-matrix operand and whether or not the load is a transposed l...
Definition: MMAUtils.cpp:209
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
Definition: MMAUtils.cpp:100
bool canLowerToWarpMatrixOperation(vector::TransferReadOp op)
Returns whether the vector.transfer_read instruction can be interpreted as a warp-level cooperative m...
Definition: MMAUtils.cpp:276
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:21
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:607
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
Definition: MMAUtils.h:52
Encapsulates the parameters needed to lower a nvgpu.ldmatrix operation to nvvm.ldmatrix.
Definition: MMAUtils.h:77
NVVM::MMALayout targetLayout
Definition: MMAUtils.h:82
vector::IteratorType contiguousDimType
Definition: MMAUtils.h:81
Collects information about a warp-level matrix operand represented by a VectorType.
Definition: MMAUtils.h:34
MatMulOperandRole operandRole
Definition: MMAUtils.h:36