MLIR  19.0.0git
PipelineDataTransfer.cpp
Go to the documentation of this file.
1 //===- PipelineDataTransfer.cpp --- Pass for pipelining data movement ---*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to pipeline data transfers.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
24 #include "mlir/IR/Builders.h"
25 #include "mlir/Transforms/Passes.h"
26 #include "llvm/ADT/DenseMap.h"
27 #include "llvm/Support/Debug.h"
28 
29 namespace mlir {
30 namespace affine {
31 #define GEN_PASS_DEF_AFFINEPIPELINEDATATRANSFER
32 #include "mlir/Dialect/Affine/Passes.h.inc"
33 } // namespace affine
34 } // namespace mlir
35 
36 #define DEBUG_TYPE "affine-pipeline-data-transfer"
37 
38 using namespace mlir;
39 using namespace mlir::affine;
40 
41 namespace {
42 struct PipelineDataTransfer
43  : public affine::impl::AffinePipelineDataTransferBase<
44  PipelineDataTransfer> {
45  void runOnOperation() override;
46  void runOnAffineForOp(AffineForOp forOp);
47 
48  std::vector<AffineForOp> forOps;
49 };
50 
51 } // namespace
52 
53 /// Creates a pass to pipeline explicit movement of data across levels of the
54 /// memory hierarchy.
55 std::unique_ptr<OperationPass<func::FuncOp>>
57  return std::make_unique<PipelineDataTransfer>();
58 }
59 
60 // Returns the position of the tag memref operand given a DMA operation.
61 // Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
62 // added.
63 static unsigned getTagMemRefPos(Operation &dmaOp) {
64  assert((isa<AffineDmaStartOp, AffineDmaWaitOp>(dmaOp)));
65  if (auto dmaStartOp = dyn_cast<AffineDmaStartOp>(dmaOp)) {
66  return dmaStartOp.getTagMemRefOperandIndex();
67  }
68  // First operand for a dma finish operation.
69  return 0;
70 }
71 
72 /// Doubles the buffer of the supplied memref on the specified 'affine.for'
73 /// operation by adding a leading dimension of size two to the memref.
74 /// Replaces all uses of the old memref by the new one while indexing the newly
75 /// added dimension by the loop IV of the specified 'affine.for' operation
76 /// modulo 2. Returns false if such a replacement cannot be performed.
77 static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
78  auto *forBody = forOp.getBody();
79  OpBuilder bInner(forBody, forBody->begin());
80 
81  // Doubles the shape with a leading dimension extent of 2.
82  auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
83  // Add the leading dimension in the shape for the double buffer.
84  ArrayRef<int64_t> oldShape = oldMemRefType.getShape();
85  SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank());
86  newShape[0] = 2;
87  std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1);
88  return MemRefType::Builder(oldMemRefType).setShape(newShape).setLayout({});
89  };
90 
91  auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
92  auto newMemRefType = doubleShape(oldMemRefType);
93 
94  // The double buffer is allocated right before 'forOp'.
95  OpBuilder bOuter(forOp);
96  // Put together alloc operands for any dynamic dimensions of the memref.
97  SmallVector<Value, 4> allocOperands;
98  for (const auto &dim : llvm::enumerate(oldMemRefType.getShape())) {
99  if (dim.value() == ShapedType::kDynamic)
100  allocOperands.push_back(bOuter.createOrFold<memref::DimOp>(
101  forOp.getLoc(), oldMemRef, dim.index()));
102  }
103 
104  // Create and place the alloc right before the 'affine.for' operation.
105  Value newMemRef = bOuter.create<memref::AllocOp>(
106  forOp.getLoc(), newMemRefType, allocOperands);
107 
108  // Create 'iv mod 2' value to index the leading dimension.
109  auto d0 = bInner.getAffineDimExpr(0);
110  int64_t step = forOp.getStepAsInt();
111  auto modTwoMap =
112  AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, d0.floorDiv(step) % 2);
113  auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap,
114  forOp.getInductionVar());
115 
116  // replaceAllMemRefUsesWith will succeed unless the forOp body has
117  // non-dereferencing uses of the memref (dealloc's are fine though).
119  oldMemRef, newMemRef,
120  /*extraIndices=*/{ivModTwoOp},
121  /*indexRemap=*/AffineMap(),
122  /*extraOperands=*/{},
123  /*symbolOperands=*/{},
124  /*domOpFilter=*/&*forOp.getBody()->begin()))) {
125  LLVM_DEBUG(
126  forOp.emitError("memref replacement for double buffering failed"));
127  ivModTwoOp.erase();
128  return false;
129  }
130  // Insert the dealloc op right after the for loop.
131  bOuter.setInsertionPointAfter(forOp);
132  bOuter.create<memref::DeallocOp>(forOp.getLoc(), newMemRef);
133 
134  return true;
135 }
136 
137 /// Returns success if the IR is in a valid state.
138 void PipelineDataTransfer::runOnOperation() {
139  // Do a post order walk so that inner loop DMAs are processed first. This is
140  // necessary since 'affine.for' operations nested within would otherwise
141  // become invalid (erased) when the outer loop is pipelined (the pipelined one
142  // gets deleted and replaced by a prologue, a new steady-state loop and an
143  // epilogue).
144  forOps.clear();
145  getOperation().walk([&](AffineForOp forOp) { forOps.push_back(forOp); });
146  for (auto forOp : forOps)
147  runOnAffineForOp(forOp);
148 }
149 
150 // Check if tags of the dma start op and dma wait op match.
151 static bool checkTagMatch(AffineDmaStartOp startOp, AffineDmaWaitOp waitOp) {
152  if (startOp.getTagMemRef() != waitOp.getTagMemRef())
153  return false;
154  auto startIndices = startOp.getTagIndices();
155  auto waitIndices = waitOp.getTagIndices();
156  // Both of these have the same number of indices since they correspond to the
157  // same tag memref.
158  for (auto it = startIndices.begin(), wIt = waitIndices.begin(),
159  e = startIndices.end();
160  it != e; ++it, ++wIt) {
161  // Keep it simple for now, just checking if indices match.
162  // TODO: this would in general need to check if there is no
163  // intervening write writing to the same tag location, i.e., memory last
164  // write/data flow analysis. This is however sufficient/powerful enough for
165  // now since the DMA generation pass or the input for it will always have
166  // start/wait with matching tags (same SSA operand indices).
167  if (*it != *wIt)
168  return false;
169  }
170  return true;
171 }
172 
173 // Identify matching DMA start/finish operations to overlap computation with.
175  AffineForOp forOp,
176  SmallVectorImpl<std::pair<Operation *, Operation *>> &startWaitPairs) {
177 
178  // Collect outgoing DMA operations - needed to check for dependences below.
179  SmallVector<AffineDmaStartOp, 4> outgoingDmaOps;
180  for (auto &op : *forOp.getBody()) {
181  auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
182  if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster())
183  outgoingDmaOps.push_back(dmaStartOp);
184  }
185 
186  SmallVector<Operation *, 4> dmaStartInsts, dmaFinishInsts;
187  for (auto &op : *forOp.getBody()) {
188  // Collect DMA finish operations.
189  if (isa<AffineDmaWaitOp>(op)) {
190  dmaFinishInsts.push_back(&op);
191  continue;
192  }
193  auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
194  if (!dmaStartOp)
195  continue;
196 
197  // Only DMAs incoming into higher memory spaces are pipelined for now.
198  // TODO: handle outgoing DMA pipelining.
199  if (!dmaStartOp.isDestMemorySpaceFaster())
200  continue;
201 
202  // Check for dependence with outgoing DMAs. Doing this conservatively.
203  // TODO: use the dependence analysis to check for
204  // dependences between an incoming and outgoing DMA in the same iteration.
205  auto *it = outgoingDmaOps.begin();
206  for (; it != outgoingDmaOps.end(); ++it) {
207  if (it->getDstMemRef() == dmaStartOp.getSrcMemRef())
208  break;
209  }
210  if (it != outgoingDmaOps.end())
211  continue;
212 
213  // We only double buffer if the buffer is not live out of loop.
214  auto memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos());
215  bool escapingUses = false;
216  for (auto *user : memref.getUsers()) {
217  // We can double buffer regardless of dealloc's outside the loop.
218  if (isa<memref::DeallocOp>(user))
219  continue;
220  if (!forOp.getBody()->findAncestorOpInBlock(*user)) {
221  LLVM_DEBUG(llvm::dbgs()
222  << "can't pipeline: buffer is live out of loop\n";);
223  escapingUses = true;
224  break;
225  }
226  }
227  if (!escapingUses)
228  dmaStartInsts.push_back(&op);
229  }
230 
231  // For each start operation, we look for a matching finish operation.
232  for (auto *dmaStartOp : dmaStartInsts) {
233  for (auto *dmaFinishOp : dmaFinishInsts) {
234  if (checkTagMatch(cast<AffineDmaStartOp>(dmaStartOp),
235  cast<AffineDmaWaitOp>(dmaFinishOp))) {
236  startWaitPairs.push_back({dmaStartOp, dmaFinishOp});
237  break;
238  }
239  }
240  }
241 }
242 
243 /// Overlap DMA transfers with computation in this loop. If successful,
244 /// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are
245 /// inserted right before where it was.
246 void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) {
247  auto mayBeConstTripCount = getConstantTripCount(forOp);
248  if (!mayBeConstTripCount) {
249  LLVM_DEBUG(forOp.emitRemark("won't pipeline due to unknown trip count"));
250  return;
251  }
252 
254  findMatchingStartFinishInsts(forOp, startWaitPairs);
255 
256  if (startWaitPairs.empty()) {
257  LLVM_DEBUG(forOp.emitRemark("No dma start/finish pairs\n"));
258  return;
259  }
260 
261  // Double the buffers for the higher memory space memref's.
262  // Identify memref's to replace by scanning through all DMA start
263  // operations. A DMA start operation has two memref's - the one from the
264  // higher level of memory hierarchy is the one to double buffer.
265  // TODO: check whether double-buffering is even necessary.
266  // TODO: make this work with different layouts: assuming here that
267  // the dimension we are adding here for the double buffering is the outermost
268  // dimension.
269  for (auto &pair : startWaitPairs) {
270  auto *dmaStartOp = pair.first;
271  Value oldMemRef = dmaStartOp->getOperand(
272  cast<AffineDmaStartOp>(dmaStartOp).getFasterMemPos());
273  if (!doubleBuffer(oldMemRef, forOp)) {
274  // Normally, double buffering should not fail because we already checked
275  // that there are no uses outside.
276  LLVM_DEBUG(llvm::dbgs()
277  << "double buffering failed for" << dmaStartOp << "\n";);
278  // IR still valid and semantically correct.
279  return;
280  }
281  // If the old memref has no more uses, remove its 'dead' alloc if it was
282  // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim'
283  // operation could have been used on it if it was dynamically shaped in
284  // order to create the double buffer above.)
285  // '-canonicalize' does this in a more general way, but we'll anyway do the
286  // simple/common case so that the output / test cases looks clear.
287  if (auto *allocOp = oldMemRef.getDefiningOp()) {
288  if (oldMemRef.use_empty()) {
289  allocOp->erase();
290  } else if (oldMemRef.hasOneUse()) {
291  if (auto dealloc =
292  dyn_cast<memref::DeallocOp>(*oldMemRef.user_begin())) {
293  dealloc.erase();
294  allocOp->erase();
295  }
296  }
297  }
298  }
299 
300  // Double the buffers for tag memrefs.
301  for (auto &pair : startWaitPairs) {
302  auto *dmaFinishOp = pair.second;
303  Value oldTagMemRef = dmaFinishOp->getOperand(getTagMemRefPos(*dmaFinishOp));
304  if (!doubleBuffer(oldTagMemRef, forOp)) {
305  LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
306  return;
307  }
308  // If the old tag has no uses or a single dealloc use, remove it.
309  // (canonicalization handles more complex cases).
310  if (auto *tagAllocOp = oldTagMemRef.getDefiningOp()) {
311  if (oldTagMemRef.use_empty()) {
312  tagAllocOp->erase();
313  } else if (oldTagMemRef.hasOneUse()) {
314  if (auto dealloc =
315  dyn_cast<memref::DeallocOp>(*oldTagMemRef.user_begin())) {
316  dealloc.erase();
317  tagAllocOp->erase();
318  }
319  }
320  }
321  }
322 
323  // Double buffering would have invalidated all the old DMA start/wait insts.
324  startWaitPairs.clear();
325  findMatchingStartFinishInsts(forOp, startWaitPairs);
326 
327  // Store shift for operation for later lookup for AffineApplyOp's.
328  DenseMap<Operation *, unsigned> instShiftMap;
329  for (auto &pair : startWaitPairs) {
330  auto *dmaStartOp = pair.first;
331  assert(isa<AffineDmaStartOp>(dmaStartOp));
332  instShiftMap[dmaStartOp] = 0;
333  // Set shifts for DMA start op's affine operand computation slices to 0.
335  affine::createAffineComputationSlice(dmaStartOp, &sliceOps);
336  if (!sliceOps.empty()) {
337  for (auto sliceOp : sliceOps) {
338  instShiftMap[sliceOp.getOperation()] = 0;
339  }
340  } else {
341  // If a slice wasn't created, the reachable affine.apply op's from its
342  // operands are the ones that go with it.
343  SmallVector<Operation *, 4> affineApplyInsts;
344  SmallVector<Value, 4> operands(dmaStartOp->getOperands());
345  getReachableAffineApplyOps(operands, affineApplyInsts);
346  for (auto *op : affineApplyInsts) {
347  instShiftMap[op] = 0;
348  }
349  }
350  }
351  // Everything else (including compute ops and dma finish) are shifted by one.
352  for (auto &op : forOp.getBody()->without_terminator())
353  if (!instShiftMap.contains(&op))
354  instShiftMap[&op] = 1;
355 
356  // Get shifts stored in map.
357  SmallVector<uint64_t, 8> shifts(forOp.getBody()->getOperations().size());
358  unsigned s = 0;
359  for (auto &op : forOp.getBody()->without_terminator()) {
360  assert(instShiftMap.contains(&op));
361  shifts[s++] = instShiftMap[&op];
362 
363  // Tagging operations with shifts for debugging purposes.
364  LLVM_DEBUG({
365  OpBuilder b(&op);
366  op.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1]));
367  });
368  }
369 
370  if (!isOpwiseShiftValid(forOp, shifts)) {
371  // Violates dependences.
372  LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";);
373  return;
374  }
375 
376  if (failed(affineForOpBodySkew(forOp, shifts))) {
377  LLVM_DEBUG(llvm::dbgs() << "op body skewing failed - unexpected\n";);
378  return;
379  }
380 }
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void findMatchingStartFinishInsts(AffineForOp forOp, SmallVectorImpl< std::pair< Operation *, Operation * >> &startWaitPairs)
static unsigned getTagMemRefPos(Operation &dmaOp)
static bool checkTagMatch(AffineDmaStartOp startOp, AffineDmaWaitOp waitOp)
static bool doubleBuffer(Value oldMemRef, AffineForOp forOp)
Doubles the buffer of the supplied memref on the specified 'affine.for' operation by adding a leading...
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:371
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:201
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Definition: BuiltinTypes.h:222
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:212
This class helps build Operations.
Definition: Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:214
Type getType() const
Return the type of this value.
Definition: Value.h:125
user_iterator user_begin() const
Definition: Value.h:222
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:211
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
Definition: AffineOps.h:94
Value getTagMemRef()
Returns the Tag MemRef for this DMA operation.
Definition: AffineOps.h:175
operand_range getTagIndices()
Returns the tag memref indices for this DMA operation.
Definition: AffineOps.h:192
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
Definition: AffineOps.h:291
Value getTagMemRef()
Returns the Tag MemRef associated with the DMA operation being waited on.
Definition: AffineOps.h:302
operand_range getTagIndices()
Returns the tag memref index for this DMA operation.
Definition: AffineOps.h:314
std::optional< uint64_t > getConstantTripCount(AffineForOp forOp)
Returns the trip count of the loop if it's a constant, std::nullopt otherwise.
LogicalResult affineForOpBodySkew(AffineForOp forOp, ArrayRef< uint64_t > shifts, bool unrollPrologueEpilogue=false)
Skew the operations in an affine.for's body with the specified operation-wise shifts.
Definition: LoopUtils.cpp:237
void getReachableAffineApplyOps(ArrayRef< Value > operands, SmallVectorImpl< Operation * > &affineApplyOps)
Returns in affineApplyOps, the sequence of those AffineApplyOp Operations that are reachable via a se...
std::unique_ptr< OperationPass< func::FuncOp > > createPipelineDataTransferPass()
Creates a pass to pipeline explicit movement of data across levels of the memory hierarchy.
bool isOpwiseShiftValid(AffineForOp forOp, ArrayRef< uint64_t > shifts)
Checks where SSA dominance would be violated if a for op's body operations are shifted by the specifi...
void createAffineComputationSlice(Operation *opInst, SmallVectorImpl< AffineApplyOp > *sliceOps)
Given an operation, inserts one or more single result affine apply operations, results of which are e...
Definition: Utils.cpp:1383
LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, ArrayRef< Value > extraIndices={}, AffineMap indexRemap=AffineMap(), ArrayRef< Value > extraOperands={}, ArrayRef< Value > symbolOperands={}, Operation *domOpFilter=nullptr, Operation *postDomOpFilter=nullptr, bool allowNonDereferencingOps=false, bool replaceInDeallocOp=false)
Replaces all "dereferencing" uses of oldMemRef with newMemRef while optionally remapping the old memr...
Definition: Utils.cpp:1267
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72