MLIR 23.0.0git
IndexedAccessOpInterfaceImpl.cpp
Go to the documentation of this file.
1//===- IndexedAccessOpInterfaceImpl.cpp -----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// Implement IndexedAccessOpInterface on GPU dialect operations that have
9// %memref[%i0, %i1, ...] arguments to allow them to be manipulated by
10// generic memref-dialect passes.
11//===----------------------------------------------------------------------===//
12
14
17#include "mlir/IR/Dialect.h"
18#include "mlir/IR/Operation.h"
20
21using namespace mlir;
22using namespace mlir::memref;
23using namespace mlir::gpu;
24
25/// Given a GPU matrix type that will be loaded or stored, the leading dimension
26/// of the matrix in memory, and whether or not the matrix is transposed,
27/// compute the size of the linear memory that the load/store spans as
28/// dC + leadingDim * (dR - 1) where dR and dC are the non-contiguous and
29/// contiguous matrix dimensions, respectively (we get to the dX-1th row and
30/// then access the first dY elements of it).
31static int64_t get1DAccessSize(MMAMatrixType matrixType, int64_t leadingDim,
32 bool transpose) {
33 assert(matrixType.getShape().size() == 2 && "expected matrices to be 2D");
34
35 int64_t c = matrixType.getShape()[1];
36 int64_t r = matrixType.getShape()[0];
37 if (transpose)
38 std::swap(c, r);
39 return c + leadingDim * (r - 1);
40}
41
42namespace {
43struct SubgroupMmaLoadMatrixOpImpl final
44 : IndexedAccessOpInterface::ExternalModel<SubgroupMmaLoadMatrixOpImpl,
45 SubgroupMmaLoadMatrixOp> {
46 TypedValue<MemRefType> getAccessedMemref(Operation *op) const {
47 return cast<SubgroupMmaLoadMatrixOp>(op).getSrcMemref();
48 }
49
50 Operation::operand_range getIndices(Operation *op) const {
51 return cast<SubgroupMmaLoadMatrixOp>(op).getIndices();
52 }
53
54 /// This returns a 1-D shape so that it's clear that both linearization and
55 /// folding in expand/collapse_shape operations are allowed.
56 SmallVector<int64_t> getAccessedShape(Operation *op) const {
57 auto loadOp = cast<SubgroupMmaLoadMatrixOp>(op);
58 return {get1DAccessSize(cast<MMAMatrixType>(loadOp.getRes().getType()),
59 loadOp.getLeadDimension().getZExtValue(),
60 loadOp.getTranspose().value_or(false))};
61 }
62
63 std::optional<SmallVector<Value>>
64 updateMemrefAndIndices(Operation *op, RewriterBase &rewriter, Value newMemref,
65 ValueRange newIndices) const {
66 auto loadOp = cast<SubgroupMmaLoadMatrixOp>(op);
67 rewriter.modifyOpInPlace(loadOp, [&]() {
68 loadOp.getSrcMemrefMutable().assign(newMemref);
69 loadOp.getIndicesMutable().assign(newIndices);
70 });
71 return std::nullopt;
72 }
73
74 bool hasInboundsIndices(Operation *) const { return true; }
75};
76
77struct SubgroupMmaStoreMatrixOpImpl final
78 : IndexedAccessOpInterface::ExternalModel<SubgroupMmaStoreMatrixOpImpl,
79 SubgroupMmaStoreMatrixOp> {
80 TypedValue<MemRefType> getAccessedMemref(Operation *op) const {
81 return cast<SubgroupMmaStoreMatrixOp>(op).getDstMemref();
82 }
83
84 Operation::operand_range getIndices(Operation *op) const {
85 return cast<SubgroupMmaStoreMatrixOp>(op).getIndices();
86 }
87
88 /// This returns a 1-D shape so that it's clear that both linearization and
89 /// folding in expand/collapse_shape operations are allowed.
90 SmallVector<int64_t> getAccessedShape(Operation *op) const {
91 auto storeOp = cast<SubgroupMmaStoreMatrixOp>(op);
92 return {get1DAccessSize(storeOp.getSrc().getType(),
93 storeOp.getLeadDimension().getZExtValue(),
94 storeOp.getTranspose().value_or(false))};
95 }
96
97 std::optional<SmallVector<Value>>
98 updateMemrefAndIndices(Operation *op, RewriterBase &rewriter, Value newMemref,
99 ValueRange newIndices) const {
100 auto storeOp = cast<SubgroupMmaStoreMatrixOp>(op);
101 rewriter.modifyOpInPlace(storeOp, [&]() {
102 storeOp.getDstMemrefMutable().assign(newMemref);
103 storeOp.getIndicesMutable().assign(newIndices);
104 });
105 return std::nullopt;
106 }
107
108 bool hasInboundsIndices(Operation *) const { return true; }
109};
110} // namespace
111
113 DialectRegistry &registry) {
114 registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
115 SubgroupMmaLoadMatrixOp::attachInterface<SubgroupMmaLoadMatrixOpImpl>(*ctx);
116 SubgroupMmaStoreMatrixOp::attachInterface<SubgroupMmaStoreMatrixOpImpl>(
117 *ctx);
118 });
119}
static int64_t get1DAccessSize(MMAMatrixType matrixType, int64_t leadingDim, bool transpose)
Given a GPU matrix type that will be loaded or stored, the leading dimension of the matrix in memory,...
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
OperandRange operand_range
Definition Operation.h:397
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition GPUDialect.h:131
ArrayRef< int64_t > getShape() const
Get shape of the matrix.
void registerIndexedAccessOpInterfaceExternalModels(DialectRegistry &registry)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494