MLIR  22.0.0git
LoopPipelining.cpp
Go to the documentation of this file.
1 //===- LoopPipelining.cpp - Code to perform loop software pipelining-------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop software pipelining
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/IRMapping.h"
19 #include "mlir/IR/PatternMatch.h"
21 #include "llvm/ADT/MapVector.h"
22 #include "llvm/Support/DebugLog.h"
23 #include "llvm/Support/MathExtras.h"
24 
25 #define DEBUG_TYPE "scf-loop-pipelining"
26 
27 using namespace mlir;
28 using namespace mlir::scf;
29 
30 namespace {
31 
32 /// Helper to keep internal information during pipelining transformation.
33 struct LoopPipelinerInternal {
34  /// Coarse liverange information for ops used across stages.
35  struct LiverangeInfo {
36  unsigned lastUseStage = 0;
37  unsigned defStage = 0;
38  };
39 
40 protected:
41  ForOp forOp;
42  unsigned maxStage = 0;
44  std::vector<Operation *> opOrder;
45  Value ub;
46  Value lb;
47  Value step;
48  bool dynamicLoop;
49  PipeliningOption::AnnotationlFnType annotateFn = nullptr;
50  bool peelEpilogue;
51  PipeliningOption::PredicateOpFn predicateFn = nullptr;
52 
53  // When peeling the kernel we generate several version of each value for
54  // different stage of the prologue. This map tracks the mapping between
55  // original Values in the loop and the different versions
56  // peeled from the loop.
58 
59  /// Assign a value to `valueMapping`, this means `val` represents the version
60  /// `idx` of `key` in the epilogue.
61  void setValueMapping(Value key, Value el, int64_t idx);
62 
63  /// Return the defining op of the given value, if the Value is an argument of
64  /// the loop return the associated defining op in the loop and its distance to
65  /// the Value.
66  std::pair<Operation *, int64_t> getDefiningOpAndDistance(Value value);
67 
68  /// Return true if the schedule is possible and return false otherwise. A
69  /// schedule is correct if all definitions are scheduled before uses.
70  bool verifySchedule();
71 
72 public:
73  /// Initalize the information for the given `op`, return true if it
74  /// satisfies the pre-condition to apply pipelining.
75  bool initializeLoopInfo(ForOp op, const PipeliningOption &options);
76  /// Emits the prologue, this creates `maxStage - 1` part which will contain
77  /// operations from stages [0; i], where i is the part index.
78  LogicalResult emitPrologue(RewriterBase &rewriter);
79  /// Gather liverange information for Values that are used in a different stage
80  /// than its definition.
81  llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues();
82  scf::ForOp createKernelLoop(
83  const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
84  RewriterBase &rewriter,
85  llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap);
86  /// Emits the pipelined kernel. This clones loop operations following user
87  /// order and remaps operands defined in a different stage as their use.
88  LogicalResult createKernel(
89  scf::ForOp newForOp,
90  const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
91  const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
92  RewriterBase &rewriter);
93  /// Emits the epilogue, this creates `maxStage - 1` part which will contain
94  /// operations from stages [i; maxStage], where i is the part index.
95  LogicalResult emitEpilogue(RewriterBase &rewriter,
96  llvm::SmallVector<Value> &returnValues);
97 };
98 
99 bool LoopPipelinerInternal::initializeLoopInfo(
100  ForOp op, const PipeliningOption &options) {
101  LDBG() << "Start initializeLoopInfo";
102  forOp = op;
103  ub = forOp.getUpperBound();
104  lb = forOp.getLowerBound();
105  step = forOp.getStep();
106 
107  std::vector<std::pair<Operation *, unsigned>> schedule;
108  options.getScheduleFn(forOp, schedule);
109  if (schedule.empty()) {
110  LDBG() << "--empty schedule -> BAIL";
111  return false;
112  }
113 
114  opOrder.reserve(schedule.size());
115  for (auto &opSchedule : schedule) {
116  maxStage = std::max(maxStage, opSchedule.second);
117  stages[opSchedule.first] = opSchedule.second;
118  opOrder.push_back(opSchedule.first);
119  }
120 
121  dynamicLoop = true;
122  auto upperBoundCst = getConstantIntValue(ub);
123  auto lowerBoundCst = getConstantIntValue(lb);
124  auto stepCst = getConstantIntValue(step);
125  if (!upperBoundCst || !lowerBoundCst || !stepCst) {
126  if (!options.supportDynamicLoops) {
127  LDBG() << "--dynamic loop not supported -> BAIL";
128  return false;
129  }
130  } else {
131  int64_t ubImm = upperBoundCst.value();
132  int64_t lbImm = lowerBoundCst.value();
133  int64_t stepImm = stepCst.value();
134  if (stepImm <= 0) {
135  LDBG() << "--invalid loop step -> BAIL";
136  return false;
137  }
138  int64_t numIteration = llvm::divideCeilSigned(ubImm - lbImm, stepImm);
139  if (numIteration >= maxStage) {
140  dynamicLoop = false;
141  } else if (!options.supportDynamicLoops) {
142  LDBG() << "--fewer loop iterations than pipeline stages -> BAIL";
143  return false;
144  }
145  }
146  peelEpilogue = options.peelEpilogue;
147  predicateFn = options.predicateFn;
148  if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) {
149  LDBG() << "--no epilogue or predicate set -> BAIL";
150  return false;
151  }
152 
153  // All operations need to have a stage.
154  for (Operation &op : forOp.getBody()->without_terminator()) {
155  if (!stages.contains(&op)) {
156  op.emitOpError("not assigned a pipeline stage");
157  LDBG() << "--op not assigned a pipeline stage: " << op << " -> BAIL";
158  return false;
159  }
160  }
161 
162  if (!verifySchedule()) {
163  LDBG() << "--invalid schedule: " << op << " -> BAIL";
164  return false;
165  }
166 
167  // Currently, we do not support assigning stages to ops in nested regions. The
168  // block of all operations assigned a stage should be the single `scf.for`
169  // body block.
170  for (const auto &[op, stageNum] : stages) {
171  (void)stageNum;
172  if (op == forOp.getBody()->getTerminator()) {
173  op->emitError("terminator should not be assigned a stage");
174  LDBG() << "--terminator should not be assigned stage: " << *op
175  << " -> BAIL";
176  return false;
177  }
178  if (op->getBlock() != forOp.getBody()) {
179  op->emitOpError("the owning Block of all operations assigned a stage "
180  "should be the loop body block");
181  LDBG() << "--the owning Block of all operations assigned a stage "
182  "should be the loop body block: "
183  << *op << " -> BAIL";
184  return false;
185  }
186  }
187 
188  // Support only loop-carried dependencies with a distance of one iteration or
189  // those defined outside of the loop. This means that any dependency within a
190  // loop should either be on the immediately preceding iteration, the current
191  // iteration, or on variables whose values are set before entering the loop.
192  if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
193  [this](Value operand) {
194  Operation *def = operand.getDefiningOp();
195  return !def ||
196  (!stages.contains(def) && forOp->isAncestor(def));
197  })) {
198  LDBG() << "--only support loop carried dependency with a distance of 1 or "
199  "defined outside of the loop -> BAIL";
200  return false;
201  }
202  annotateFn = options.annotateFn;
203  return true;
204 }
205 
206 /// Find operands of all the nested operations within `op`.
207 static SetVector<Value> getNestedOperands(Operation *op) {
208  SetVector<Value> operands;
209  op->walk([&](Operation *nestedOp) {
210  operands.insert_range(nestedOp->getOperands());
211  });
212  return operands;
213 }
214 
215 /// Compute unrolled cycles of each op (consumer) and verify that each op is
216 /// scheduled after its operands (producers) while adjusting for the distance
217 /// between producer and consumer.
218 bool LoopPipelinerInternal::verifySchedule() {
219  int64_t numCylesPerIter = opOrder.size();
220  // Pre-compute the unrolled cycle of each op.
221  DenseMap<Operation *, int64_t> unrolledCyles;
222  for (int64_t cycle = 0; cycle < numCylesPerIter; cycle++) {
223  Operation *def = opOrder[cycle];
224  auto it = stages.find(def);
225  assert(it != stages.end());
226  int64_t stage = it->second;
227  unrolledCyles[def] = cycle + stage * numCylesPerIter;
228  }
229  for (Operation *consumer : opOrder) {
230  int64_t consumerCycle = unrolledCyles[consumer];
231  for (Value operand : getNestedOperands(consumer)) {
232  auto [producer, distance] = getDefiningOpAndDistance(operand);
233  if (!producer)
234  continue;
235  auto it = unrolledCyles.find(producer);
236  // Skip producer coming from outside the loop.
237  if (it == unrolledCyles.end())
238  continue;
239  int64_t producerCycle = it->second;
240  if (consumerCycle < producerCycle - numCylesPerIter * distance) {
241  consumer->emitError("operation scheduled before its operands");
242  return false;
243  }
244  }
245  }
246  return true;
247 }
248 
249 /// Clone `op` and call `callback` on the cloned op's oeprands as well as any
250 /// operands of nested ops that:
251 /// 1) aren't defined within the new op or
252 /// 2) are block arguments.
253 static Operation *
254 cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
255  function_ref<void(OpOperand *newOperand)> callback) {
256  Operation *clone = rewriter.clone(*op);
257  clone->walk<WalkOrder::PreOrder>([&](Operation *nested) {
258  // 'clone' itself will be visited first.
259  for (OpOperand &operand : nested->getOpOperands()) {
260  Operation *def = operand.get().getDefiningOp();
261  if ((def && !clone->isAncestor(def)) || isa<BlockArgument>(operand.get()))
262  callback(&operand);
263  }
264  });
265  return clone;
266 }
267 
268 LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
269  // Initialize the iteration argument to the loop initial values.
270  for (auto [arg, operand] :
271  llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
272  setValueMapping(arg, operand.get(), 0);
273  }
274  auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
275  Location loc = forOp.getLoc();
276  SmallVector<Value> predicates(maxStage);
277  for (int64_t i = 0; i < maxStage; i++) {
278  if (dynamicLoop) {
279  Type t = ub.getType();
280  // pred = ub > lb + (i * step)
281  Value iv = arith::AddIOp::create(
282  rewriter, loc, lb,
283  arith::MulIOp::create(
284  rewriter, loc, step,
285  arith::ConstantOp::create(rewriter, loc,
286  rewriter.getIntegerAttr(t, i))));
287  predicates[i] = arith::CmpIOp::create(rewriter, loc,
288  arith::CmpIPredicate::slt, iv, ub);
289  }
290 
291  // special handling for induction variable as the increment is implicit.
292  // iv = lb + i * step
293  Type t = lb.getType();
294  Value iv = arith::AddIOp::create(
295  rewriter, loc, lb,
296  arith::MulIOp::create(
297  rewriter, loc, step,
298  arith::ConstantOp::create(rewriter, loc,
299  rewriter.getIntegerAttr(t, i))));
300  setValueMapping(forOp.getInductionVar(), iv, i);
301  for (Operation *op : opOrder) {
302  if (stages[op] > i)
303  continue;
304  Operation *newOp =
305  cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
306  auto it = valueMapping.find(newOperand->get());
307  if (it != valueMapping.end()) {
308  Value replacement = it->second[i - stages[op]];
309  newOperand->set(replacement);
310  }
311  });
312  int predicateIdx = i - stages[op];
313  if (predicates[predicateIdx]) {
314  OpBuilder::InsertionGuard insertGuard(rewriter);
315  newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
316  if (newOp == nullptr)
317  return failure();
318  }
319  if (annotateFn)
320  annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
321  for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
322  Value source = newOp->getResult(destId);
323  // If the value is a loop carried dependency update the loop argument
324  for (OpOperand &operand : yield->getOpOperands()) {
325  if (operand.get() != op->getResult(destId))
326  continue;
327  if (predicates[predicateIdx] &&
328  !forOp.getResult(operand.getOperandNumber()).use_empty()) {
329  // If the value is used outside the loop, we need to make sure we
330  // return the correct version of it.
331  Value prevValue = valueMapping
332  [forOp.getRegionIterArgs()[operand.getOperandNumber()]]
333  [i - stages[op]];
334  source = arith::SelectOp::create(
335  rewriter, loc, predicates[predicateIdx], source, prevValue);
336  }
337  setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
338  source, i - stages[op] + 1);
339  }
340  setValueMapping(op->getResult(destId), newOp->getResult(destId),
341  i - stages[op]);
342  }
343  }
344  }
345  return success();
346 }
347 
348 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
349 LoopPipelinerInternal::analyzeCrossStageValues() {
350  llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues;
351  for (Operation *op : opOrder) {
352  unsigned stage = stages[op];
353 
354  auto analyzeOperand = [&](OpOperand &operand) {
355  auto [def, distance] = getDefiningOpAndDistance(operand.get());
356  if (!def)
357  return;
358  auto defStage = stages.find(def);
359  if (defStage == stages.end() || defStage->second == stage ||
360  defStage->second == stage + distance)
361  return;
362  assert(stage > defStage->second);
363  LiverangeInfo &info = crossStageValues[operand.get()];
364  info.defStage = defStage->second;
365  info.lastUseStage = std::max(info.lastUseStage, stage);
366  };
367 
368  for (OpOperand &operand : op->getOpOperands())
369  analyzeOperand(operand);
370  visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) {
371  analyzeOperand(*operand);
372  });
373  }
374  return crossStageValues;
375 }
376 
377 std::pair<Operation *, int64_t>
378 LoopPipelinerInternal::getDefiningOpAndDistance(Value value) {
379  int64_t distance = 0;
380  if (auto arg = dyn_cast<BlockArgument>(value)) {
381  if (arg.getOwner() != forOp.getBody())
382  return {nullptr, 0};
383  // Ignore induction variable.
384  if (arg.getArgNumber() == 0)
385  return {nullptr, 0};
386  distance++;
387  value =
388  forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1);
389  }
390  Operation *def = value.getDefiningOp();
391  if (!def)
392  return {nullptr, 0};
393  return {def, distance};
394 }
395 
396 scf::ForOp LoopPipelinerInternal::createKernelLoop(
397  const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
398  &crossStageValues,
399  RewriterBase &rewriter,
400  llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap) {
401  // Creates the list of initial values associated to values used across
402  // stages. The initial values come from the prologue created above.
403  // Keep track of the kernel argument associated to each version of the
404  // values passed to the kernel.
405  llvm::SmallVector<Value> newLoopArg;
406  // For existing loop argument initialize them with the right version from the
407  // prologue.
408  for (const auto &retVal :
409  llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
410  Operation *def = retVal.value().getDefiningOp();
411  assert(def && "Only support loop carried dependencies of distance of 1 or "
412  "outside the loop");
413  auto defStage = stages.find(def);
414  if (defStage != stages.end()) {
415  Value valueVersion =
416  valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
417  [maxStage - defStage->second];
418  assert(valueVersion);
419  newLoopArg.push_back(valueVersion);
420  } else {
421  newLoopArg.push_back(forOp.getInitArgs()[retVal.index()]);
422  }
423  }
424  for (auto escape : crossStageValues) {
425  LiverangeInfo &info = escape.second;
426  Value value = escape.first;
427  for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage;
428  stageIdx++) {
429  Value valueVersion =
430  valueMapping[value][maxStage - info.lastUseStage + stageIdx];
431  assert(valueVersion);
432  newLoopArg.push_back(valueVersion);
433  loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage -
434  stageIdx)] = newLoopArg.size() - 1;
435  }
436  }
437 
438  // Create the new kernel loop. When we peel the epilgue we need to peel
439  // `numStages - 1` iterations. Then we adjust the upper bound to remove those
440  // iterations.
441  Value newUb = forOp.getUpperBound();
442  if (peelEpilogue) {
443  Type t = ub.getType();
444  Location loc = forOp.getLoc();
445  // newUb = ub - maxStage * step
446  Value maxStageValue = arith::ConstantOp::create(
447  rewriter, loc, rewriter.getIntegerAttr(t, maxStage));
448  Value maxStageByStep =
449  arith::MulIOp::create(rewriter, loc, step, maxStageValue);
450  newUb = arith::SubIOp::create(rewriter, loc, ub, maxStageByStep);
451  }
452  auto newForOp =
453  scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), newUb,
454  forOp.getStep(), newLoopArg);
455  // When there are no iter args, the loop body terminator will be created.
456  // Since we always create it below, remove the terminator if it was created.
457  if (!newForOp.getBody()->empty())
458  rewriter.eraseOp(newForOp.getBody()->getTerminator());
459  return newForOp;
460 }
461 
462 LogicalResult LoopPipelinerInternal::createKernel(
463  scf::ForOp newForOp,
464  const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
465  &crossStageValues,
466  const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
467  RewriterBase &rewriter) {
468  valueMapping.clear();
469 
470  // Create the kernel, we clone instruction based on the order given by
471  // user and remap operands coming from a previous stages.
472  rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
473  IRMapping mapping;
474  mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
475  for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) {
476  mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
477  }
478  SmallVector<Value> predicates(maxStage + 1, nullptr);
479  if (!peelEpilogue) {
480  // Create a predicate for each stage except the last stage.
481  Location loc = newForOp.getLoc();
482  Type t = ub.getType();
483  for (unsigned i = 0; i < maxStage; i++) {
484  // c = ub - (maxStage - i) * step
485  Value c = arith::SubIOp::create(
486  rewriter, loc, ub,
487  arith::MulIOp::create(
488  rewriter, loc, step,
489  arith::ConstantOp::create(
490  rewriter, loc,
491  rewriter.getIntegerAttr(t, int64_t(maxStage - i)))));
492 
493  Value pred = arith::CmpIOp::create(rewriter, newForOp.getLoc(),
494  arith::CmpIPredicate::slt,
495  newForOp.getInductionVar(), c);
496  predicates[i] = pred;
497  }
498  }
499  for (Operation *op : opOrder) {
500  int64_t useStage = stages[op];
501  auto *newOp = rewriter.clone(*op, mapping);
502  SmallVector<OpOperand *> operands;
503  // Collect all the operands for the cloned op and its nested ops.
504  op->walk([&operands](Operation *nestedOp) {
505  for (OpOperand &operand : nestedOp->getOpOperands()) {
506  operands.push_back(&operand);
507  }
508  });
509  for (OpOperand *operand : operands) {
510  Operation *nestedNewOp = mapping.lookup(operand->getOwner());
511  // Special case for the induction variable uses. We replace it with a
512  // version incremented based on the stage where it is used.
513  if (operand->get() == forOp.getInductionVar()) {
514  rewriter.setInsertionPoint(newOp);
515 
516  // offset = (maxStage - stages[op]) * step
517  Type t = step.getType();
518  Value offset = arith::MulIOp::create(
519  rewriter, forOp.getLoc(), step,
520  arith::ConstantOp::create(
521  rewriter, forOp.getLoc(),
522  rewriter.getIntegerAttr(t, maxStage - stages[op])));
523  Value iv = arith::AddIOp::create(rewriter, forOp.getLoc(),
524  newForOp.getInductionVar(), offset);
525  nestedNewOp->setOperand(operand->getOperandNumber(), iv);
526  rewriter.setInsertionPointAfter(newOp);
527  continue;
528  }
529  Value source = operand->get();
530  auto arg = dyn_cast<BlockArgument>(source);
531  if (arg && arg.getOwner() == forOp.getBody()) {
532  Value ret = forOp.getBody()->getTerminator()->getOperand(
533  arg.getArgNumber() - 1);
534  Operation *dep = ret.getDefiningOp();
535  if (!dep)
536  continue;
537  auto stageDep = stages.find(dep);
538  if (stageDep == stages.end() || stageDep->second == useStage)
539  continue;
540  // If the value is a loop carried value coming from stage N + 1 remap,
541  // it will become a direct use.
542  if (stageDep->second == useStage + 1) {
543  nestedNewOp->setOperand(operand->getOperandNumber(),
544  mapping.lookupOrDefault(ret));
545  continue;
546  }
547  source = ret;
548  }
549  // For operands defined in a previous stage we need to remap it to use
550  // the correct region argument. We look for the right version of the
551  // Value based on the stage where it is used.
552  Operation *def = source.getDefiningOp();
553  if (!def)
554  continue;
555  auto stageDef = stages.find(def);
556  if (stageDef == stages.end() || stageDef->second == useStage)
557  continue;
558  auto remap = loopArgMap.find(
559  std::make_pair(operand->get(), useStage - stageDef->second));
560  assert(remap != loopArgMap.end());
561  nestedNewOp->setOperand(operand->getOperandNumber(),
562  newForOp.getRegionIterArgs()[remap->second]);
563  }
564 
565  if (predicates[useStage]) {
566  OpBuilder::InsertionGuard insertGuard(rewriter);
567  newOp = predicateFn(rewriter, newOp, predicates[useStage]);
568  if (!newOp)
569  return failure();
570  // Remap the results to the new predicated one.
571  for (auto values : llvm::zip(op->getResults(), newOp->getResults()))
572  mapping.map(std::get<0>(values), std::get<1>(values));
573  }
574  if (annotateFn)
575  annotateFn(newOp, PipeliningOption::PipelinerPart::Kernel, 0);
576  }
577 
578  // Collect the Values that need to be returned by the forOp. For each
579  // value we need to have `LastUseStage - DefStage` number of versions
580  // returned.
581  // We create a mapping between original values and the associated loop
582  // returned values that will be needed by the epilogue.
583  llvm::SmallVector<Value> yieldOperands;
584  for (OpOperand &yieldOperand :
585  forOp.getBody()->getTerminator()->getOpOperands()) {
586  Value source = mapping.lookupOrDefault(yieldOperand.get());
587  // When we don't peel the epilogue and the yield value is used outside the
588  // loop we need to make sure we return the version from numStages -
589  // defStage.
590  if (!peelEpilogue &&
591  !forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) {
592  Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first;
593  if (def) {
594  auto defStage = stages.find(def);
595  if (defStage != stages.end() && defStage->second < maxStage) {
596  Value pred = predicates[defStage->second];
597  source = arith::SelectOp::create(
598  rewriter, pred.getLoc(), pred, source,
599  newForOp.getBody()
600  ->getArguments()[yieldOperand.getOperandNumber() + 1]);
601  }
602  }
603  }
604  yieldOperands.push_back(source);
605  }
606 
607  for (auto &it : crossStageValues) {
608  int64_t version = maxStage - it.second.lastUseStage + 1;
609  unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
610  // add the original version to yield ops.
611  // If there is a live range spanning across more than 2 stages we need to
612  // add extra arg.
613  for (unsigned i = 1; i < numVersionReturned; i++) {
614  setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
615  version++);
616  yieldOperands.push_back(
617  newForOp.getBody()->getArguments()[yieldOperands.size() + 1 +
618  newForOp.getNumInductionVars()]);
619  }
620  setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
621  version++);
622  yieldOperands.push_back(mapping.lookupOrDefault(it.first));
623  }
624  // Map the yield operand to the forOp returned value.
625  for (const auto &retVal :
626  llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
627  Operation *def = retVal.value().getDefiningOp();
628  assert(def && "Only support loop carried dependencies of distance of 1 or "
629  "defined outside the loop");
630  auto defStage = stages.find(def);
631  if (defStage == stages.end()) {
632  for (unsigned int stage = 1; stage <= maxStage; stage++)
633  setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
634  retVal.value(), stage);
635  } else if (defStage->second > 0) {
636  setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
637  newForOp->getResult(retVal.index()),
638  maxStage - defStage->second + 1);
639  }
640  }
641  scf::YieldOp::create(rewriter, forOp.getLoc(), yieldOperands);
642  return success();
643 }
644 
645 LogicalResult
646 LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
647  llvm::SmallVector<Value> &returnValues) {
648  Location loc = forOp.getLoc();
649  Type t = lb.getType();
650 
651  // Emit different versions of the induction variable. They will be
652  // removed by dead code if not used.
653 
654  auto createConst = [&](int v) {
655  return arith::ConstantOp::create(rewriter, loc,
656  rewriter.getIntegerAttr(t, v));
657  };
658 
659  // total_iterations = cdiv(range_diff, step);
660  // - range_diff = ub - lb
661  // - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step
662  Value zero = createConst(0);
663  Value one = createConst(1);
664  Value stepLessZero = arith::CmpIOp::create(
665  rewriter, loc, arith::CmpIPredicate::slt, step, zero);
666  Value stepDecr = arith::SelectOp::create(rewriter, loc, stepLessZero, one,
667  createConst(-1));
668 
669  Value rangeDiff = arith::SubIOp::create(rewriter, loc, ub, lb);
670  Value rangeIncrStep = arith::AddIOp::create(rewriter, loc, rangeDiff, step);
671  Value rangeDecr =
672  arith::AddIOp::create(rewriter, loc, rangeIncrStep, stepDecr);
673  Value totalIterations =
674  arith::DivSIOp::create(rewriter, loc, rangeDecr, step);
675 
676  // If total_iters < max_stage, start the epilogue at zero to match the
677  // ramp-up in the prologue.
678  // start_iter = max(0, total_iters - max_stage)
679  Value iterI = arith::SubIOp::create(rewriter, loc, totalIterations,
680  createConst(maxStage));
681  iterI = arith::MaxSIOp::create(rewriter, loc, zero, iterI);
682 
683  // Capture predicates for dynamic loops.
684  SmallVector<Value> predicates(maxStage + 1);
685 
686  for (int64_t i = 1; i <= maxStage; i++) {
687  // newLastIter = lb + step * iterI
688  Value newlastIter = arith::AddIOp::create(
689  rewriter, loc, lb, arith::MulIOp::create(rewriter, loc, step, iterI));
690 
691  setValueMapping(forOp.getInductionVar(), newlastIter, i);
692 
693  // increment to next iterI
694  iterI = arith::AddIOp::create(rewriter, loc, iterI, one);
695 
696  if (dynamicLoop) {
697  // Disable stages when `i` is greater than total_iters.
698  // pred = total_iters >= i
699  predicates[i] =
700  arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge,
701  totalIterations, createConst(i));
702  }
703  }
704 
705  // Emit `maxStage - 1` epilogue part that includes operations from stages
706  // [i; maxStage].
707  for (int64_t i = 1; i <= maxStage; i++) {
708  SmallVector<std::pair<Value, unsigned>> returnMap(returnValues.size());
709  for (Operation *op : opOrder) {
710  if (stages[op] < i)
711  continue;
712  unsigned currentVersion = maxStage - stages[op] + i;
713  unsigned nextVersion = currentVersion + 1;
714  Operation *newOp =
715  cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
716  auto it = valueMapping.find(newOperand->get());
717  if (it != valueMapping.end()) {
718  Value replacement = it->second[currentVersion];
719  newOperand->set(replacement);
720  }
721  });
722  if (dynamicLoop) {
723  OpBuilder::InsertionGuard insertGuard(rewriter);
724  newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
725  if (!newOp)
726  return failure();
727  }
728  if (annotateFn)
729  annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
730 
731  for (auto [opRes, newRes] :
732  llvm::zip(op->getResults(), newOp->getResults())) {
733  setValueMapping(opRes, newRes, currentVersion);
734  // If the value is a loop carried dependency update the loop argument
735  // mapping and keep track of the last version to replace the original
736  // forOp uses.
737  for (OpOperand &operand :
738  forOp.getBody()->getTerminator()->getOpOperands()) {
739  if (operand.get() != opRes)
740  continue;
741  // If the version is greater than maxStage it means it maps to the
742  // original forOp returned value.
743  unsigned ri = operand.getOperandNumber();
744  returnValues[ri] = newRes;
745  Value mapVal = forOp.getRegionIterArgs()[ri];
746  returnMap[ri] = std::make_pair(mapVal, currentVersion);
747  if (nextVersion <= maxStage)
748  setValueMapping(mapVal, newRes, nextVersion);
749  }
750  }
751  }
752  if (dynamicLoop) {
753  // Select return values from this stage (live outs) based on predication.
754  // If the stage is valid select the peeled value, else use previous stage
755  // value.
756  for (auto pair : llvm::enumerate(returnValues)) {
757  unsigned ri = pair.index();
758  auto [mapVal, currentVersion] = returnMap[ri];
759  if (mapVal) {
760  unsigned nextVersion = currentVersion + 1;
761  Value pred = predicates[currentVersion];
762  Value prevValue = valueMapping[mapVal][currentVersion];
763  auto selOp = arith::SelectOp::create(rewriter, loc, pred,
764  pair.value(), prevValue);
765  returnValues[ri] = selOp;
766  if (nextVersion <= maxStage)
767  setValueMapping(mapVal, selOp, nextVersion);
768  }
769  }
770  }
771  }
772  return success();
773 }
774 
775 void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
776  auto it = valueMapping.find(key);
777  // If the value is not in the map yet add a vector big enough to store all
778  // versions.
779  if (it == valueMapping.end())
780  it =
781  valueMapping
782  .insert(std::make_pair(key, llvm::SmallVector<Value>(maxStage + 1)))
783  .first;
784  it->second[idx] = el;
785 }
786 
787 } // namespace
788 
789 FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
790  const PipeliningOption &options,
791  bool *modifiedIR) {
792  if (modifiedIR)
793  *modifiedIR = false;
794 
795  // TODO: Add support for unsigned loops.
796  if (forOp.getUnsignedCmp())
797  return failure();
798 
799  LoopPipelinerInternal pipeliner;
800  if (!pipeliner.initializeLoopInfo(forOp, options))
801  return failure();
802 
803  if (modifiedIR)
804  *modifiedIR = true;
805 
806  // 1. Emit prologue.
807  if (failed(pipeliner.emitPrologue(rewriter)))
808  return failure();
809 
810  // 2. Track values used across stages. When a value cross stages it will
811  // need to be passed as loop iteration arguments.
812  // We first collect the values that are used in a different stage than where
813  // they are defined.
814  llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
815  crossStageValues = pipeliner.analyzeCrossStageValues();
816 
817  // Mapping between original loop values used cross stage and the block
818  // arguments associated after pipelining. A Value may map to several
819  // arguments if its liverange spans across more than 2 stages.
820  llvm::DenseMap<std::pair<Value, unsigned>, unsigned> loopArgMap;
821  // 3. Create the new kernel loop and return the block arguments mapping.
822  ForOp newForOp =
823  pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap);
824  // Create the kernel block, order ops based on user choice and remap
825  // operands.
826  if (failed(pipeliner.createKernel(newForOp, crossStageValues, loopArgMap,
827  rewriter)))
828  return failure();
829 
830  llvm::SmallVector<Value> returnValues =
831  newForOp.getResults().take_front(forOp->getNumResults());
832  if (options.peelEpilogue) {
833  // 4. Emit the epilogue after the new forOp.
834  rewriter.setInsertionPointAfter(newForOp);
835  if (failed(pipeliner.emitEpilogue(rewriter, returnValues)))
836  return failure();
837  }
838  // 5. Erase the original loop and replace the uses with the epilogue output.
839  if (forOp->getNumResults() > 0)
840  rewriter.replaceOp(forOp, returnValues);
841  else
842  rewriter.eraseOp(forOp);
843 
844  return newForOp;
845 }
846 
850 }
static Value createConst(Location loc, Type type, int value, PatternRewriter &rewriter)
Create an integer or index constant.
Definition: ExpandOps.cpp:27
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:227
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:65
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:552
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
This class represents an operand of an operation.
Definition: Value.h:257
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
Definition: Operation.h:263
result_range getResults()
Definition: Operation.h:415
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns, const PipeliningOption &options)
Populate patterns for SCF software pipelining transformation.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
const FrozenRewritePatternSet & patterns
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
Definition: RegionUtils.cpp:43
Options to dictate how loops should be pipelined.
Definition: Transforms.h:129
std::function< void(Operation *, PipelinerPart, unsigned)> AnnotationlFnType
Lambda called by the pipeliner to allow the user to annotate the IR while it is generated.
Definition: Transforms.h:147
std::function< Operation *(RewriterBase &, Operation *, Value)> PredicateOpFn
Definition: Transforms.h:170