MLIR 23.0.0git
MemoryAccessOpInterfacesImpl.cpp
Go to the documentation of this file.
1//===- MemoryAccessOpInterfacesImpl.cpp -----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// Implement memref dialect interfaces that enable manipulating memref indexing
9// in passes like FoldMemRefAliasOps.
10//===----------------------------------------------------------------------===//
11
13
16#include "mlir/IR/Dialect.h"
17#include "mlir/IR/Operation.h"
19
20using namespace mlir;
21using namespace mlir::memref;
22using namespace mlir::nvgpu;
23
24namespace {
25struct LdMatrixOpInterface final
26 : IndexedAccessOpInterface::ExternalModel<LdMatrixOpInterface, LdMatrixOp> {
27 TypedValue<MemRefType> getAccessedMemref(Operation *op) const {
28 return cast<LdMatrixOp>(op).getSrcMemref();
29 }
30
31 Operation::operand_range getIndices(Operation *op) const {
32 return cast<LdMatrixOp>(op).getIndices();
33 }
34
35 SmallVector<int64_t> getAccessedShape(Operation *op) const {
36 VectorType vecTy = cast<LdMatrixOp>(op).getRes().getType();
37 // The 2-D nature of the result is an artifact of this operation returning
38 // a struct of vectors and doesn't reflect any strides that need to be
39 // preserved.
40 return SmallVector<int64_t>{vecTy.getNumElements()};
41 }
42
43 std::optional<SmallVector<Value>>
44 updateMemrefAndIndices(Operation *op, RewriterBase &rewriter, Value newMemref,
45 ValueRange newIndices) const {
46 auto ldMatrixOp = cast<LdMatrixOp>(op);
47 rewriter.modifyOpInPlace(ldMatrixOp, [&]() {
48 ldMatrixOp.getSrcMemrefMutable().assign(newMemref);
49 ldMatrixOp.getIndicesMutable().assign(newIndices);
50 });
51 return std::nullopt;
52 }
53
54 bool hasInboundsIndices(Operation *) const { return true; }
55};
56
57struct DeviceAsyncCopyOpInterface final
58 : IndexedMemCopyOpInterface::ExternalModel<DeviceAsyncCopyOpInterface,
59 DeviceAsyncCopyOp> {
60 TypedValue<MemRefType> getSrc(Operation *op) const {
61 return cast<DeviceAsyncCopyOp>(op).getSrc();
62 }
63
64 Operation::operand_range getSrcIndices(Operation *op) const {
65 return cast<DeviceAsyncCopyOp>(op).getSrcIndices();
66 }
67
68 TypedValue<MemRefType> getDst(Operation *op) const {
69 return cast<DeviceAsyncCopyOp>(op).getDst();
70 }
71
72 Operation::operand_range getDstIndices(Operation *op) const {
73 return cast<DeviceAsyncCopyOp>(op).getDstIndices();
74 }
75
76 void setMemrefsAndIndices(Operation *op, RewriterBase &rewriter, Value newSrc,
77 ValueRange newSrcIndices, Value newDst,
78 ValueRange newDstIndices) const {
79 auto copyOp = cast<DeviceAsyncCopyOp>(op);
80 rewriter.modifyOpInPlace(copyOp, [&]() {
81 copyOp.getSrcMutable().assign(newSrc);
82 copyOp.getSrcIndicesMutable().assign(newSrcIndices);
83 copyOp.getDstMutable().assign(newDst);
84 copyOp.getDstIndicesMutable().assign(newDstIndices);
85 });
86 }
87};
88} // namespace
89
91 DialectRegistry &registry) {
92 registry.addExtension(+[](MLIRContext *ctx, nvgpu::NVGPUDialect *dialect) {
93 LdMatrixOp::attachInterface<LdMatrixOpInterface>(*ctx);
94 DeviceAsyncCopyOp::attachInterface<DeviceAsyncCopyOpInterface>(*ctx);
95 });
96}
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
type_range getType() const
OperandRange operand_range
Definition Operation.h:397
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void registerMemoryAccessOpInterfacesExternalModels(DialectRegistry &registry)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
Include the generated interface declarations.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494