31 (shape[1] * type.
vectorType.getElementType().getIntOrFloatBitWidth()) /
39 int64_t lineSizeBits) {
50 if (
auto contractOp = dyn_cast<vector::ContractionOp>(user))
60 if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
62 }
else if (isa<vector::TransferReadOp, vector::ContractionOp,
63 vector::ExtractStridedSliceOp, arith::ConstantOp>(op)) {
67 <<
"unhandled operation type in nvgpu.mma.sync conversion path";
74 if (failed(contractOp))
77 if ((*contractOp).getLhs() == op->
getResult(0))
79 else if ((*contractOp).getRhs() == op->
getResult(0))
92 return isAcc ? 512 : 256;
97 FailureOr<FragmentElementInfo>
103 if (elType.
isF16()) {
110 if (elType.
isF64()) {
137 if (elType.
isF32()) {
152 int64_t elementsPerRegister,
154 const int64_t elementsPerLine =
156 const std::array<int64_t, 2> num8x128bTiles =
161 {(registerIdx % num8x128bTiles[0]) * 8,
162 (registerIdx.
floorDiv(num8x128bTiles[0])) * elementsPerLine},
171 FailureOr<nvgpu::FragmentElementInfo> regInfo =
177 const int64_t elementsPerRegister =
178 regInfo->registerWidthBits / elementBitWidth;
187 lineSize, elementType, operandShape,
195 auto tileRow = registerIndexToTileCoord.
getResult(0);
196 auto tileCol = registerIndexToTileCoord.
getResult(1);
199 (logicalValueIdDim % elementsPerRegister)});
202 FailureOr<nvgpu::LdMatrixParams>
215 : vector::IteratorType::reduction;
235 const int bitsPerElement =
static_cast<int>(
236 params.
fragmentType.getElementType().getIntOrFloatBitWidth());
237 const int kElementsPer128b = (128 / bitsPerElement);
252 AffineExpr strided = d0 % (operandShape[idx]);
259 return makeMap({strided, contiguous});
265 return makeMap({contiguous, strided});
271 if (op.getMask() || op.hasOutOfBoundsDim())
273 VectorType type = op.getType();
279 if (!type.hasStaticShape() || type.getRank() != 2)
286 auto sourceType = dyn_cast<MemRefType>(op.getBase().getType());
293 auto [strides, offset] = sourceType.getStridesAndOffset();
294 return strides.back() == 1;
298 if (op.getMask() || op.hasOutOfBoundsDim() || op.getTransferRank() == 0)
300 VectorType type = op.getVectorType();
301 if (!type.hasStaticShape() || type.getRank() != 2)
306 if (!op.getPermutationMap().isMinorIdentity())
310 auto sourceType = dyn_cast<MemRefType>(op.getBase().getType());
317 auto [strides, offset] = sourceType.getStridesAndOffset();
318 return strides.back() == 1;
static constexpr int64_t kNumRowsPerTile
static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize, Type elementType, ArrayRef< int64_t > operandShape, bool isAccumulator, int64_t elementsPerRegister, AffineExpr logicalValueId)
static constexpr int64_t kThreadsPerRow
There are always 4 threads per [128|256|512] bit row.
static bool isAccumulatorOrResult(MatMulOperandRole operandType)
static int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type)
Returns the number of registers which compose a matrix fragment held by a single thread.
static std::array< int64_t, 2 > getTileShape(ArrayRef< int64_t > operandShape, Type elementType, int64_t lineSizeBits)
Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineExpr getResult(unsigned idx) const
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
user_range getUsers()
Returns a range of all users.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Type getType() const
Return the type of this value.
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams ¶ms)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
MatMulOperandRole
Represents the role of an operand in an MMA instruction: result := matmul(A, B) + C
FailureOr< LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
Given type that contains info for a warp-matrix operand and whether or not the load is a transposed l...
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
bool canLowerToWarpMatrixOperation(vector::TransferReadOp op)
Returns whether the vector.transfer_read instruction can be interpreted as a warp-level cooperative m...
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
Encapsulates the parameters needed to lower a nvgpu.ldmatrix operation to nvvm.ldmatrix.
NVVM::MMALayout targetLayout
vector::IteratorType contiguousDimType
Collects information about a warp-level matrix operand represented by a VectorType.
MatMulOperandRole operandRole