33 (shape[1] * type.
vectorType.getElementType().getIntOrFloatBitWidth()) /
41 int64_t lineSizeBits) {
52 if (
auto contractOp = dyn_cast<vector::ContractionOp>(user))
62 if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
64 }
else if (isa<vector::TransferReadOp, vector::ContractionOp,
65 vector::ExtractStridedSliceOp, arith::ConstantOp>(op)) {
69 <<
"unhandled operation type in nvgpu.mma.sync conversion path";
76 if (failed(contractOp))
79 if ((*contractOp).getLhs() == op->
getResult(0))
81 else if ((*contractOp).getRhs() == op->
getResult(0))
94 return isAcc ? 512 : 256;
99 FailureOr<FragmentElementInfo>
105 if (elType.
isF16()) {
112 if (elType.
isF64()) {
139 if (elType.
isF32()) {
154 int64_t elementsPerRegister,
156 const int64_t elementsPerLine =
158 const std::array<int64_t, 2> num8x128bTiles =
163 {(registerIdx % num8x128bTiles[0]) * 8,
164 (registerIdx.
floorDiv(num8x128bTiles[0])) * elementsPerLine},
173 FailureOr<nvgpu::FragmentElementInfo> regInfo =
179 const int64_t elementsPerRegister =
180 regInfo->registerWidthBits / elementBitWidth;
189 lineSize, elementType, operandShape,
197 auto tileRow = registerIndexToTileCoord.
getResult(0);
198 auto tileCol = registerIndexToTileCoord.
getResult(1);
201 (logicalValueIdDim % elementsPerRegister)});
204 FailureOr<nvgpu::LdMatrixParams>
217 : vector::IteratorType::reduction;
237 const int bitsPerElement =
static_cast<int>(
238 params.
fragmentType.getElementType().getIntOrFloatBitWidth());
239 const int kElementsPer128b = (128 / bitsPerElement);
254 AffineExpr strided = d0 % (operandShape[idx]);
261 return makeMap({strided, contiguous});
267 return makeMap({contiguous, strided});
273 if (op.getMask() || op.hasOutOfBoundsDim())
275 VectorType type = op.getType();
281 if (!type.hasStaticShape() || type.getRank() != 2)
288 auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
295 auto [strides, offset] = sourceType.getStridesAndOffset();
296 return strides.back() == 1;
300 if (op.getMask() || op.hasOutOfBoundsDim() || op.getTransferRank() == 0)
302 VectorType type = op.getVectorType();
303 if (!type.hasStaticShape() || type.getRank() != 2)
308 if (!op.getPermutationMap().isMinorIdentity())
312 auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
319 auto [strides, offset] = sourceType.getStridesAndOffset();
320 return strides.back() == 1;
static constexpr int64_t kNumRowsPerTile
static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize, Type elementType, ArrayRef< int64_t > operandShape, bool isAccumulator, int64_t elementsPerRegister, AffineExpr logicalValueId)
static constexpr int64_t kThreadsPerRow
There are always 4 threads per [128|256|512] bit row.
static bool isAccumulatorOrResult(MatMulOperandRole operandType)
static int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type)
Returns the number of registers which compose a matrix fragment held by a single thread.
static std::array< int64_t, 2 > getTileShape(ArrayRef< int64_t > operandShape, Type elementType, int64_t lineSizeBits)
Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineExpr getResult(unsigned idx) const
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
user_range getUsers()
Returns a range of all users.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Type getType() const
Return the type of this value.
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams ¶ms)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
MatMulOperandRole
Represents the role of an operand in an MMA instruction: result := matmul(A, B) + C
FailureOr< LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
Given type that contains info for a warp-matrix operand and whether or not the load is a transposed l...
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
bool canLowerToWarpMatrixOperation(vector::TransferReadOp op)
Returns whether the vector.transfer_read instruction can be interpreted as a warp-level cooperative m...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
Encapsulates the parameters needed to lower a nvgpu.ldmatrix operation to nvvm.ldmatrix.
NVVM::MMALayout targetLayout
vector::IteratorType contiguousDimType
Collects information about a warp-level matrix operand represented by a VectorType.
MatMulOperandRole operandRole