MLIR 22.0.0git
PipelineDataTransfer.cpp
Go to the documentation of this file.
1//===- PipelineDataTransfer.cpp --- Pass for pipelining data movement ---*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to pipeline data transfers.
10//
11//===----------------------------------------------------------------------===//
12
14
23#include "mlir/IR/Builders.h"
24#include "llvm/ADT/DenseMap.h"
25#include "llvm/Support/Debug.h"
26
27namespace mlir {
28namespace affine {
29#define GEN_PASS_DEF_AFFINEPIPELINEDATATRANSFER
30#include "mlir/Dialect/Affine/Passes.h.inc"
31} // namespace affine
32} // namespace mlir
33
34#define DEBUG_TYPE "affine-pipeline-data-transfer"
35
36using namespace mlir;
37using namespace mlir::affine;
38
39namespace {
40struct PipelineDataTransfer
42 PipelineDataTransfer> {
43 void runOnOperation() override;
44 void runOnAffineForOp(AffineForOp forOp);
45
46 std::vector<AffineForOp> forOps;
47};
48
49} // namespace
50
51/// Creates a pass to pipeline explicit movement of data across levels of the
52/// memory hierarchy.
53std::unique_ptr<OperationPass<func::FuncOp>>
55 return std::make_unique<PipelineDataTransfer>();
56}
57
58// Returns the position of the tag memref operand given a DMA operation.
59// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
60// added.
61static unsigned getTagMemRefPos(Operation &dmaOp) {
62 assert((isa<AffineDmaStartOp, AffineDmaWaitOp>(dmaOp)));
63 if (auto dmaStartOp = dyn_cast<AffineDmaStartOp>(dmaOp)) {
64 return dmaStartOp.getTagMemRefOperandIndex();
65 }
66 // First operand for a dma finish operation.
67 return 0;
68}
69
70/// Doubles the buffer of the supplied memref on the specified 'affine.for'
71/// operation by adding a leading dimension of size two to the memref.
72/// Replaces all uses of the old memref by the new one while indexing the newly
73/// added dimension by the loop IV of the specified 'affine.for' operation
74/// modulo 2. Returns false if such a replacement cannot be performed.
75static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
76 auto *forBody = forOp.getBody();
77 OpBuilder bInner(forBody, forBody->begin());
78
79 // Doubles the shape with a leading dimension extent of 2.
80 auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
81 // Add the leading dimension in the shape for the double buffer.
82 ArrayRef<int64_t> oldShape = oldMemRefType.getShape();
83 SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank());
84 newShape[0] = 2;
85 llvm::copy(oldShape, newShape.begin() + 1);
86 return MemRefType::Builder(oldMemRefType).setShape(newShape).setLayout({});
87 };
88
89 auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
90 auto newMemRefType = doubleShape(oldMemRefType);
91
92 // The double buffer is allocated right before 'forOp'.
93 OpBuilder bOuter(forOp);
94 // Put together alloc operands for any dynamic dimensions of the memref.
95 SmallVector<Value, 4> allocOperands;
96 for (const auto &dim : llvm::enumerate(oldMemRefType.getShape())) {
97 if (dim.value() == ShapedType::kDynamic)
98 allocOperands.push_back(bOuter.createOrFold<memref::DimOp>(
99 forOp.getLoc(), oldMemRef, dim.index()));
100 }
101
102 // Create and place the alloc right before the 'affine.for' operation.
103 Value newMemRef = memref::AllocOp::create(bOuter, forOp.getLoc(),
104 newMemRefType, allocOperands);
105
106 // Create 'iv mod 2' value to index the leading dimension.
107 auto d0 = bInner.getAffineDimExpr(0);
108 int64_t step = forOp.getStepAsInt();
109 auto modTwoMap =
110 AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, d0.floorDiv(step) % 2);
111 auto ivModTwoOp = AffineApplyOp::create(bInner, forOp.getLoc(), modTwoMap,
112 forOp.getInductionVar());
113
114 // replaceAllMemRefUsesWith will succeed unless the forOp body has
115 // non-dereferencing uses of the memref (dealloc's are fine though).
116 auto userFilterFn = [&](Operation *user) {
117 auto domInfo = std::make_unique<DominanceInfo>(
118 forOp->getParentOfType<FunctionOpInterface>());
119 return domInfo->dominates(&*forOp.getBody()->begin(), user);
120 };
121 if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef,
122 /*extraIndices=*/{ivModTwoOp},
123 /*indexRemap=*/AffineMap(),
124 /*extraOperands=*/{},
125 /*symbolOperands=*/{}, userFilterFn))) {
126 LLVM_DEBUG(
127 forOp.emitError("memref replacement for double buffering failed"));
128 ivModTwoOp.erase();
129 return false;
130 }
131 // Insert the dealloc op right after the for loop.
132 bOuter.setInsertionPointAfter(forOp);
133 memref::DeallocOp::create(bOuter, forOp.getLoc(), newMemRef);
134
135 return true;
136}
137
138/// Returns success if the IR is in a valid state.
139void PipelineDataTransfer::runOnOperation() {
140 // Do a post order walk so that inner loop DMAs are processed first. This is
141 // necessary since 'affine.for' operations nested within would otherwise
142 // become invalid (erased) when the outer loop is pipelined (the pipelined one
143 // gets deleted and replaced by a prologue, a new steady-state loop and an
144 // epilogue).
145 forOps.clear();
146 getOperation().walk([&](AffineForOp forOp) { forOps.push_back(forOp); });
147 for (auto forOp : forOps)
148 runOnAffineForOp(forOp);
149}
150
151// Check if tags of the dma start op and dma wait op match.
152static bool checkTagMatch(AffineDmaStartOp startOp, AffineDmaWaitOp waitOp) {
153 if (startOp.getTagMemRef() != waitOp.getTagMemRef())
154 return false;
155 auto startIndices = startOp.getTagIndices();
156 auto waitIndices = waitOp.getTagIndices();
157 // Both of these have the same number of indices since they correspond to the
158 // same tag memref.
159 for (auto it = startIndices.begin(), wIt = waitIndices.begin(),
160 e = startIndices.end();
161 it != e; ++it, ++wIt) {
162 // Keep it simple for now, just checking if indices match.
163 // TODO: this would in general need to check if there is no
164 // intervening write writing to the same tag location, i.e., memory last
165 // write/data flow analysis. This is however sufficient/powerful enough for
166 // now since the DMA generation pass or the input for it will always have
167 // start/wait with matching tags (same SSA operand indices).
168 if (*it != *wIt)
169 return false;
170 }
171 return true;
172}
173
174// Identify matching DMA start/finish operations to overlap computation with.
176 AffineForOp forOp,
177 SmallVectorImpl<std::pair<Operation *, Operation *>> &startWaitPairs) {
178
179 // Collect outgoing DMA operations - needed to check for dependences below.
181 for (auto &op : *forOp.getBody()) {
182 auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
183 if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster())
184 outgoingDmaOps.push_back(dmaStartOp);
185 }
186
187 SmallVector<Operation *, 4> dmaStartInsts, dmaFinishInsts;
188 for (auto &op : *forOp.getBody()) {
189 // Collect DMA finish operations.
190 if (isa<AffineDmaWaitOp>(op)) {
191 dmaFinishInsts.push_back(&op);
192 continue;
193 }
194 auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
195 if (!dmaStartOp)
196 continue;
197
198 // Only DMAs incoming into higher memory spaces are pipelined for now.
199 // TODO: handle outgoing DMA pipelining.
200 if (!dmaStartOp.isDestMemorySpaceFaster())
201 continue;
202
203 // Check for dependence with outgoing DMAs. Doing this conservatively.
204 // TODO: use the dependence analysis to check for
205 // dependences between an incoming and outgoing DMA in the same iteration.
206 auto *it = outgoingDmaOps.begin();
207 for (; it != outgoingDmaOps.end(); ++it) {
208 if (it->getDstMemRef() == dmaStartOp.getSrcMemRef())
209 break;
210 }
211 if (it != outgoingDmaOps.end())
212 continue;
213
214 // We only double buffer if the buffer is not live out of loop.
215 auto memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos());
216 bool escapingUses = false;
217 for (auto *user : memref.getUsers()) {
218 // We can double buffer regardless of dealloc's outside the loop.
219 if (isa<memref::DeallocOp>(user))
220 continue;
221 if (!forOp.getBody()->findAncestorOpInBlock(*user)) {
222 LLVM_DEBUG(llvm::dbgs()
223 << "can't pipeline: buffer is live out of loop\n";);
224 escapingUses = true;
225 break;
226 }
227 }
228 if (!escapingUses)
229 dmaStartInsts.push_back(&op);
230 }
231
232 // For each start operation, we look for a matching finish operation.
233 for (auto *dmaStartOp : dmaStartInsts) {
234 for (auto *dmaFinishOp : dmaFinishInsts) {
235 if (checkTagMatch(cast<AffineDmaStartOp>(dmaStartOp),
236 cast<AffineDmaWaitOp>(dmaFinishOp))) {
237 startWaitPairs.push_back({dmaStartOp, dmaFinishOp});
238 break;
239 }
240 }
241 }
242}
243
244/// Overlap DMA transfers with computation in this loop. If successful,
245/// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are
246/// inserted right before where it was.
247void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) {
248 auto mayBeConstTripCount = getConstantTripCount(forOp);
249 if (!mayBeConstTripCount) {
250 LLVM_DEBUG(forOp.emitRemark("won't pipeline due to unknown trip count"));
251 return;
252 }
253
254 SmallVector<std::pair<Operation *, Operation *>, 4> startWaitPairs;
255 findMatchingStartFinishInsts(forOp, startWaitPairs);
256
257 if (startWaitPairs.empty()) {
258 LLVM_DEBUG(forOp.emitRemark("No dma start/finish pairs\n"));
259 return;
260 }
261
262 // Double the buffers for the higher memory space memref's.
263 // Identify memref's to replace by scanning through all DMA start
264 // operations. A DMA start operation has two memref's - the one from the
265 // higher level of memory hierarchy is the one to double buffer.
266 // TODO: check whether double-buffering is even necessary.
267 // TODO: make this work with different layouts: assuming here that
268 // the dimension we are adding here for the double buffering is the outermost
269 // dimension.
270 for (auto &pair : startWaitPairs) {
271 auto *dmaStartOp = pair.first;
272 Value oldMemRef = dmaStartOp->getOperand(
273 cast<AffineDmaStartOp>(dmaStartOp).getFasterMemPos());
274 if (!doubleBuffer(oldMemRef, forOp)) {
275 // Normally, double buffering should not fail because we already checked
276 // that there are no uses outside.
277 LLVM_DEBUG(llvm::dbgs()
278 << "double buffering failed for" << dmaStartOp << "\n";);
279 // IR still valid and semantically correct.
280 return;
281 }
282 // If the old memref has no more uses, remove its 'dead' alloc if it was
283 // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim'
284 // operation could have been used on it if it was dynamically shaped in
285 // order to create the double buffer above.)
286 // '-canonicalize' does this in a more general way, but we'll anyway do the
287 // simple/common case so that the output / test cases looks clear.
288 if (auto *allocOp = oldMemRef.getDefiningOp()) {
289 if (oldMemRef.use_empty()) {
290 allocOp->erase();
291 } else if (oldMemRef.hasOneUse()) {
292 if (auto dealloc =
293 dyn_cast<memref::DeallocOp>(*oldMemRef.user_begin())) {
294 dealloc.erase();
295 allocOp->erase();
296 }
297 }
298 }
299 }
300
301 // Double the buffers for tag memrefs.
302 for (auto &pair : startWaitPairs) {
303 auto *dmaFinishOp = pair.second;
304 Value oldTagMemRef = dmaFinishOp->getOperand(getTagMemRefPos(*dmaFinishOp));
305 if (!doubleBuffer(oldTagMemRef, forOp)) {
306 LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
307 return;
308 }
309 // If the old tag has no uses or a single dealloc use, remove it.
310 // (canonicalization handles more complex cases).
311 if (auto *tagAllocOp = oldTagMemRef.getDefiningOp()) {
312 if (oldTagMemRef.use_empty()) {
313 tagAllocOp->erase();
314 } else if (oldTagMemRef.hasOneUse()) {
315 if (auto dealloc =
316 dyn_cast<memref::DeallocOp>(*oldTagMemRef.user_begin())) {
317 dealloc.erase();
318 tagAllocOp->erase();
319 }
320 }
321 }
322 }
323
324 // Double buffering would have invalidated all the old DMA start/wait insts.
325 startWaitPairs.clear();
326 findMatchingStartFinishInsts(forOp, startWaitPairs);
327
328 // Store shift for operation for later lookup for AffineApplyOp's.
330 for (auto &pair : startWaitPairs) {
331 auto *dmaStartOp = pair.first;
332 assert(isa<AffineDmaStartOp>(dmaStartOp));
333 instShiftMap[dmaStartOp] = 0;
334 // Set shifts for DMA start op's affine operand computation slices to 0.
335 SmallVector<AffineApplyOp, 4> sliceOps;
336 affine::createAffineComputationSlice(dmaStartOp, &sliceOps);
337 if (!sliceOps.empty()) {
338 for (auto sliceOp : sliceOps) {
339 instShiftMap[sliceOp.getOperation()] = 0;
340 }
341 } else {
342 // If a slice wasn't created, the reachable affine.apply op's from its
343 // operands are the ones that go with it.
344 SmallVector<Operation *, 4> affineApplyInsts;
345 SmallVector<Value, 4> operands(dmaStartOp->getOperands());
346 getReachableAffineApplyOps(operands, affineApplyInsts);
347 for (auto *op : affineApplyInsts) {
348 instShiftMap[op] = 0;
349 }
350 }
351 }
352 // Everything else (including compute ops and dma finish) are shifted by one.
353 for (auto &op : forOp.getBody()->without_terminator())
354 instShiftMap.try_emplace(&op, 1);
355
356 // Get shifts stored in map.
357 SmallVector<uint64_t, 8> shifts(forOp.getBody()->getOperations().size());
358 unsigned s = 0;
359 for (auto &op : forOp.getBody()->without_terminator()) {
360 assert(instShiftMap.contains(&op));
361 shifts[s++] = instShiftMap[&op];
362
363 // Tagging operations with shifts for debugging purposes.
364 LLVM_DEBUG({
365 OpBuilder b(&op);
366 op.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1]));
367 });
368 }
369
370 if (!isOpwiseShiftValid(forOp, shifts)) {
371 // Violates dependences.
372 LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";);
373 return;
374 }
375
376 if (failed(affineForOpBodySkew(forOp, shifts))) {
377 LLVM_DEBUG(llvm::dbgs() << "op body skewing failed - unexpected\n";);
378 return;
379 }
380}
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static unsigned getTagMemRefPos(Operation &dmaOp)
static void findMatchingStartFinishInsts(AffineForOp forOp, SmallVectorImpl< std::pair< Operation *, Operation * > > &startWaitPairs)
static bool checkTagMatch(AffineDmaStartOp startOp, AffineDmaWaitOp waitOp)
static bool doubleBuffer(Value oldMemRef, AffineForOp forOp)
Doubles the buffer of the supplied memref on the specified 'affine.for' operation by adding a leading...
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:364
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
This class helps build Operations.
Definition Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition Value.h:208
Type getType() const
Return the type of this value.
Definition Value.h:105
user_iterator user_begin() const
Definition Value.h:216
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
Definition AffineOps.h:106
Value getTagMemRef()
Returns the Tag MemRef for this DMA operation.
Definition AffineOps.h:210
operand_range getTagIndices()
Returns the tag memref indices for this DMA operation.
Definition AffineOps.h:231
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
Definition AffineOps.h:330
Value getTagMemRef()
Returns the Tag MemRef associated with the DMA operation being waited on.
Definition AffineOps.h:347
operand_range getTagIndices()
Returns the tag memref index for this DMA operation.
Definition AffineOps.h:361
std::optional< uint64_t > getConstantTripCount(AffineForOp forOp)
Returns the trip count of the loop if it's a constant, std::nullopt otherwise.
LogicalResult affineForOpBodySkew(AffineForOp forOp, ArrayRef< uint64_t > shifts, bool unrollPrologueEpilogue=false)
Skew the operations in an affine.for's body with the specified operation-wise shifts.
void getReachableAffineApplyOps(ArrayRef< Value > operands, SmallVectorImpl< Operation * > &affineApplyOps)
Returns in affineApplyOps, the sequence of those AffineApplyOp Operations that are reachable via a se...
std::unique_ptr< OperationPass< func::FuncOp > > createPipelineDataTransferPass()
Creates a pass to pipeline explicit movement of data across levels of the memory hierarchy.
bool isOpwiseShiftValid(AffineForOp forOp, ArrayRef< uint64_t > shifts)
Checks where SSA dominance would be violated if a for op's body operations are shifted by the specifi...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126