28 int64_t lineSize = inferTileWidthInBits(type);
50 if (
auto contractOp = dyn_cast<vector::ContractionOp>(user))
60 if (vector::TransferWriteOp writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
62 }
else if (isa<vector::TransferReadOp, vector::ContractionOp,
63 vector::ExtractStridedSliceOp, arith::ConstantOp>(op)) {
67 <<
"unhandled operation type in nvgpu.mma.sync conversion path";
74 if (failed(contractOp))
77 if ((*contractOp).getLhs() == op->
getResult(0))
79 else if ((*contractOp).getRhs() == op->
getResult(0))
92 return isAcc ? 512 : 256;
97FailureOr<FragmentElementInfo>
103 if (elType.
isF16()) {
104 return FragmentElementInfo{VectorType::get(2, Float16Type::get(ctx)), 2, 32,
109 Type f64Ty = Float64Type::get(ctx);
110 if (elType.
isF64()) {
112 ? FragmentElementInfo{VectorType::get(2, f64Ty), 2, 128,
114 : FragmentElementInfo{f64Ty, 1, 64,
120 return FragmentElementInfo{VectorType::get(4, IntegerType::get(ctx, 8)), 4,
126 return FragmentElementInfo{VectorType::get(8, IntegerType::get(ctx, 4)), 8,
132 return FragmentElementInfo{VectorType::get(2, IntegerType::get(ctx, 32)), 2,
137 if (elType.
isF32()) {
138 Type f32Ty = Float32Type::get(ctx);
140 ? FragmentElementInfo{VectorType::get(2, f32Ty), 2, 64,
142 : FragmentElementInfo{f32Ty, 1, 32,
154 const int64_t elementsPerLine =
156 const std::array<int64_t, 2> num8x128bTiles =
161 {(registerIdx % num8x128bTiles[0]) * 8,
162 (registerIdx.
floorDiv(num8x128bTiles[0])) * elementsPerLine},
171 FailureOr<FragmentElementInfo> regInfo = getMmaSyncRegisterType(fragmentType);
176 const int64_t elementsPerRegister =
177 regInfo->registerWidthBits / elementBitWidth;
178 const int64_t lineSize = inferTileWidthInBits(fragmentType);
186 lineSize, elementType, operandShape,
194 auto tileRow = registerIndexToTileCoord.
getResult(0);
195 auto tileCol = registerIndexToTileCoord.
getResult(1);
198 (logicalValueIdDim % elementsPerRegister)});
201FailureOr<LdMatrixParams> nvgpu::getLdMatrixParams(
const WarpMatrixInfo &type,
203 LdMatrixParams params;
208 params.targetLayout = NVVM::MMALayout::row;
210 params.targetLayout = NVVM::MMALayout::col;
213 params.contiguousDimType = transpose ? vector::IteratorType::parallel
214 : vector::IteratorType::reduction;
216 if (params.contiguousDimType == vector::IteratorType::reduction) {
224 if (params.numTiles == 0)
232 const LdMatrixParams ¶ms) {
234 const int bitsPerElement =
static_cast<int>(
235 params.fragmentType.getElementType().getIntOrFloatBitWidth());
236 const int kElementsPer128b = (128 / bitsPerElement);
247 (params.contiguousDimType == vector::IteratorType::reduction) ? 0 : 1;
251 AffineExpr strided = d0 % (operandShape[idx]);
257 if (params.contiguousDimType == vector::IteratorType::reduction)
258 return makeMap({strided, contiguous});
263 if (params.contiguousDimType == vector::IteratorType::parallel)
264 return makeMap({contiguous, strided});
270 if (op.getMask() || op.hasOutOfBoundsDim())
272 VectorType type = op.getType();
278 if (!type.hasStaticShape() || type.getRank() != 2)
285 auto sourceType = dyn_cast<MemRefType>(op.getBase().getType());
292 auto [strides, offset] = sourceType.getStridesAndOffset();
293 return strides.back() == 1;
297 if (op.getMask() || op.hasOutOfBoundsDim() || op.getTransferRank() == 0)
299 VectorType type = op.getVectorType();
300 if (!type.hasStaticShape() || type.getRank() != 2)
305 if (!op.getPermutationMap().isMinorIdentity())
309 auto sourceType = dyn_cast<MemRefType>(op.getBase().getType());
316 auto [strides, offset] = sourceType.getStridesAndOffset();
317 return strides.back() == 1;
static constexpr int64_t kNumRowsPerTile
static AffineMap getRegisterIndexToTileOffsetMap(int64_t lineSize, Type elementType, ArrayRef< int64_t > operandShape, bool isAccumulator, int64_t elementsPerRegister, AffineExpr logicalValueId)
static constexpr int64_t kThreadsPerRow
There are always 4 threads per [128|256|512] bit row.
static bool isAccumulatorOrResult(MatMulOperandRole operandType)
static int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type)
Returns the number of registers which compose a matrix fragment held by a single thread.
static std::array< int64_t, 2 > getTileShape(ArrayRef< int64_t > operandShape, Type elementType, int64_t lineSizeBits)
Returns the number of 8 x [128|256|512] bit tiles that compose the given operand shape.
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineExpr getResult(unsigned idx) const
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
user_range getUsers()
Returns a range of all users.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Type getType() const
Return the type of this value.
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
bool canLowerToWarpMatrixOperation(vector::TransferWriteOp op)
Returns the number of bits in a single tile row.
MatMulOperandRole
Represents the role of an operand in an MMA instruction: result := matmul(A, B) + C
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Collects information about a warp-level matrix operand represented by a VectorType.
MatMulOperandRole operandRole