MLIR 23.0.0git
MemoryAccessOpInterfacesImpl.cpp
Go to the documentation of this file.
1//===- MemoryAccessOpInterfacesImpl.cpp -----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// Implement memref dialect interfaces that enable manipulating memref indexing
9// in passes like FoldMemRefAliasOps.
10//===----------------------------------------------------------------------===//
11
13
17#include "mlir/IR/Dialect.h"
18#include "mlir/IR/Operation.h"
20
21using namespace mlir;
22using namespace mlir::amdgpu;
23using namespace mlir::memref;
24
25namespace {
26template <typename OpTy>
27struct TransposeLoadAccess final
28 : IndexedAccessOpInterface::ExternalModel<TransposeLoadAccess<OpTy>, OpTy> {
29 TypedValue<MemRefType> getAccessedMemref(Operation *op) const {
30 return cast<TypedValue<MemRefType>>(cast<OpTy>(op).getSrc());
31 }
32
33 Operation::operand_range getIndices(Operation *op) const {
34 return cast<OpTy>(op).getSrcIndices();
35 }
36
37 SmallVector<int64_t> getAccessedShape(Operation *op) const {
38 return {cast<VectorType>(cast<OpTy>(op).getResult().getType())
39 .getNumElements()};
40 }
41
42 std::optional<SmallVector<Value>>
43 updateMemrefAndIndices(Operation *op, RewriterBase &rewriter, Value newMemref,
44 ValueRange newIndices) const {
45 auto accessOp = cast<OpTy>(op);
46 rewriter.modifyOpInPlace(accessOp, [&]() {
47 accessOp.getSrcMutable().assign(newMemref);
48 accessOp.getSrcIndicesMutable().assign(newIndices);
49 });
50 return std::nullopt;
51 }
52
53 bool hasInboundsIndices(Operation *) const { return true; }
54};
55
56template <typename OpTy>
57struct BaseAndIndicesAccess final
58 : IndexedAccessOpInterface::ExternalModel<BaseAndIndicesAccess<OpTy>,
59 OpTy> {
60 TypedValue<MemRefType> getAccessedMemref(Operation *op) const {
61 return cast<TypedValue<MemRefType>>(cast<OpTy>(op).getBase());
62 }
63
64 Operation::operand_range getIndices(Operation *op) const {
65 return cast<OpTy>(op).getIndices();
66 }
67
68 SmallVector<int64_t> getAccessedShape(Operation *) const { return {}; }
69
70 std::optional<SmallVector<Value>>
71 updateMemrefAndIndices(Operation *op, RewriterBase &rewriter, Value newMemref,
72 ValueRange newIndices) const {
73 auto accessOp = cast<OpTy>(op);
74 rewriter.modifyOpInPlace(accessOp, [&]() {
75 accessOp.getBaseMutable().assign(newMemref);
76 accessOp.getIndicesMutable().assign(newIndices);
77 });
78 return std::nullopt;
79 }
80
81 bool hasInboundsIndices(Operation *) const { return true; }
82};
83
84template <typename OpTy>
85struct DescriptorAtomicBarrierAccess final
86 : IndexedAccessOpInterface::ExternalModel<
87 DescriptorAtomicBarrierAccess<OpTy>, OpTy> {
88 TypedValue<MemRefType> getAccessedMemref(Operation *op) const {
89 Value memref = cast<OpTy>(op).getAtomicBarrierAddress();
90 if (!memref)
91 return {};
92 return cast<TypedValue<MemRefType>>(memref);
93 }
94
95 Operation::operand_range getIndices(Operation *op) const {
96 return cast<OpTy>(op).getAtomicBarrierIndices();
97 }
98
99 SmallVector<int64_t> getAccessedShape(Operation *) const { return {}; }
100
101 std::optional<SmallVector<Value>>
102 updateMemrefAndIndices(Operation *op, RewriterBase &rewriter, Value newMemref,
103 ValueRange newIndices) const {
104 auto accessOp = cast<OpTy>(op);
105 rewriter.modifyOpInPlace(accessOp, [&]() {
106 accessOp.getAtomicBarrierAddressMutable().assign(newMemref);
107 accessOp.getAtomicBarrierIndicesMutable().assign(newIndices);
108 });
109 return std::nullopt;
110 }
111
112 bool hasInboundsIndices(Operation *) const { return true; }
113};
114
115struct GatherToLDSCopy final
116 : IndexedMemCopyOpInterface::ExternalModel<GatherToLDSCopy, GatherToLDSOp> {
117 TypedValue<MemRefType> getSrc(Operation *op) const {
118 return cast<TypedValue<MemRefType>>(cast<GatherToLDSOp>(op).getSrc());
119 }
120
121 Operation::operand_range getSrcIndices(Operation *op) const {
122 return cast<GatherToLDSOp>(op).getSrcIndices();
123 }
124
125 TypedValue<MemRefType> getDst(Operation *op) const {
126 return cast<TypedValue<MemRefType>>(cast<GatherToLDSOp>(op).getDst());
127 }
128
129 Operation::operand_range getDstIndices(Operation *op) const {
130 return cast<GatherToLDSOp>(op).getDstIndices();
131 }
132
133 void setMemrefsAndIndices(Operation *op, RewriterBase &rewriter, Value newSrc,
134 ValueRange newSrcIndices, Value newDst,
135 ValueRange newDstIndices) const {
136 auto copyOp = cast<GatherToLDSOp>(op);
137 rewriter.modifyOpInPlace(copyOp, [&]() {
138 copyOp.getSrcMutable().assign(newSrc);
139 copyOp.getSrcIndicesMutable().assign(newSrcIndices);
140 copyOp.getDstMutable().assign(newDst);
141 copyOp.getDstIndicesMutable().assign(newDstIndices);
142 });
143 }
144
145 bool hasInboundsSrcIndices(Operation *op) const {
146 MemRefType srcType = cast<GatherToLDSOp>(op).getSrc().getType();
147 return !isFatRawBufferMemorySpace(srcType.getMemorySpace());
148 }
149
150 bool hasInboundsDstIndices(Operation *) const { return true; }
151};
152
153struct GlobalLoadAsyncToLDSCopy final
154 : IndexedMemCopyOpInterface::ExternalModel<GlobalLoadAsyncToLDSCopy,
155 GlobalLoadAsyncToLDSOp> {
156 TypedValue<MemRefType> getSrc(Operation *op) const {
157 return cast<TypedValue<MemRefType>>(
158 cast<GlobalLoadAsyncToLDSOp>(op).getSrc());
159 }
160
161 Operation::operand_range getSrcIndices(Operation *op) const {
162 return cast<GlobalLoadAsyncToLDSOp>(op).getSrcIndices();
163 }
164
165 TypedValue<MemRefType> getDst(Operation *op) const {
166 return cast<TypedValue<MemRefType>>(
167 cast<GlobalLoadAsyncToLDSOp>(op).getDst());
168 }
169
170 Operation::operand_range getDstIndices(Operation *op) const {
171 return cast<GlobalLoadAsyncToLDSOp>(op).getDstIndices();
172 }
173
174 void setMemrefsAndIndices(Operation *op, RewriterBase &rewriter, Value newSrc,
175 ValueRange newSrcIndices, Value newDst,
176 ValueRange newDstIndices) const {
177 auto copyOp = cast<GlobalLoadAsyncToLDSOp>(op);
178 rewriter.modifyOpInPlace(copyOp, [&]() {
179 copyOp.getSrcMutable().assign(newSrc);
180 copyOp.getSrcIndicesMutable().assign(newSrcIndices);
181 copyOp.getDstMutable().assign(newDst);
182 copyOp.getDstIndicesMutable().assign(newDstIndices);
183 });
184 }
185
186 bool hasInboundsSrcIndices(Operation *) const { return true; }
187
188 bool hasInboundsDstIndices(Operation *op) const {
189 // Masked lanes may carry out-of-bounds destination indices; lowering
190 // replaces their destination pointer with -1 before the instruction uses
191 // it.
192 return !cast<GlobalLoadAsyncToLDSOp>(op).getMask();
193 }
194};
195
196template <typename OpTy>
197struct DmaBaseCopy final
198 : IndexedMemCopyOpInterface::ExternalModel<DmaBaseCopy<OpTy>, OpTy> {
199 TypedValue<MemRefType> getSrc(Operation *op) const {
200 return cast<TypedValue<MemRefType>>(cast<OpTy>(op).getGlobal());
201 }
202
203 Operation::operand_range getSrcIndices(Operation *op) const {
204 return cast<OpTy>(op).getGlobalIndices();
205 }
206
207 TypedValue<MemRefType> getDst(Operation *op) const {
208 return cast<TypedValue<MemRefType>>(cast<OpTy>(op).getLds());
209 }
210
211 Operation::operand_range getDstIndices(Operation *op) const {
212 return cast<OpTy>(op).getLdsIndices();
213 }
214
215 void setMemrefsAndIndices(Operation *op, RewriterBase &rewriter, Value newSrc,
216 ValueRange newSrcIndices, Value newDst,
217 ValueRange newDstIndices) const {
218 auto copyOp = cast<OpTy>(op);
219 rewriter.modifyOpInPlace(copyOp, [&]() {
220 copyOp.getGlobalMutable().assign(newSrc);
221 copyOp.getGlobalIndicesMutable().assign(newSrcIndices);
222 copyOp.getLdsMutable().assign(newDst);
223 copyOp.getLdsIndicesMutable().assign(newDstIndices);
224 });
225 }
226
227 bool hasInboundsSrcIndices(Operation *) const { return true; }
228
229 bool hasInboundsDstIndices(Operation *) const { return true; }
230};
231} // namespace
232
234 DialectRegistry &registry) {
235 registry.addExtension(+[](MLIRContext *ctx, amdgpu::AMDGPUDialect *) {
236 TransposeLoadOp::attachInterface<TransposeLoadAccess<TransposeLoadOp>>(
237 *ctx);
238 GlobalTransposeLoadOp::attachInterface<
239 TransposeLoadAccess<GlobalTransposeLoadOp>>(*ctx);
240 MakeDmaDescriptorOp::attachInterface<
241 DescriptorAtomicBarrierAccess<MakeDmaDescriptorOp>>(*ctx);
242 MakeGatherDmaDescriptorOp::attachInterface<
243 DescriptorAtomicBarrierAccess<MakeGatherDmaDescriptorOp>>(*ctx);
244 DsBarrierInitOp::attachInterface<BaseAndIndicesAccess<DsBarrierInitOp>>(
245 *ctx);
246 DsBarrierPollStateOp::attachInterface<
247 BaseAndIndicesAccess<DsBarrierPollStateOp>>(*ctx);
248 DsAsyncBarrierArriveOp::attachInterface<
249 BaseAndIndicesAccess<DsAsyncBarrierArriveOp>>(*ctx);
250 DsBarrierArriveOp::attachInterface<BaseAndIndicesAccess<DsBarrierArriveOp>>(
251 *ctx);
252 GatherToLDSOp::attachInterface<GatherToLDSCopy>(*ctx);
253 GlobalLoadAsyncToLDSOp::attachInterface<GlobalLoadAsyncToLDSCopy>(*ctx);
254 MakeDmaBaseOp::attachInterface<DmaBaseCopy<MakeDmaBaseOp>>(*ctx);
255 MakeGatherDmaBaseOp::attachInterface<DmaBaseCopy<MakeGatherDmaBaseOp>>(
256 *ctx);
257 });
258}
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
OperandRange operand_range
Definition Operation.h:396
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void registerMemoryAccessOpInterfacesExternalModels(DialectRegistry &registry)
bool isFatRawBufferMemorySpace(Attribute memorySpace)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494