MLIR 23.0.0git
MemorySlot.cpp
Go to the documentation of this file.
1//===- MemorySlot.cpp - Memory Slot interface implementations for SCF -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11
12using namespace mlir;
13using namespace mlir::scf;
14
15//===----------------------------------------------------------------------===//
16// Helper functions
17//===----------------------------------------------------------------------===//
18
19/// Adds the corresponding reaching definition to the terminator of the block if
20/// the terminator is of the provided type.
21template <typename TermTy>
22static void
23updateTerminator(Block *block, Value defaultReachingDef,
24 const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd) {
25 Operation *terminator = block->getTerminator();
26 if (!isa<TermTy>(terminator))
27 return;
28 Value blockReachingDef = reachingAtBlockEnd.lookup(block);
29 if (!blockReachingDef) {
30 // Block is dead code or the region is not using the slot, so we use the
31 // default provided reaching definition.
32 blockReachingDef = defaultReachingDef;
33 }
34 terminator->insertOperands(terminator->getNumOperands(), {blockReachingDef});
35}
36
37/// Creates a shallow copy of an operation with new result types, moving the
38/// regions out of the original operation and deleting the original operation.
40 TypeRange resultTypes) {
41 RewriterBase::InsertionGuard guard(rewriter);
42 rewriter.setInsertionPoint(op);
43 Operation *newOp =
44 mlir::cloneWithoutRegions(rewriter, op, resultTypes, op->getOperands());
45 rewriter.startOpModification(newOp);
46 rewriter.startOpModification(op);
47 for (unsigned int i : llvm::seq(op->getNumRegions()))
48 newOp->getRegion(i).takeBody(op->getRegion(i));
49 rewriter.finalizeOpModification(op);
50 rewriter.finalizeOpModification(newOp);
51
52 SmallVector<Value> replacementValues(newOp->getResults().drop_back());
53 rewriter.replaceAllOpUsesWith(op, replacementValues);
54 rewriter.eraseOp(op);
55 return newOp;
56}
57
58//===----------------------------------------------------------------------===//
59// ExecuteRegionOp
60//===----------------------------------------------------------------------===//
61
62bool ExecuteRegionOp::isRegionPromotable(const MemorySlot &slot, Region *region,
63 bool hasValueStores) {
64 return true;
65}
66
67void ExecuteRegionOp::setupPromotion(
68 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
69 llvm::SmallMapVector<Region *, Value, 2> &regionsToProcess) {
70 regionsToProcess.insert({&getRegion(), reachingDef});
71}
72
73Value ExecuteRegionOp::finalizePromotion(
74 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
75 const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd,
76 OpBuilder &builder) {
77 if (!hasValueStores)
78 return reachingDef;
79
80 // Update the yield terminators to return the newly defined reaching
81 // definition.
82 for (Block &block : getRegion().getBlocks())
83 updateTerminator<YieldOp>(&block, reachingDef, reachingAtBlockEnd);
84
85 SmallVector<Type> resultTypes(getResultTypes());
86 resultTypes.push_back(slot.elemType);
87
88 IRRewriter rewriter(builder);
89 Operation *newOp =
90 replaceWithNewResults(rewriter, getOperation(), resultTypes);
91 return newOp->getResults().back();
92}
93
94//===----------------------------------------------------------------------===//
95// ForOp
96//===----------------------------------------------------------------------===//
97
98bool ForOp::isRegionPromotable(const MemorySlot &slot, Region *region,
99 bool hasValueStores) {
100 return true;
101}
102
103void ForOp::setupPromotion(
104 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
105 llvm::SmallMapVector<Region *, Value, 2> &regionsToProcess) {
106 Region &bodyRegion = getBodyRegion();
107 if (!hasValueStores) {
108 regionsToProcess.insert({&bodyRegion, reachingDef});
109 return;
110 }
111
112 getInitArgsMutable().append(reachingDef);
113 bodyRegion.addArgument(slot.elemType, slot.ptr.getLoc());
114 regionsToProcess.insert({&bodyRegion, bodyRegion.getArguments().back()});
115}
116
117Value ForOp::finalizePromotion(
118 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
119 const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd,
120 OpBuilder &builder) {
121 if (!hasValueStores)
122 return reachingDef;
123
124 // Update the yield terminator to return the newly defined reaching
125 // definition.
126 updateTerminator<YieldOp>(getBody(), reachingDef, reachingAtBlockEnd);
127
128 SmallVector<Type> resultTypes(getResultTypes());
129 resultTypes.push_back(slot.elemType);
130
131 IRRewriter rewriter(builder);
132 Operation *newOp =
133 replaceWithNewResults(rewriter, getOperation(), resultTypes);
134 return newOp->getResults().back();
135}
136
137//===----------------------------------------------------------------------===//
138// ForallOp
139//===----------------------------------------------------------------------===//
140
141bool ForallOp::isRegionPromotable(const MemorySlot &slot, Region *region,
142 bool hasValueStores) {
143 // The ForallOp body can be ran in parallel, thus does not support sequenced
144 // value passing. Therefore only loads can be handled.
145 return !hasValueStores;
146}
147
148void ForallOp::setupPromotion(
149 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
150 llvm::SmallMapVector<Region *, Value, 2> &regionsToProcess) {
151 assert(!hasValueStores && "ForallOp does not support stores");
152 regionsToProcess.insert({&getBodyRegion(), reachingDef});
153}
154
155Value ForallOp::finalizePromotion(
156 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
157 const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd,
158 OpBuilder &builder) {
159 assert(!hasValueStores && "ForallOp does not support stores");
160 return reachingDef;
161}
162
163//===----------------------------------------------------------------------===//
164// IfOp
165//===----------------------------------------------------------------------===//
166
167bool IfOp::isRegionPromotable(const MemorySlot &slot, Region *region,
168 bool hasValueStores) {
169 return true;
170}
171
172void IfOp::setupPromotion(
173 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
174 llvm::SmallMapVector<Region *, Value, 2> &regionsToProcess) {
175 regionsToProcess.insert({&getThenRegion(), reachingDef});
176 regionsToProcess.insert({&getElseRegion(), reachingDef});
177}
178
179Value IfOp::finalizePromotion(
180 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
181 const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd,
182 OpBuilder &builder) {
183 if (!hasValueStores)
184 return reachingDef;
185
186 IRRewriter rewriter(builder);
187
188 // Update the yield terminators to return the newly defined reaching
189 // definition.
190 updateTerminator<YieldOp>(&getThenRegion().back(), reachingDef,
191 reachingAtBlockEnd);
192 if (getElseRegion().hasOneBlock()) {
193 updateTerminator<YieldOp>(&getElseRegion().back(), reachingDef,
194 reachingAtBlockEnd);
195 } else {
196 OpBuilder::InsertionGuard guard(rewriter);
197 rewriter.createBlock(&getElseRegion());
198 YieldOp::create(rewriter, getOperation()->getLoc(), reachingDef);
199 }
200
201 SmallVector<Type> resultTypes(getResultTypes());
202 resultTypes.push_back(slot.elemType);
203
204 Operation *newOp =
205 replaceWithNewResults(rewriter, getOperation(), resultTypes);
206 return newOp->getResults().back();
207}
208
209//===----------------------------------------------------------------------===//
210// IndexSwitchOp
211//===----------------------------------------------------------------------===//
212
213bool IndexSwitchOp::isRegionPromotable(const MemorySlot &slot, Region *region,
214 bool hasValueStores) {
215 return true;
216}
217
218void IndexSwitchOp::setupPromotion(
219 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
220 llvm::SmallMapVector<Region *, Value, 2> &regionsToProcess) {
221 regionsToProcess.insert({&getDefaultRegion(), reachingDef});
222 for (Region &caseRegion : getCaseRegions())
223 regionsToProcess.insert({&caseRegion, reachingDef});
224}
225
226Value IndexSwitchOp::finalizePromotion(
227 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
228 const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd,
229 OpBuilder &builder) {
230 if (!hasValueStores)
231 return reachingDef;
232
233 IRRewriter rewriter(builder);
234
235 // Update the yield terminators to return the newly defined reaching
236 // definition.
237 updateTerminator<YieldOp>(&getDefaultRegion().back(), reachingDef,
238 reachingAtBlockEnd);
239 for (Region &caseRegion : getCaseRegions())
240 updateTerminator<YieldOp>(&caseRegion.back(), reachingDef,
241 reachingAtBlockEnd);
242
243 SmallVector<Type> resultTypes(getResultTypes());
244 resultTypes.push_back(slot.elemType);
245
246 Operation *newOp =
247 replaceWithNewResults(rewriter, getOperation(), resultTypes);
248 return newOp->getResults().back();
249}
250
251//===----------------------------------------------------------------------===//
252// ParallelOp
253//===----------------------------------------------------------------------===//
254
255bool ParallelOp::isRegionPromotable(const MemorySlot &slot, Region *region,
256 bool hasValueStores) {
257 // The ParallelOp body can be ran in parallel, thus does not support sequenced
258 // value passing. Therefore only loads can be handled.
259 return !hasValueStores;
260}
261
262void ParallelOp::setupPromotion(
263 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
264 llvm::SmallMapVector<Region *, Value, 2> &regionsToProcess) {
265 assert(!hasValueStores && "ParallelOp does not support stores");
266 regionsToProcess.insert({&getBodyRegion(), reachingDef});
267}
268
269Value ParallelOp::finalizePromotion(
270 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
271 const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd,
272 OpBuilder &builder) {
273 assert(!hasValueStores && "ParallelOp does not support stores");
274 return reachingDef;
275}
276
277//===----------------------------------------------------------------------===//
278// ReduceOp
279//===----------------------------------------------------------------------===//
280
281bool ReduceOp::isRegionPromotable(const MemorySlot &slot, Region *region,
282 bool hasValueStores) {
283 // The ReduceOp body can be ran in parallel, thus does not support sequenced
284 // value passing. Therefore only loads can be handled.
285 return !hasValueStores;
286}
287
288void ReduceOp::setupPromotion(
289 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
290 llvm::SmallMapVector<Region *, Value, 2> &regionsToProcess) {
291 assert(!hasValueStores && "ReduceOp does not support stores");
292 for (Region &reduction : getReductions())
293 regionsToProcess.insert({&reduction, reachingDef});
294}
295
296Value ReduceOp::finalizePromotion(
297 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
298 const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd,
299 OpBuilder &builder) {
300 assert(!hasValueStores && "ReduceOp does not support stores");
301 return reachingDef;
302}
303
304//===----------------------------------------------------------------------===//
305// WhileOp
306//===----------------------------------------------------------------------===//
307
308bool WhileOp::isRegionPromotable(const MemorySlot &slot, Region *region,
309 bool hasValueStores) {
310 return true;
311}
312
313void WhileOp::setupPromotion(
314 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
315 llvm::SmallMapVector<Region *, Value, 2> &regionsToProcess) {
316 Region &beforeRegion = getBefore();
317 Region &afterRegion = getAfter();
318 if (!hasValueStores) {
319 regionsToProcess.insert({&beforeRegion, reachingDef});
320 regionsToProcess.insert({&afterRegion, reachingDef});
321 return;
322 }
323
324 getInitsMutable().append(reachingDef);
325
326 beforeRegion.addArgument(slot.elemType, slot.ptr.getLoc());
327 regionsToProcess.insert({&beforeRegion, beforeRegion.getArguments().back()});
328
329 afterRegion.addArgument(slot.elemType, slot.ptr.getLoc());
330 regionsToProcess.insert({&afterRegion, afterRegion.getArguments().back()});
331}
332
333Value WhileOp::finalizePromotion(
334 const MemorySlot &slot, Value reachingDef, bool hasValueStores,
335 const llvm::DenseMap<Block *, Value> &reachingAtBlockEnd,
336 OpBuilder &builder) {
337 if (!hasValueStores)
338 return reachingDef;
339
340 // Update the yield terminators to return the newly defined reaching
341 // definition.
342 updateTerminator<ConditionOp>(&getBefore().back(),
343 getBefore().getArguments().back(),
344 reachingAtBlockEnd);
346 &getAfter().back(), getAfter().getArguments().back(), reachingAtBlockEnd);
347
348 SmallVector<Type> resultTypes(getResultTypes());
349 resultTypes.push_back(slot.elemType);
350
351 IRRewriter rewriter(builder);
352 Operation *newOp =
353 replaceWithNewResults(rewriter, getOperation(), resultTypes);
354 return newOp->getResults().back();
355}
static void updateTerminator(Block *block, Value defaultReachingDef, const llvm::DenseMap< Block *, Value > &reachingAtBlockEnd)
Adds the corresponding reaching definition to the terminator of the block if the terminator is of the...
static Operation * replaceWithNewResults(RewriterBase &rewriter, Operation *op, TypeRange resultTypes)
Creates a shallow copy of an operation with new result types, moving the regions out of the original ...
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:715
void insertOperands(unsigned index, ValueRange operands)
Insert the given operands into the operand list at the given 'index'.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:703
unsigned getNumOperands()
Definition Operation.h:375
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:407
result_range getResults()
Definition Operation.h:444
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
BlockArgListType getArguments()
Definition Region.h:81
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition Region.h:252
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Region.h:98
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Include the generated interface declarations.
Operation * cloneWithoutRegions(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
Represents a slot in memory.
Value ptr
Pointer to the memory slot, used by operations to refer to it.
Type elemType
Type of the value contained in the slot.