33 assert(matrixType.
getShape().size() == 2 &&
"expected matrices to be 2D");
39 return c + leadingDim * (r - 1);
43struct SubgroupMmaLoadMatrixOpImpl final
44 : IndexedAccessOpInterface::ExternalModel<SubgroupMmaLoadMatrixOpImpl,
45 SubgroupMmaLoadMatrixOp> {
47 return cast<SubgroupMmaLoadMatrixOp>(op).getSrcMemref();
51 return cast<SubgroupMmaLoadMatrixOp>(op).getIndices();
56 SmallVector<int64_t> getAccessedShape(Operation *op)
const {
57 auto loadOp = cast<SubgroupMmaLoadMatrixOp>(op);
59 loadOp.getLeadDimension().getZExtValue(),
60 loadOp.getTranspose().value_or(
false))};
63 std::optional<SmallVector<Value>>
64 updateMemrefAndIndices(Operation *op, RewriterBase &rewriter, Value newMemref,
66 auto loadOp = cast<SubgroupMmaLoadMatrixOp>(op);
68 loadOp.getSrcMemrefMutable().assign(newMemref);
69 loadOp.getIndicesMutable().assign(newIndices);
74 bool hasInboundsIndices(Operation *)
const {
return true; }
77struct SubgroupMmaStoreMatrixOpImpl final
78 : IndexedAccessOpInterface::ExternalModel<SubgroupMmaStoreMatrixOpImpl,
79 SubgroupMmaStoreMatrixOp> {
81 return cast<SubgroupMmaStoreMatrixOp>(op).getDstMemref();
85 return cast<SubgroupMmaStoreMatrixOp>(op).getIndices();
90 SmallVector<int64_t> getAccessedShape(Operation *op)
const {
91 auto storeOp = cast<SubgroupMmaStoreMatrixOp>(op);
93 storeOp.getLeadDimension().getZExtValue(),
94 storeOp.getTranspose().value_or(
false))};
97 std::optional<SmallVector<Value>>
98 updateMemrefAndIndices(Operation *op, RewriterBase &rewriter, Value newMemref,
100 auto storeOp = cast<SubgroupMmaStoreMatrixOp>(op);
102 storeOp.getDstMemrefMutable().assign(newMemref);
103 storeOp.getIndicesMutable().assign(newIndices);
108 bool hasInboundsIndices(Operation *)
const {
return true; }
115 SubgroupMmaLoadMatrixOp::attachInterface<SubgroupMmaLoadMatrixOpImpl>(*ctx);
116 SubgroupMmaStoreMatrixOp::attachInterface<SubgroupMmaStoreMatrixOpImpl>(
static int64_t get1DAccessSize(MMAMatrixType matrixType, int64_t leadingDim, bool transpose)
Given a GPU matrix type that will be loaded or stored, the leading dimension of the matrix in memory,...
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
OperandRange operand_range
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
void registerIndexedAccessOpInterfaceExternalModels(DialectRegistry ®istry)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.