MLIR  16.0.0git
NvGpuSupport.h
Go to the documentation of this file.
1 //===- NvvmMMASupport.h - MLIR Vector to GPU lowering support --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file provides utilities to assist in the lowering of Vector operations
10 // to GPU dialect MMA operations.
11 //
12 //===----------------------------------------------------------------------===//
13 #ifndef MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H
14 #define MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H
15 
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/Types.h"
23 
24 namespace mlir {
25 namespace nvgpu {
26 
27 enum class MatMulOperandRole : int32_t { A = 0, B, C };
28 
29 /// Collects information about a warp-level matrix operand represented by a
30 /// VectorType.
32  VectorType vectorType;
34 };
35 
36 /// Given an op that operates on a VectorType representing a warp-level matrix
37 /// operand, the function returns a struct containing relevant type information.
39 
40 /// Returns the number of bits in a single tile row. It is either 128, 256, or
41 /// 512 bits depending on the data type and` whether the operand is an
42 /// accumulator/result operand
43 int64_t inferTileWidthInBits(const WarpMatrixInfo &type);
44 
45 /// Specifies information about the registers which compose a matrix fragment
46 /// according to the PTX documentation.
52 };
53 
54 /// Returns a FragmentElementInfo struct describing the register types for the
55 /// given matrix fragment type.
58 
59 /// Returns an AffineMap which maps a two dimensions representing (laneId,
60 /// logicalValueId) and returns two results representing offsets within a
61 /// matrix operand. The offsets point to the values the thread is responsible
62 /// for (AKA the matrix fragment values) during a warp-collective matrix
63 /// operation. For a visual reference of this LaneId -> (row, col) mapping,
64 /// please see NVIDIA's PTX documentation:
65 /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma
68  const WarpMatrixInfo &fragmentType);
69 
71  VectorType fragmentType;
72  bool isAccum;
73  int64_t numTiles;
75  NVVM::MMALayout targetLayout;
76 };
77 
79  bool transpose);
80 /// Returns an AffineMap which maps a single dimension representing the laneId
81 /// to two results representing offsets within the matrix operand that should
82 /// be the pointer locations a thread should pass to the ldmatrix instruction.
85  const LdMatrixParams &params);
86 
87 // Transform contract into (m, k)x(n, k)x(m, n) form so that it can be converted
88 // to MMA matmul.
90  : public OpRewritePattern<vector::ContractionOp> {
92 
93  LogicalResult matchAndRewrite(vector::ContractionOp op,
94  PatternRewriter &rewriter) const override;
95 };
96 
97 } // namespace nvgpu
98 } // namespace mlir
99 
100 #endif // MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H
Include the generated interface declarations.
IteratorType contiguousDimType
Definition: NvGpuSupport.h:74
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
IteratorType
Typed representation for loop type strings.
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
NVVM::MMALayout targetLayout
Definition: NvGpuSupport.h:75
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
Collects information about a warp-level matrix operand represented by a VectorType.
Definition: NvGpuSupport.h:31
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This class provides support for representing a failure result, or a valid value of type T...
Definition: LogicalResult.h:78
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
Given an op that operates on a VectorType representing a warp-level matrix operand, the function returns a struct containing relevant type information.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
Definition: NvGpuSupport.h:47
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder, const LdMatrixParams &params)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
This class helps build Operations.
Definition: Builders.h:192
FailureOr< nvgpu::LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
MatMulOperandRole operandRole
Definition: NvGpuSupport.h:33