MLIR 22.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
15
16using namespace mlir;
17using namespace mlir::bufferization;
18using namespace mlir::ml_program;
19
20namespace mlir {
21namespace ml_program {
22namespace {
23
24template <typename Interface, typename Op>
25struct ExternalModelBase
26 : public BufferizableOpInterface::ExternalModel<Interface, Op> {
27
28 AliasingValueList getAliasingValues(Operation *, OpOperand &,
29 const AnalysisState &) const {
30 return {};
31 }
32
33 BufferRelation bufferRelation(Operation *, OpResult,
34 const AnalysisState &) const {
35 return BufferRelation::Unknown;
36 }
37};
38
39/// Bufferization of ml_program.global into a memref.global
40struct GlobalOpInterface
41 : public ExternalModelBase<GlobalOpInterface, GlobalOp> {
42
43 bool bufferizesToMemoryRead(Operation *, OpOperand &,
44 const AnalysisState &) const {
45 return false;
46 }
47
48 bool bufferizesToMemoryWrite(Operation *, OpOperand &,
49 const AnalysisState &) const {
50 return false;
51 }
52
53 bool hasTensorSemantics(Operation *) const { return true; }
54
55 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
56 const BufferizationOptions &,
57 BufferizationState &state) const {
58 auto globalOp = cast<GlobalOp>(op);
59 if (!globalOp.getValue().has_value())
60 return globalOp.emitError("global op must have a value");
61
62 bufferization::removeSymbol(globalOp, state);
63
64 auto tensorType = cast<TensorType>(globalOp.getType());
65 auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
66
67 auto replacement = replaceOpWithNewBufferizedOp<memref::GlobalOp>(
68 rewriter, globalOp, globalOp.getSymName(),
69 /*sym_visibility=*/globalOp.getSymVisibilityAttr(),
70 /*type=*/cast<MemRefType>(memrefType),
71 /*initial_value=*/globalOp.getValue().value(),
72 /*constant=*/!globalOp.getIsMutable(),
73 /*alignment=*/nullptr);
74
76 return success();
77 }
78};
79
80/// Bufferization of ml_program.global_load into a memref.get_global
81struct GlobalLoadOpInterface
82 : public ExternalModelBase<GlobalLoadOpInterface, GlobalLoadOp> {
83
84 bool bufferizesToMemoryRead(Operation *, OpOperand &,
85 const AnalysisState &) const {
86 return false;
87 }
88
89 bool bufferizesToMemoryWrite(Operation *, OpOperand &,
90 const AnalysisState &) const {
91 return false;
92 }
93
94 bool isWritable(Operation *, Value, const AnalysisState &) const {
95 return false;
96 }
97
98 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
99 const BufferizationOptions &,
100 BufferizationState &state) const {
101 auto globalLoadOp = cast<GlobalLoadOp>(op);
102
103 auto tensorType = cast<TensorType>(globalLoadOp.getType());
104 auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
105
106 replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
107 rewriter, globalLoadOp, memrefType,
108 globalLoadOp.getGlobalAttr().getLeafReference());
109
110 return success();
111 }
112};
113
114/// Bufferization of ml_program.global_store into a memref.get_global and
115/// memcpy
116struct GlobalStoreOpInterface
117 : public ExternalModelBase<GlobalStoreOpInterface, GlobalStoreOp> {
118
119 bool bufferizesToMemoryRead(Operation *, OpOperand &,
120 const AnalysisState &) const {
121 return false;
122 }
123
124 bool bufferizesToMemoryWrite(Operation *, OpOperand &,
125 const AnalysisState &) const {
126 return true;
127 }
128
129 LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
130 const BufferizationOptions &options,
131 BufferizationState &state) const {
132 auto globalStoreOp = cast<GlobalStoreOp>(op);
133
134 auto tensorType = cast<TensorType>(globalStoreOp.getValue().getType());
135 auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
136
137 auto loc = globalStoreOp.getLoc();
138 auto targetMemref = memref::GetGlobalOp::create(
139 rewriter, loc, memrefType,
140 globalStoreOp.getGlobalAttr().getLeafReference());
141
142 auto sourceMemref =
143 getBuffer(rewriter, globalStoreOp.getValue(), options, state);
144 if (failed(sourceMemref)) {
145 return failure();
146 }
147
148 auto memcpy =
149 options.createMemCpy(rewriter, loc, sourceMemref.value(), targetMemref);
150 if (failed(memcpy)) {
151 return failure();
152 }
153 rewriter.eraseOp(globalStoreOp);
154
155 return success();
156 }
157};
158} // namespace
159
161 registry.addExtension(+[](MLIRContext *ctx, MLProgramDialect *) {
162 GlobalOp::attachInterface<GlobalOpInterface>(*ctx);
163 GlobalLoadOp::attachInterface<GlobalLoadOpInterface>(*ctx);
164 GlobalStoreOp::attachInterface<GlobalStoreOpInterface>(*ctx);
165 });
166}
167} // namespace ml_program
168} // namespace mlir
return success()
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static llvm::ManagedStatic< PassManagerOptions > options
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void insertSymbol(Operation *op, BufferizationState &state)
void removeSymbol(Operation *op, BufferizationState &state)
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.