MLIR  16.0.0git
NvGpuSupport.cpp
Go to the documentation of this file.
1 //===- NvGpuSupport.cpp - MLIR Vector to GPU lowering support --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides utilities to assist in the lowering of Vector operations
10 // to NvGPU dialect MMA operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "NvGpuSupport.h"
18 
19 namespace mlir {
20 namespace nvgpu {
21 namespace {
22 
23 /// There are always 4 threads per [128|256|512] bit row.
24 constexpr int64_t kThreadsPerRow = 4;
25 
26 constexpr int64_t kNumRowsPerTile = 8;
27 
28 bool isAccumulatorOrResult(MatMulOperandRole operandType) {
29  return operandType == MatMulOperandRole::C;
30 }
31 
32 /// Returns the number of registers which compose a matrix fragment held by a
33 /// single thread.
34 int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) {
35  int64_t lineSize = inferTileWidthInBits(type);
36  auto shape = type.vectorType.getShape();
37  return (shape[0] / kNumRowsPerTile) *
38  (shape[1] * type.vectorType.getElementType().getIntOrFloatBitWidth()) /
39  lineSize;
40 }
41 
42 /// Returns the number of 8 x [128|256|512] bit tiles that compose the given
43 /// operand shape.
44 std::array<int64_t, 2> getTileShape(ArrayRef<int64_t> operandShape,
45  Type elementType, int64_t lineSizeBits) {
46  // For each 8x128bit square, a thread is responsible for one 32bit register.
47  return {operandShape[0] / kNumRowsPerTile,
48  (operandShape[1] * elementType.getIntOrFloatBitWidth()) /
49  lineSizeBits};
50 }
51 
52 } // namespace
53 
55  WarpMatrixInfo info;
56 
57  // Determine the vector type.
58  if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
59  info.vectorType = writeOp.getVectorType();
60  } else if (isa<vector::TransferReadOp, vector::ContractionOp,
61  arith::ConstantOp>(op)) {
62  info.vectorType = op->getResult(0).getType().cast<VectorType>();
63  } else {
64  return op->emitError()
65  << "unhandled operation type in nvgpu.mma.sync conversion path";
66  }
67 
68  // Determine the operand role. We assume it is an accumulator/result unless it
69  // is directly consumed by a `vector.contract` op.
71  for (Operation *user : op->getUsers()) {
72  auto contract = dyn_cast<vector::ContractionOp>(user);
73  if (!contract)
74  continue;
75  if (contract.getLhs() == op->getResult(0)) {
77  break;
78  }
79  if (contract.getRhs() == op->getResult(0)) {
81  break;
82  }
83  }
84  return info;
85 }
86 
87 int64_t inferTileWidthInBits(const WarpMatrixInfo &type) {
88  bool isAcc = isAccumulatorOrResult(type.operandRole);
89  Type elType = type.vectorType.getElementType();
90  if (isAcc && elType.getIntOrFloatBitWidth() == 32) {
91  return 256;
92  }
93  if (elType.getIntOrFloatBitWidth() == 64) {
94  return isAcc ? 512 : 256;
95  }
96  return 128;
97 }
98 
101  MLIRContext *ctx = type.vectorType.getContext();
102  const bool isAccum = isAccumulatorOrResult(type.operandRole);
103 
104  Type elType = type.vectorType.getElementType();
105  if (elType.isF16()) {
106  return FragmentElementInfo{
107  LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32,
108  inferNumRegistersPerMatrixFragment(type)};
109  }
110 
111  // f64 operand
112  Type f64Ty = Float64Type::get(ctx);
113  if (elType.isF64()) {
114  return isAccum
115  ? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128,
116  inferNumRegistersPerMatrixFragment(type)}
117  : FragmentElementInfo{f64Ty, 1, 64,
118  inferNumRegistersPerMatrixFragment(type)};
119  }
120 
121  // int8 operand
122  if (elType.isInteger(8)) {
123  return FragmentElementInfo{
124  LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32,
125  inferNumRegistersPerMatrixFragment(type)};
126  }
127 
128  // int4 operand
129  if (elType.isInteger(4)) {
130  return FragmentElementInfo{
131  LLVM::getFixedVectorType(IntegerType::get(ctx, 4), 8), 8, 32,
132  inferNumRegistersPerMatrixFragment(type)};
133  }
134 
135  // Integer 32bit acc operands
136  if (elType.isInteger(32)) {
137  return FragmentElementInfo{
138  LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64,
139  inferNumRegistersPerMatrixFragment(type)};
140  }
141 
142  // Floating point 32bit operands
143  if (elType.isF32()) {
144  Type f32Ty = Float32Type::get(ctx);
145  return isAccum
147  inferNumRegistersPerMatrixFragment(type)}
148  : FragmentElementInfo{f32Ty, 1, 32,
149  inferNumRegistersPerMatrixFragment(type)};
150  }
151  return failure();
152 }
153 
155  Type elementType,
156  ArrayRef<int64_t> operandShape,
157  bool isAccumulator,
158  int64_t elementsPerRegister,
159  AffineExpr logicalValueId) {
160  const int64_t elementsPerLine =
161  lineSize / elementType.getIntOrFloatBitWidth();
162  const std::array<int64_t, 2> num8x128bTiles =
163  getTileShape(operandShape, elementType, lineSize);
164  AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister);
165  return AffineMap::get(
166  2, 0,
167  {(registerIdx % num8x128bTiles[0]) * 8,
168  (registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine},
169  elementType.getContext());
170 }
171 
174  const WarpMatrixInfo &fragmentType) {
175  Type elementType = fragmentType.vectorType.getElementType();
176  ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape();
178  getMmaSyncRegisterType(fragmentType);
179  if (failed(regInfo))
180  return failure();
181 
182  const int64_t elementBitWidth = elementType.getIntOrFloatBitWidth();
183  const int64_t elementsPerRegister =
184  regInfo->registerWidthBits / elementBitWidth;
185  const int64_t lineSize = inferTileWidthInBits(fragmentType);
186 
187  AffineExpr laneId, logicalValueIdDim;
188  bindDims(builder.getContext(), laneId, logicalValueIdDim);
189 
190  // Determine what register logicalValueId corresponds to. Use that as a
191  // linear index into the coordinate mapping `index -> (tile row, tile col)`.
192  AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap(
193  lineSize, elementType, operandShape,
194  isAccumulatorOrResult(fragmentType.operandRole), elementsPerRegister,
195  logicalValueIdDim);
196 
197  auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
198  return AffineMap::get(2, 0, dimExprs, builder.getContext());
199  };
200 
201  auto tileRow = registerIndexToTileCoord.getResult(0);
202  auto tileCol = registerIndexToTileCoord.getResult(1);
203  return makeMap({tileRow + laneId.floorDiv(kThreadsPerRow),
204  tileCol + (laneId % kThreadsPerRow) * elementsPerRegister +
205  (logicalValueIdDim % elementsPerRegister)});
206 }
207 
209  bool transpose) {
210  LdMatrixParams params;
211  Type elType = type.vectorType.getElementType();
212  params.fragmentType = type.vectorType;
213  if (type.operandRole == MatMulOperandRole::A ||
215  params.targetLayout = NVVM::MMALayout::row;
216  } else {
217  params.targetLayout = NVVM::MMALayout::col;
218  }
219  ArrayRef<int64_t> shape = type.vectorType.getShape();
220  params.contiguousDimType =
222 
224  params.numTiles = (shape[0] / kNumRowsPerTile) *
225  ((shape[1] * elType.getIntOrFloatBitWidth()) / 128);
226  } else {
227  params.numTiles = (shape[1] / kNumRowsPerTile) *
228  ((shape[0] * elType.getIntOrFloatBitWidth()) / 128);
229  }
230 
231  if (params.numTiles == 0)
232  return failure();
233 
234  return params;
235 }
236 
239  const LdMatrixParams &params) {
240  // One thread per 128b row.
241  const int64_t kNumThreadsPerTile = kNumRowsPerTile;
242  const int bitsPerElement = static_cast<int>(
243  params.fragmentType.getElementType().getIntOrFloatBitWidth());
244  const int kElementsPer128b = (128 / bitsPerElement);
245  ArrayRef<int64_t> operandShape = params.fragmentType.getShape();
246  AffineExpr d0 = getAffineDimExpr(0, builder.getContext());
247 
248  auto makeMap = [&](ArrayRef<AffineExpr> dimExprs) -> AffineMap {
249  return AffineMap::get(1, 0, dimExprs, builder.getContext());
250  };
251 
252  // This case corresponds to row-major A|C or col-major B operands.
254  AffineExpr row = d0 % (operandShape[0]);
255  AffineExpr col = d0.floorDiv(operandShape[0]) * (kElementsPer128b);
256  return makeMap({row, col});
257  }
258 
259  // This case Corresponds to col-major A|C or row-major B operands. The
260  // operandShape given is already pre-transposed (e.g. 8x16 = KxN).
262  const int64_t num8x128bCols = (operandShape[0] * bitsPerElement) / 128;
263  // Threads are assigned in groups of 8 first across columns, then to
264  // rows. This is transpose of what `ldmatrix` expects, but when
265  // `ldmatrix` gets the `.trans` qualifier, final the effect will be to
266  // transpose just the blocks.
267  auto groupIdx = d0.floorDiv(kNumThreadsPerTile);
268  auto tileCol = (groupIdx % num8x128bCols);
269  auto tileRow = groupIdx.floorDiv(num8x128bCols);
270  return makeMap({tileCol * kElementsPer128b,
271  tileRow * kNumRowsPerTile + (d0 % kNumRowsPerTile)});
272  }
273  return failure();
274 }
275 
278  PatternRewriter &rewriter) const {
279  Location loc = op.getLoc();
280  Value lhs = op.getLhs();
281  Value rhs = op.getRhs();
282  Value res = op.getAcc();
283 
284  // Set up the parallel/reduction structure in right form.
285  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
286  auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); };
287  AffineExpr m;
288  AffineExpr n;
289  AffineExpr k;
290  bindDims(rewriter.getContext(), m, n, k);
291  static constexpr std::array<int64_t, 2> perm = {1, 0};
292  auto iteratorTypes = op.getIteratorTypes().getValue();
293  SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
294  if (iteratorTypes.size() != 3)
295  return failure();
296  if (!(isParallelIterator(iteratorTypes[0]) &&
297  isParallelIterator(iteratorTypes[1]) &&
298  isReductionIterator(iteratorTypes[2])))
299  return failure();
300 
301  // The canonical form is "TNT" = A row-major, B col-major, C row-major.
302  const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}});
303  if (maps == canonicalForm) {
304  return failure();
305  }
306  if (maps == infer({{m, k}, {k, n}, {m, n}})) {
307  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
308  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
309  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
310  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
311  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
312  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
313  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
314  std::swap(rhs, lhs);
315  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
316  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
317  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
318  std::swap(rhs, lhs);
319  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
320  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
321  std::swap(lhs, rhs);
322  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
323  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
324  std::swap(lhs, rhs);
325  } else {
326  return failure();
327  }
328  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
329  op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm),
330  op.getIteratorTypes());
331  return success();
332 }
333 
334 } // namespace nvgpu
335 } // namespace mlir
Include the generated interface declarations.
bool isF32() const
Definition: Types.cpp:23
MLIRContext * getContext() const
Definition: Builders.h:54
IteratorType contiguousDimType
Definition: NvGpuSupport.h:74
Type getFixedVectorType(Type elementType, unsigned numElements)
Creates an LLVM dialect-compatible type with the given element type and length.
Definition: LLVMTypes.cpp:940
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
bool isParallelIterator(Attribute attr)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:31
static unsigned perm(const SparseTensorEncodingAttr &enc, unsigned d)
Helper method to apply dimension ordering permutation.
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
NVVM::MMALayout targetLayout
Definition: NvGpuSupport.h:75
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
Collects information about a warp-level matrix operand represented by a VectorType.
Definition: NvGpuSupport.h:31
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:311
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:282
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
This class provides support for representing a failure result, or a valid value of type T...
Definition: LogicalResult.h:78
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isF16() const
Definition: Types.cpp:22
Base type for affine expression.
Definition: AffineExpr.h:68
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:324
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:42
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
Given an op that operates on a VectorType representing a warp-level matrix operand, the function returns a struct containing relevant type information.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:489
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:765
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
bool isReductionIterator(Attribute attr)
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:19
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
Definition: NvGpuSupport.h:47
bool isF64() const
Definition: Types.cpp:24
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder, const LdMatrixParams &params)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
Type getType() const
Return the type of this value.
Definition: Value.h:118
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:91
LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:328
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:235
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:221
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:650
This class helps build Operations.
Definition: Builders.h:192
FailureOr< nvgpu::LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
MatMulOperandRole operandRole
Definition: NvGpuSupport.h:33
static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize, Type elementType, ArrayRef< int64_t > operandShape, bool isAccumulator, int64_t elementsPerRegister, AffineExpr logicalValueId)
U cast() const
Definition: Types.h:278