MLIR 23.0.0git
MemorySlot.cpp
Go to the documentation of this file.
1//===- MemorySlot.cpp - Memory Slot interface implementations for SCF -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11
12using namespace mlir;
13using namespace mlir::scf;
14
15//===----------------------------------------------------------------------===//
16// ExecuteRegionOp
17//===----------------------------------------------------------------------===//
18
19bool ExecuteRegionOp::isRegionPromotable(const MemorySlot &slot, Region *region,
20 bool hasValueStores) {
21 return true;
22}
23
24void ExecuteRegionOp::setupPromotion(
25 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
26 llvm::SmallMapVector<Region *, Value, 2> &regionsToProcess) {
27 regionsToProcess.insert({&getRegion(), reachingDef});
28}
29
30Value ExecuteRegionOp::finalizePromotion(
31 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
32 const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd,
33 OpBuilder &builder) {
34 if (!hasValueStores)
35 return reachingDef;
36
37 // Update the yield terminators to return the newly defined reaching
38 // definition.
39 for (Block &block : getRegion().getBlocks())
40 if (isa<YieldOp>(block.getTerminator()))
41 memoryslot::updateTerminator(&block, reachingDef, reachingAtBlockEnd);
42
43 SmallVector<Type> resultTypes(getResultTypes());
44 resultTypes.push_back(slot.elemType);
45
46 IRRewriter rewriter(builder);
47 Operation *newOp =
48 memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
49 return newOp->getResults().back();
50}
51
52//===----------------------------------------------------------------------===//
53// ForOp
54//===----------------------------------------------------------------------===//
55
56bool ForOp::isRegionPromotable(const MemorySlot &slot, Region *region,
57 bool hasValueStores) {
58 return true;
59}
60
61void ForOp::setupPromotion(
62 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
63 llvm::SmallMapVector<Region *, Value, 2> &regionsToProcess) {
64 Region &bodyRegion = getBodyRegion();
65 if (!hasValueStores) {
66 regionsToProcess.insert({&bodyRegion, reachingDef});
67 return;
68 }
69
70 getInitArgsMutable().append(reachingDef);
71 bodyRegion.addArgument(slot.elemType, slot.ptr.getLoc());
72 regionsToProcess.insert({&bodyRegion, bodyRegion.getArguments().back()});
73}
74
75Value ForOp::finalizePromotion(
76 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
77 const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd,
78 OpBuilder &builder) {
79 if (!hasValueStores)
80 return reachingDef;
81
82 // Update the yield terminator to return the newly defined reaching
83 // definition.
84 memoryslot::updateTerminator(getBody(), reachingDef, reachingAtBlockEnd);
85
86 SmallVector<Type> resultTypes(getResultTypes());
87 resultTypes.push_back(slot.elemType);
88
89 IRRewriter rewriter(builder);
90 Operation *newOp =
91 memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
92 return newOp->getResults().back();
93}
94
95//===----------------------------------------------------------------------===//
96// ForallOp
97//===----------------------------------------------------------------------===//
98
99bool ForallOp::isRegionPromotable(const MemorySlot &slot, Region *region,
100 bool hasValueStores) {
101 // The ForallOp body can be ran in parallel, thus does not support sequenced
102 // value passing. Therefore only loads can be handled.
103 return !hasValueStores;
104}
105
106void ForallOp::setupPromotion(
107 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
108 llvm::SmallMapVector<Region *, Value, 2> &regionsToProcess) {
109 assert(!hasValueStores && "ForallOp does not support stores");
110 regionsToProcess.insert({&getBodyRegion(), reachingDef});
111}
112
113Value ForallOp::finalizePromotion(
114 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
115 const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd,
116 OpBuilder &builder) {
117 assert(!hasValueStores && "ForallOp does not support stores");
118 return reachingDef;
119}
120
121//===----------------------------------------------------------------------===//
122// IfOp
123//===----------------------------------------------------------------------===//
124
125bool IfOp::isRegionPromotable(const MemorySlot &slot, Region *region,
126 bool hasValueStores) {
127 return true;
128}
129
130void IfOp::setupPromotion(
131 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
132 llvm::SmallMapVector<Region *, Value, 2> &regionsToProcess) {
133 regionsToProcess.insert({&getThenRegion(), reachingDef});
134 regionsToProcess.insert({&getElseRegion(), reachingDef});
135}
136
137Value IfOp::finalizePromotion(
138 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
139 const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd,
140 OpBuilder &builder) {
141 if (!hasValueStores)
142 return reachingDef;
143
144 IRRewriter rewriter(builder);
145
146 // Update the yield terminators to return the newly defined reaching
147 // definition.
148 memoryslot::updateTerminator(&getThenRegion().back(), reachingDef,
149 reachingAtBlockEnd);
150 if (getElseRegion().hasOneBlock()) {
151 memoryslot::updateTerminator(&getElseRegion().back(), reachingDef,
152 reachingAtBlockEnd);
153 } else {
154 OpBuilder::InsertionGuard guard(rewriter);
155 rewriter.createBlock(&getElseRegion());
156 YieldOp::create(rewriter, getOperation()->getLoc(), reachingDef);
157 }
158
159 SmallVector<Type> resultTypes(getResultTypes());
160 resultTypes.push_back(slot.elemType);
161
162 Operation *newOp =
163 memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
164 return newOp->getResults().back();
165}
166
167//===----------------------------------------------------------------------===//
168// IndexSwitchOp
169//===----------------------------------------------------------------------===//
170
171bool IndexSwitchOp::isRegionPromotable(const MemorySlot &slot, Region *region,
172 bool hasValueStores) {
173 return true;
174}
175
176void IndexSwitchOp::setupPromotion(
177 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
178 llvm::SmallMapVector<Region *, Value, 2> &regionsToProcess) {
179 regionsToProcess.insert({&getDefaultRegion(), reachingDef});
180 for (Region &caseRegion : getCaseRegions())
181 regionsToProcess.insert({&caseRegion, reachingDef});
182}
183
184Value IndexSwitchOp::finalizePromotion(
185 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
186 const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd,
187 OpBuilder &builder) {
188 if (!hasValueStores)
189 return reachingDef;
190
191 IRRewriter rewriter(builder);
192
193 // Update the yield terminators to return the newly defined reaching
194 // definition.
195 memoryslot::updateTerminator(&getDefaultRegion().back(), reachingDef,
196 reachingAtBlockEnd);
197 for (Region &caseRegion : getCaseRegions())
198 memoryslot::updateTerminator(&caseRegion.back(), reachingDef,
199 reachingAtBlockEnd);
200
201 SmallVector<Type> resultTypes(getResultTypes());
202 resultTypes.push_back(slot.elemType);
203
204 Operation *newOp =
205 memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
206 return newOp->getResults().back();
207}
208
209//===----------------------------------------------------------------------===//
210// ParallelOp
211//===----------------------------------------------------------------------===//
212
213bool ParallelOp::isRegionPromotable(const MemorySlot &slot, Region *region,
214 bool hasValueStores) {
215 // The ParallelOp body can be ran in parallel, thus does not support sequenced
216 // value passing. Therefore only loads can be handled.
217 return !hasValueStores;
218}
219
220void ParallelOp::setupPromotion(
221 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
222 llvm::SmallMapVector<Region *, Value, 2> &regionsToProcess) {
223 assert(!hasValueStores && "ParallelOp does not support stores");
224 regionsToProcess.insert({&getBodyRegion(), reachingDef});
225}
226
227Value ParallelOp::finalizePromotion(
228 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
229 const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd,
230 OpBuilder &builder) {
231 assert(!hasValueStores && "ParallelOp does not support stores");
232 return reachingDef;
233}
234
235//===----------------------------------------------------------------------===//
236// ReduceOp
237//===----------------------------------------------------------------------===//
238
239bool ReduceOp::isRegionPromotable(const MemorySlot &slot, Region *region,
240 bool hasValueStores) {
241 // The ReduceOp body can be ran in parallel, thus does not support sequenced
242 // value passing. Therefore only loads can be handled.
243 return !hasValueStores;
244}
245
246void ReduceOp::setupPromotion(
247 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
248 llvm::SmallMapVector<Region *, Value, 2> &regionsToProcess) {
249 assert(!hasValueStores && "ReduceOp does not support stores");
250 for (Region &reduction : getReductions())
251 regionsToProcess.insert({&reduction, reachingDef});
252}
253
254Value ReduceOp::finalizePromotion(
255 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
256 const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd,
257 OpBuilder &builder) {
258 assert(!hasValueStores && "ReduceOp does not support stores");
259 return reachingDef;
260}
261
262//===----------------------------------------------------------------------===//
263// WhileOp
264//===----------------------------------------------------------------------===//
265
266bool WhileOp::isRegionPromotable(const MemorySlot &slot, Region *region,
267 bool hasValueStores) {
268 return true;
269}
270
271void WhileOp::setupPromotion(
272 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
273 llvm::SmallMapVector<Region *, Value, 2> &regionsToProcess) {
274 Region &beforeRegion = getBefore();
275 Region &afterRegion = getAfter();
276 if (!hasValueStores) {
277 regionsToProcess.insert({&beforeRegion, reachingDef});
278 regionsToProcess.insert({&afterRegion, reachingDef});
279 return;
280 }
281
282 getInitsMutable().append(reachingDef);
283
284 beforeRegion.addArgument(slot.elemType, slot.ptr.getLoc());
285 regionsToProcess.insert({&beforeRegion, beforeRegion.getArguments().back()});
286
287 afterRegion.addArgument(slot.elemType, slot.ptr.getLoc());
288 regionsToProcess.insert({&afterRegion, afterRegion.getArguments().back()});
289}
290
291Value WhileOp::finalizePromotion(
292 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
293 const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd,
294 OpBuilder &builder) {
295 if (!hasValueStores)
296 return reachingDef;
297
298 // Update the yield terminators to return the newly defined reaching
299 // definition.
300 memoryslot::updateTerminator(&getBefore().back(),
301 getBefore().getArguments().back(),
302 reachingAtBlockEnd);
304 &getAfter().back(), getAfter().getArguments().back(), reachingAtBlockEnd);
305
306 SmallVector<Type> resultTypes(getResultTypes());
307 resultTypes.push_back(slot.elemType);
308
309 IRRewriter rewriter(builder);
310 Operation *newOp =
311 memoryslot::replaceWithNewResults(rewriter, getOperation(), resultTypes);
312 return newOp->getResults().back();
313}
Block represents an ordered list of Operations.
Definition Block.h:33
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class helps build Operations.
Definition Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
result_range getResults()
Definition Operation.h:441
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
BlockArgListType getArguments()
Definition Region.h:81
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Region.h:98
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * replaceWithNewResults(RewriterBase &rewriter, Operation *op, TypeRange resultTypes)
Creates a shallow copy of an operation with new result types, moving the regions out of the original ...
void updateTerminator(Block *block, Value defaultReachingDef, const DenseMap< Block *, Value > &reachingAtBlockEnd)
Appends the reaching definition for the given block as an operand to its terminator.
Include the generated interface declarations.
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.