26template <
typename OpTy>
27struct TransposeLoadAccess final
28 : IndexedAccessOpInterface::ExternalModel<TransposeLoadAccess<OpTy>, OpTy> {
30 return cast<TypedValue<MemRefType>>(cast<OpTy>(op).getSrc());
34 return cast<OpTy>(op).getSrcIndices();
37 SmallVector<int64_t> getAccessedShape(Operation *op)
const {
38 return {cast<VectorType>(cast<OpTy>(op).getResult().
getType())
42 std::optional<SmallVector<Value>>
43 updateMemrefAndIndices(Operation *op, RewriterBase &rewriter, Value newMemref,
45 auto accessOp = cast<OpTy>(op);
47 accessOp.getSrcMutable().assign(newMemref);
48 accessOp.getSrcIndicesMutable().assign(newIndices);
53 bool hasInboundsIndices(Operation *)
const {
return true; }
56template <
typename OpTy>
57struct BaseAndIndicesAccess final
58 : IndexedAccessOpInterface::ExternalModel<BaseAndIndicesAccess<OpTy>,
61 return cast<TypedValue<MemRefType>>(cast<OpTy>(op).getBase());
65 return cast<OpTy>(op).getIndices();
68 SmallVector<int64_t> getAccessedShape(Operation *)
const {
return {}; }
70 std::optional<SmallVector<Value>>
71 updateMemrefAndIndices(Operation *op, RewriterBase &rewriter, Value newMemref,
73 auto accessOp = cast<OpTy>(op);
75 accessOp.getBaseMutable().assign(newMemref);
76 accessOp.getIndicesMutable().assign(newIndices);
81 bool hasInboundsIndices(Operation *)
const {
return true; }
84template <
typename OpTy>
85struct DescriptorAtomicBarrierAccess final
86 : IndexedAccessOpInterface::ExternalModel<
87 DescriptorAtomicBarrierAccess<OpTy>, OpTy> {
89 Value memref = cast<OpTy>(op).getAtomicBarrierAddress();
92 return cast<TypedValue<MemRefType>>(memref);
96 return cast<OpTy>(op).getAtomicBarrierIndices();
99 SmallVector<int64_t> getAccessedShape(Operation *)
const {
return {}; }
101 std::optional<SmallVector<Value>>
102 updateMemrefAndIndices(Operation *op, RewriterBase &rewriter, Value newMemref,
104 auto accessOp = cast<OpTy>(op);
106 accessOp.getAtomicBarrierAddressMutable().assign(newMemref);
107 accessOp.getAtomicBarrierIndicesMutable().assign(newIndices);
112 bool hasInboundsIndices(Operation *)
const {
return true; }
115struct GatherToLDSCopy final
116 : IndexedMemCopyOpInterface::ExternalModel<GatherToLDSCopy, GatherToLDSOp> {
118 return cast<TypedValue<MemRefType>>(cast<GatherToLDSOp>(op).getSrc());
122 return cast<GatherToLDSOp>(op).getSrcIndices();
126 return cast<TypedValue<MemRefType>>(cast<GatherToLDSOp>(op).getDst());
130 return cast<GatherToLDSOp>(op).getDstIndices();
133 void setMemrefsAndIndices(Operation *op, RewriterBase &rewriter, Value newSrc,
136 auto copyOp = cast<GatherToLDSOp>(op);
138 copyOp.getSrcMutable().assign(newSrc);
139 copyOp.getSrcIndicesMutable().assign(newSrcIndices);
140 copyOp.getDstMutable().assign(newDst);
141 copyOp.getDstIndicesMutable().assign(newDstIndices);
145 bool hasInboundsSrcIndices(Operation *op)
const {
146 MemRefType srcType = cast<GatherToLDSOp>(op).getSrc().getType();
150 bool hasInboundsDstIndices(Operation *)
const {
return true; }
153struct GlobalLoadAsyncToLDSCopy final
154 : IndexedMemCopyOpInterface::ExternalModel<GlobalLoadAsyncToLDSCopy,
155 GlobalLoadAsyncToLDSOp> {
157 return cast<TypedValue<MemRefType>>(
158 cast<GlobalLoadAsyncToLDSOp>(op).getSrc());
162 return cast<GlobalLoadAsyncToLDSOp>(op).getSrcIndices();
166 return cast<TypedValue<MemRefType>>(
167 cast<GlobalLoadAsyncToLDSOp>(op).getDst());
171 return cast<GlobalLoadAsyncToLDSOp>(op).getDstIndices();
174 void setMemrefsAndIndices(Operation *op, RewriterBase &rewriter, Value newSrc,
177 auto copyOp = cast<GlobalLoadAsyncToLDSOp>(op);
179 copyOp.getSrcMutable().assign(newSrc);
180 copyOp.getSrcIndicesMutable().assign(newSrcIndices);
181 copyOp.getDstMutable().assign(newDst);
182 copyOp.getDstIndicesMutable().assign(newDstIndices);
186 bool hasInboundsSrcIndices(Operation *)
const {
return true; }
188 bool hasInboundsDstIndices(Operation *op)
const {
192 return !cast<GlobalLoadAsyncToLDSOp>(op).getMask();
196template <
typename OpTy>
197struct DmaBaseCopy final
198 : IndexedMemCopyOpInterface::ExternalModel<DmaBaseCopy<OpTy>, OpTy> {
200 return cast<TypedValue<MemRefType>>(cast<OpTy>(op).getGlobal());
204 return cast<OpTy>(op).getGlobalIndices();
208 return cast<TypedValue<MemRefType>>(cast<OpTy>(op).getLds());
212 return cast<OpTy>(op).getLdsIndices();
215 void setMemrefsAndIndices(Operation *op, RewriterBase &rewriter, Value newSrc,
218 auto copyOp = cast<OpTy>(op);
220 copyOp.getGlobalMutable().assign(newSrc);
221 copyOp.getGlobalIndicesMutable().assign(newSrcIndices);
222 copyOp.getLdsMutable().assign(newDst);
223 copyOp.getLdsIndicesMutable().assign(newDstIndices);
227 bool hasInboundsSrcIndices(Operation *)
const {
return true; }
229 bool hasInboundsDstIndices(Operation *)
const {
return true; }
236 TransposeLoadOp::attachInterface<TransposeLoadAccess<TransposeLoadOp>>(
238 GlobalTransposeLoadOp::attachInterface<
239 TransposeLoadAccess<GlobalTransposeLoadOp>>(*ctx);
240 MakeDmaDescriptorOp::attachInterface<
241 DescriptorAtomicBarrierAccess<MakeDmaDescriptorOp>>(*ctx);
242 MakeGatherDmaDescriptorOp::attachInterface<
243 DescriptorAtomicBarrierAccess<MakeGatherDmaDescriptorOp>>(*ctx);
244 DsBarrierInitOp::attachInterface<BaseAndIndicesAccess<DsBarrierInitOp>>(
246 DsBarrierPollStateOp::attachInterface<
247 BaseAndIndicesAccess<DsBarrierPollStateOp>>(*ctx);
248 DsAsyncBarrierArriveOp::attachInterface<
249 BaseAndIndicesAccess<DsAsyncBarrierArriveOp>>(*ctx);
250 DsBarrierArriveOp::attachInterface<BaseAndIndicesAccess<DsBarrierArriveOp>>(
252 GatherToLDSOp::attachInterface<GatherToLDSCopy>(*ctx);
253 GlobalLoadAsyncToLDSOp::attachInterface<GlobalLoadAsyncToLDSCopy>(*ctx);
254 MakeDmaBaseOp::attachInterface<DmaBaseCopy<MakeDmaBaseOp>>(*ctx);
255 MakeGatherDmaBaseOp::attachInterface<DmaBaseCopy<MakeGatherDmaBaseOp>>(
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
OperandRange operand_range
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void registerMemoryAccessOpInterfacesExternalModels(DialectRegistry ®istry)
bool isFatRawBufferMemorySpace(Attribute memorySpace)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.