MLIR 22.0.0git
LoopPipelining.cpp
Go to the documentation of this file.
1//===- LoopPipelining.cpp - Code to perform loop software pipelining-------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements loop software pipelining
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/IRMapping.h"
21#include "llvm/ADT/MapVector.h"
22#include "llvm/Support/DebugLog.h"
23#include "llvm/Support/MathExtras.h"
24
25#define DEBUG_TYPE "scf-loop-pipelining"
26
27using namespace mlir;
28using namespace mlir::scf;
29
30namespace {
31
32/// Helper to keep internal information during pipelining transformation.
33struct LoopPipelinerInternal {
34 /// Coarse liverange information for ops used across stages.
35 struct LiverangeInfo {
36 unsigned lastUseStage = 0;
37 unsigned defStage = 0;
38 };
39
40protected:
41 ForOp forOp;
42 unsigned maxStage = 0;
44 std::vector<Operation *> opOrder;
45 Value ub;
46 Value lb;
47 Value step;
48 bool dynamicLoop;
49 PipeliningOption::AnnotationlFnType annotateFn = nullptr;
50 bool peelEpilogue;
51 PipeliningOption::PredicateOpFn predicateFn = nullptr;
52
53 // When peeling the kernel we generate several version of each value for
54 // different stage of the prologue. This map tracks the mapping between
55 // original Values in the loop and the different versions
56 // peeled from the loop.
58
59 /// Assign a value to `valueMapping`, this means `val` represents the version
60 /// `idx` of `key` in the epilogue.
61 void setValueMapping(Value key, Value el, int64_t idx);
62
63 /// Return the defining op of the given value, if the Value is an argument of
64 /// the loop return the associated defining op in the loop and its distance to
65 /// the Value.
66 std::pair<Operation *, int64_t> getDefiningOpAndDistance(Value value);
67
68 /// Return true if the schedule is possible and return false otherwise. A
69 /// schedule is correct if all definitions are scheduled before uses.
70 bool verifySchedule();
71
72public:
73 /// Initalize the information for the given `op`, return true if it
74 /// satisfies the pre-condition to apply pipelining.
75 bool initializeLoopInfo(ForOp op, const PipeliningOption &options);
76 /// Emits the prologue, this creates `maxStage - 1` part which will contain
77 /// operations from stages [0; i], where i is the part index.
78 LogicalResult emitPrologue(RewriterBase &rewriter);
79 /// Gather liverange information for Values that are used in a different stage
80 /// than its definition.
81 llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues();
82 scf::ForOp createKernelLoop(
83 const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
84 RewriterBase &rewriter,
85 llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap);
86 /// Emits the pipelined kernel. This clones loop operations following user
87 /// order and remaps operands defined in a different stage as their use.
88 LogicalResult createKernel(
89 scf::ForOp newForOp,
90 const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
91 const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
92 RewriterBase &rewriter);
93 /// Emits the epilogue, this creates `maxStage - 1` part which will contain
94 /// operations from stages [i; maxStage], where i is the part index.
95 LogicalResult emitEpilogue(RewriterBase &rewriter,
96 llvm::SmallVector<Value> &returnValues);
97};
98
99bool LoopPipelinerInternal::initializeLoopInfo(
100 ForOp op, const PipeliningOption &options) {
101 LDBG() << "Start initializeLoopInfo";
102 forOp = op;
103 ub = forOp.getUpperBound();
104 lb = forOp.getLowerBound();
105 step = forOp.getStep();
106
107 std::vector<std::pair<Operation *, unsigned>> schedule;
108 options.getScheduleFn(forOp, schedule);
109 if (schedule.empty()) {
110 LDBG() << "--empty schedule -> BAIL";
111 return false;
112 }
113
114 opOrder.reserve(schedule.size());
115 for (auto &opSchedule : schedule) {
116 maxStage = std::max(maxStage, opSchedule.second);
117 stages[opSchedule.first] = opSchedule.second;
118 opOrder.push_back(opSchedule.first);
119 }
120
121 dynamicLoop = true;
122 auto upperBoundCst = getConstantIntValue(ub);
123 auto lowerBoundCst = getConstantIntValue(lb);
124 auto stepCst = getConstantIntValue(step);
125 if (!upperBoundCst || !lowerBoundCst || !stepCst) {
126 if (!options.supportDynamicLoops) {
127 LDBG() << "--dynamic loop not supported -> BAIL";
128 return false;
129 }
130 } else {
131 int64_t ubImm = upperBoundCst.value();
132 int64_t lbImm = lowerBoundCst.value();
133 int64_t stepImm = stepCst.value();
134 if (stepImm <= 0) {
135 LDBG() << "--invalid loop step -> BAIL";
136 return false;
137 }
138 int64_t numIteration = llvm::divideCeilSigned(ubImm - lbImm, stepImm);
139 if (numIteration >= maxStage) {
140 dynamicLoop = false;
141 } else if (!options.supportDynamicLoops) {
142 LDBG() << "--fewer loop iterations than pipeline stages -> BAIL";
143 return false;
144 }
145 }
146 peelEpilogue = options.peelEpilogue;
147 predicateFn = options.predicateFn;
148 if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) {
149 LDBG() << "--no epilogue or predicate set -> BAIL";
150 return false;
151 }
152
153 // All operations need to have a stage.
154 for (Operation &op : forOp.getBody()->without_terminator()) {
155 if (!stages.contains(&op)) {
156 op.emitOpError("not assigned a pipeline stage");
157 LDBG() << "--op not assigned a pipeline stage: " << op << " -> BAIL";
158 return false;
159 }
160 }
161
162 if (!verifySchedule()) {
163 LDBG() << "--invalid schedule: " << op << " -> BAIL";
164 return false;
165 }
166
167 // Currently, we do not support assigning stages to ops in nested regions. The
168 // block of all operations assigned a stage should be the single `scf.for`
169 // body block.
170 for (const auto &[op, stageNum] : stages) {
171 (void)stageNum;
172 if (op == forOp.getBody()->getTerminator()) {
173 op->emitError("terminator should not be assigned a stage");
174 LDBG() << "--terminator should not be assigned stage: " << *op
175 << " -> BAIL";
176 return false;
177 }
178 if (op->getBlock() != forOp.getBody()) {
179 op->emitOpError("the owning Block of all operations assigned a stage "
180 "should be the loop body block");
181 LDBG() << "--the owning Block of all operations assigned a stage "
182 "should be the loop body block: "
183 << *op << " -> BAIL";
184 return false;
185 }
186 }
187
188 // Support only loop-carried dependencies with a distance of one iteration or
189 // those defined outside of the loop. This means that any dependency within a
190 // loop should either be on the immediately preceding iteration, the current
191 // iteration, or on variables whose values are set before entering the loop.
192 if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
193 [this](Value operand) {
194 Operation *def = operand.getDefiningOp();
195 return !def ||
196 (!stages.contains(def) && forOp->isAncestor(def));
197 })) {
198 LDBG() << "--only support loop carried dependency with a distance of 1 or "
199 "defined outside of the loop -> BAIL";
200 return false;
201 }
202 annotateFn = options.annotateFn;
203 return true;
204}
205
206/// Find operands of all the nested operations within `op`.
207static SetVector<Value> getNestedOperands(Operation *op) {
208 SetVector<Value> operands;
209 op->walk([&](Operation *nestedOp) {
210 operands.insert_range(nestedOp->getOperands());
211 });
212 return operands;
213}
214
215/// Compute unrolled cycles of each op (consumer) and verify that each op is
216/// scheduled after its operands (producers) while adjusting for the distance
217/// between producer and consumer.
218bool LoopPipelinerInternal::verifySchedule() {
219 int64_t numCylesPerIter = opOrder.size();
220 // Pre-compute the unrolled cycle of each op.
221 DenseMap<Operation *, int64_t> unrolledCyles;
222 for (int64_t cycle = 0; cycle < numCylesPerIter; cycle++) {
223 Operation *def = opOrder[cycle];
224 auto it = stages.find(def);
225 assert(it != stages.end());
226 int64_t stage = it->second;
227 unrolledCyles[def] = cycle + stage * numCylesPerIter;
228 }
229 for (Operation *consumer : opOrder) {
230 int64_t consumerCycle = unrolledCyles[consumer];
231 for (Value operand : getNestedOperands(consumer)) {
232 auto [producer, distance] = getDefiningOpAndDistance(operand);
233 if (!producer)
234 continue;
235 auto it = unrolledCyles.find(producer);
236 // Skip producer coming from outside the loop.
237 if (it == unrolledCyles.end())
238 continue;
239 int64_t producerCycle = it->second;
240 if (consumerCycle < producerCycle - numCylesPerIter * distance) {
241 consumer->emitError("operation scheduled before its operands");
242 return false;
243 }
244 }
245 }
246 return true;
247}
248
249/// Clone `op` and call `callback` on the cloned op's oeprands as well as any
250/// operands of nested ops that:
251/// 1) aren't defined within the new op or
252/// 2) are block arguments.
253static Operation *
254cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
255 function_ref<void(OpOperand *newOperand)> callback) {
256 Operation *clone = rewriter.clone(*op);
257 clone->walk<WalkOrder::PreOrder>([&](Operation *nested) {
258 // 'clone' itself will be visited first.
259 for (OpOperand &operand : nested->getOpOperands()) {
260 Operation *def = operand.get().getDefiningOp();
261 if ((def && !clone->isAncestor(def)) || isa<BlockArgument>(operand.get()))
262 callback(&operand);
263 }
264 });
265 return clone;
266}
267
268LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
269 // Initialize the iteration argument to the loop initial values.
270 for (auto [arg, operand] :
271 llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
272 setValueMapping(arg, operand.get(), 0);
273 }
274 auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
275 Location loc = forOp.getLoc();
276 SmallVector<Value> predicates(maxStage);
277 for (int64_t i = 0; i < maxStage; i++) {
278 if (dynamicLoop) {
279 Type t = ub.getType();
280 // pred = ub > lb + (i * step)
281 Value iv = arith::AddIOp::create(
282 rewriter, loc, lb,
283 arith::MulIOp::create(
284 rewriter, loc, step,
285 arith::ConstantOp::create(rewriter, loc,
286 rewriter.getIntegerAttr(t, i))));
287 predicates[i] = arith::CmpIOp::create(rewriter, loc,
288 arith::CmpIPredicate::slt, iv, ub);
289 }
290
291 // special handling for induction variable as the increment is implicit.
292 // iv = lb + i * step
293 Type t = lb.getType();
294 Value iv = arith::AddIOp::create(
295 rewriter, loc, lb,
296 arith::MulIOp::create(
297 rewriter, loc, step,
298 arith::ConstantOp::create(rewriter, loc,
299 rewriter.getIntegerAttr(t, i))));
300 setValueMapping(forOp.getInductionVar(), iv, i);
301 for (Operation *op : opOrder) {
302 if (stages[op] > i)
303 continue;
304 Operation *newOp =
305 cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
306 auto it = valueMapping.find(newOperand->get());
307 if (it != valueMapping.end()) {
308 Value replacement = it->second[i - stages[op]];
309 newOperand->set(replacement);
310 }
311 });
312 int predicateIdx = i - stages[op];
313 if (predicates[predicateIdx]) {
314 OpBuilder::InsertionGuard insertGuard(rewriter);
315 newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
316 if (newOp == nullptr)
317 return failure();
318 }
319 if (annotateFn)
320 annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
321 for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
322 Value source = newOp->getResult(destId);
323 // If the value is a loop carried dependency update the loop argument
324 for (OpOperand &operand : yield->getOpOperands()) {
325 if (operand.get() != op->getResult(destId))
326 continue;
327 if (predicates[predicateIdx] &&
328 !forOp.getResult(operand.getOperandNumber()).use_empty()) {
329 // If the value is used outside the loop, we need to make sure we
330 // return the correct version of it.
331 Value prevValue = valueMapping
332 [forOp.getRegionIterArgs()[operand.getOperandNumber()]]
333 [i - stages[op]];
334 source = arith::SelectOp::create(
335 rewriter, loc, predicates[predicateIdx], source, prevValue);
336 }
337 setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
338 source, i - stages[op] + 1);
339 }
340 setValueMapping(op->getResult(destId), newOp->getResult(destId),
341 i - stages[op]);
342 }
343 }
344 }
345 return success();
346}
347
348llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
349LoopPipelinerInternal::analyzeCrossStageValues() {
350 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues;
351 for (Operation *op : opOrder) {
352 unsigned stage = stages[op];
353
354 auto analyzeOperand = [&](OpOperand &operand) {
355 auto [def, distance] = getDefiningOpAndDistance(operand.get());
356 if (!def)
357 return;
358 auto defStage = stages.find(def);
359 if (defStage == stages.end() || defStage->second == stage ||
360 defStage->second == stage + distance)
361 return;
362 assert(stage > defStage->second);
363 LiverangeInfo &info = crossStageValues[operand.get()];
364 info.defStage = defStage->second;
365 info.lastUseStage = std::max(info.lastUseStage, stage);
366 };
367
368 for (OpOperand &operand : op->getOpOperands())
369 analyzeOperand(operand);
370 visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) {
371 analyzeOperand(*operand);
372 });
373 }
374 return crossStageValues;
375}
376
377std::pair<Operation *, int64_t>
378LoopPipelinerInternal::getDefiningOpAndDistance(Value value) {
379 int64_t distance = 0;
380 if (auto arg = dyn_cast<BlockArgument>(value)) {
381 if (arg.getOwner() != forOp.getBody())
382 return {nullptr, 0};
383 // Ignore induction variable.
384 if (arg.getArgNumber() == 0)
385 return {nullptr, 0};
386 distance++;
387 value =
388 forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1);
389 }
390 Operation *def = value.getDefiningOp();
391 if (!def)
392 return {nullptr, 0};
393 return {def, distance};
394}
395
396scf::ForOp LoopPipelinerInternal::createKernelLoop(
397 const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
398 &crossStageValues,
399 RewriterBase &rewriter,
400 llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap) {
401 // Creates the list of initial values associated to values used across
402 // stages. The initial values come from the prologue created above.
403 // Keep track of the kernel argument associated to each version of the
404 // values passed to the kernel.
405 llvm::SmallVector<Value> newLoopArg;
406 // For existing loop argument initialize them with the right version from the
407 // prologue.
408 for (const auto &retVal :
409 llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
410 Operation *def = retVal.value().getDefiningOp();
411 assert(def && "Only support loop carried dependencies of distance of 1 or "
412 "outside the loop");
413 auto defStage = stages.find(def);
414 if (defStage != stages.end()) {
415 Value valueVersion =
416 valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
417 [maxStage - defStage->second];
418 assert(valueVersion);
419 newLoopArg.push_back(valueVersion);
420 } else {
421 newLoopArg.push_back(forOp.getInitArgs()[retVal.index()]);
422 }
423 }
424 for (auto escape : crossStageValues) {
425 LiverangeInfo &info = escape.second;
426 Value value = escape.first;
427 for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage;
428 stageIdx++) {
429 Value valueVersion =
430 valueMapping[value][maxStage - info.lastUseStage + stageIdx];
431 assert(valueVersion);
432 newLoopArg.push_back(valueVersion);
433 loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage -
434 stageIdx)] = newLoopArg.size() - 1;
435 }
436 }
437
438 // Create the new kernel loop. When we peel the epilgue we need to peel
439 // `numStages - 1` iterations. Then we adjust the upper bound to remove those
440 // iterations.
441 Value newUb = forOp.getUpperBound();
442 if (peelEpilogue) {
443 Type t = ub.getType();
444 Location loc = forOp.getLoc();
445 // newUb = ub - maxStage * step
446 Value maxStageValue = arith::ConstantOp::create(
447 rewriter, loc, rewriter.getIntegerAttr(t, maxStage));
448 Value maxStageByStep =
449 arith::MulIOp::create(rewriter, loc, step, maxStageValue);
450 newUb = arith::SubIOp::create(rewriter, loc, ub, maxStageByStep);
451 }
452 auto newForOp =
453 scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), newUb,
454 forOp.getStep(), newLoopArg);
455 // When there are no iter args, the loop body terminator will be created.
456 // Since we always create it below, remove the terminator if it was created.
457 if (!newForOp.getBody()->empty())
458 rewriter.eraseOp(newForOp.getBody()->getTerminator());
459 return newForOp;
460}
461
462LogicalResult LoopPipelinerInternal::createKernel(
463 scf::ForOp newForOp,
464 const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
465 &crossStageValues,
466 const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
467 RewriterBase &rewriter) {
468 valueMapping.clear();
469
470 // Create the kernel, we clone instruction based on the order given by
471 // user and remap operands coming from a previous stages.
472 rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
473 IRMapping mapping;
474 mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
475 for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) {
476 mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
477 }
478 SmallVector<Value> predicates(maxStage + 1, nullptr);
479 if (!peelEpilogue) {
480 // Create a predicate for each stage except the last stage.
481 Location loc = newForOp.getLoc();
482 Type t = ub.getType();
483 for (unsigned i = 0; i < maxStage; i++) {
484 // c = ub - (maxStage - i) * step
485 Value c = arith::SubIOp::create(
486 rewriter, loc, ub,
487 arith::MulIOp::create(
488 rewriter, loc, step,
489 arith::ConstantOp::create(
490 rewriter, loc,
491 rewriter.getIntegerAttr(t, int64_t(maxStage - i)))));
492
493 Value pred = arith::CmpIOp::create(rewriter, newForOp.getLoc(),
494 arith::CmpIPredicate::slt,
495 newForOp.getInductionVar(), c);
496 predicates[i] = pred;
497 }
498 }
499 for (Operation *op : opOrder) {
500 int64_t useStage = stages[op];
501 auto *newOp = rewriter.clone(*op, mapping);
503 // Collect all the operands for the cloned op and its nested ops.
504 op->walk([&operands](Operation *nestedOp) {
505 for (OpOperand &operand : nestedOp->getOpOperands()) {
506 operands.push_back(&operand);
507 }
508 });
509 for (OpOperand *operand : operands) {
510 Operation *nestedNewOp = mapping.lookup(operand->getOwner());
511 // Special case for the induction variable uses. We replace it with a
512 // version incremented based on the stage where it is used.
513 if (operand->get() == forOp.getInductionVar()) {
514 rewriter.setInsertionPoint(newOp);
515
516 // offset = (maxStage - stages[op]) * step
517 Type t = step.getType();
518 Value offset = arith::MulIOp::create(
519 rewriter, forOp.getLoc(), step,
520 arith::ConstantOp::create(
521 rewriter, forOp.getLoc(),
522 rewriter.getIntegerAttr(t, maxStage - stages[op])));
523 Value iv = arith::AddIOp::create(rewriter, forOp.getLoc(),
524 newForOp.getInductionVar(), offset);
525 nestedNewOp->setOperand(operand->getOperandNumber(), iv);
526 rewriter.setInsertionPointAfter(newOp);
527 continue;
528 }
529 Value source = operand->get();
530 auto arg = dyn_cast<BlockArgument>(source);
531 if (arg && arg.getOwner() == forOp.getBody()) {
532 Value ret = forOp.getBody()->getTerminator()->getOperand(
533 arg.getArgNumber() - 1);
534 Operation *dep = ret.getDefiningOp();
535 if (!dep)
536 continue;
537 auto stageDep = stages.find(dep);
538 if (stageDep == stages.end() || stageDep->second == useStage)
539 continue;
540 // If the value is a loop carried value coming from stage N + 1 remap,
541 // it will become a direct use.
542 if (stageDep->second == useStage + 1) {
543 nestedNewOp->setOperand(operand->getOperandNumber(),
544 mapping.lookupOrDefault(ret));
545 continue;
546 }
547 source = ret;
548 }
549 // For operands defined in a previous stage we need to remap it to use
550 // the correct region argument. We look for the right version of the
551 // Value based on the stage where it is used.
552 Operation *def = source.getDefiningOp();
553 if (!def)
554 continue;
555 auto stageDef = stages.find(def);
556 if (stageDef == stages.end() || stageDef->second == useStage)
557 continue;
558 auto remap = loopArgMap.find(
559 std::make_pair(operand->get(), useStage - stageDef->second));
560 assert(remap != loopArgMap.end());
561 nestedNewOp->setOperand(operand->getOperandNumber(),
562 newForOp.getRegionIterArgs()[remap->second]);
563 }
564
565 if (predicates[useStage]) {
566 OpBuilder::InsertionGuard insertGuard(rewriter);
567 newOp = predicateFn(rewriter, newOp, predicates[useStage]);
568 if (!newOp)
569 return failure();
570 // Remap the results to the new predicated one.
571 for (auto values : llvm::zip(op->getResults(), newOp->getResults()))
572 mapping.map(std::get<0>(values), std::get<1>(values));
573 }
574 if (annotateFn)
575 annotateFn(newOp, PipeliningOption::PipelinerPart::Kernel, 0);
576 }
577
578 // Collect the Values that need to be returned by the forOp. For each
579 // value we need to have `LastUseStage - DefStage` number of versions
580 // returned.
581 // We create a mapping between original values and the associated loop
582 // returned values that will be needed by the epilogue.
583 llvm::SmallVector<Value> yieldOperands;
584 for (OpOperand &yieldOperand :
585 forOp.getBody()->getTerminator()->getOpOperands()) {
586 Value source = mapping.lookupOrDefault(yieldOperand.get());
587 // When we don't peel the epilogue and the yield value is used outside the
588 // loop we need to make sure we return the version from numStages -
589 // defStage.
590 if (!peelEpilogue &&
591 !forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) {
592 Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first;
593 if (def) {
594 auto defStage = stages.find(def);
595 if (defStage != stages.end() && defStage->second < maxStage) {
596 Value pred = predicates[defStage->second];
597 source = arith::SelectOp::create(
598 rewriter, pred.getLoc(), pred, source,
599 newForOp.getBody()
600 ->getArguments()[yieldOperand.getOperandNumber() + 1]);
601 }
602 }
603 }
604 yieldOperands.push_back(source);
605 }
606
607 for (auto &it : crossStageValues) {
608 int64_t version = maxStage - it.second.lastUseStage + 1;
609 unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
610 // add the original version to yield ops.
611 // If there is a live range spanning across more than 2 stages we need to
612 // add extra arg.
613 for (unsigned i = 1; i < numVersionReturned; i++) {
614 setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
615 version++);
616 yieldOperands.push_back(
617 newForOp.getBody()->getArguments()[yieldOperands.size() + 1 +
618 newForOp.getNumInductionVars()]);
619 }
620 setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
621 version++);
622 yieldOperands.push_back(mapping.lookupOrDefault(it.first));
623 }
624 // Map the yield operand to the forOp returned value.
625 for (const auto &retVal :
626 llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
627 Operation *def = retVal.value().getDefiningOp();
628 assert(def && "Only support loop carried dependencies of distance of 1 or "
629 "defined outside the loop");
630 auto defStage = stages.find(def);
631 if (defStage == stages.end()) {
632 for (unsigned int stage = 1; stage <= maxStage; stage++)
633 setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
634 retVal.value(), stage);
635 } else if (defStage->second > 0) {
636 setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
637 newForOp->getResult(retVal.index()),
638 maxStage - defStage->second + 1);
639 }
640 }
641 scf::YieldOp::create(rewriter, forOp.getLoc(), yieldOperands);
642 return success();
643}
644
645LogicalResult
646LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
647 llvm::SmallVector<Value> &returnValues) {
648 Location loc = forOp.getLoc();
649 Type t = lb.getType();
650
651 // Emit different versions of the induction variable. They will be
652 // removed by dead code if not used.
653
654 auto createConst = [&](int v) {
655 return arith::ConstantOp::create(rewriter, loc,
656 rewriter.getIntegerAttr(t, v));
657 };
658
659 // total_iterations = cdiv(range_diff, step);
660 // - range_diff = ub - lb
661 // - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step
662 Value zero = createConst(0);
663 Value one = createConst(1);
664 Value stepLessZero = arith::CmpIOp::create(
665 rewriter, loc, arith::CmpIPredicate::slt, step, zero);
666 Value stepDecr = arith::SelectOp::create(rewriter, loc, stepLessZero, one,
667 createConst(-1));
668
669 Value rangeDiff = arith::SubIOp::create(rewriter, loc, ub, lb);
670 Value rangeIncrStep = arith::AddIOp::create(rewriter, loc, rangeDiff, step);
671 Value rangeDecr =
672 arith::AddIOp::create(rewriter, loc, rangeIncrStep, stepDecr);
673 Value totalIterations =
674 arith::DivSIOp::create(rewriter, loc, rangeDecr, step);
675
676 // If total_iters < max_stage, start the epilogue at zero to match the
677 // ramp-up in the prologue.
678 // start_iter = max(0, total_iters - max_stage)
679 Value iterI = arith::SubIOp::create(rewriter, loc, totalIterations,
680 createConst(maxStage));
681 iterI = arith::MaxSIOp::create(rewriter, loc, zero, iterI);
682
683 // Capture predicates for dynamic loops.
684 SmallVector<Value> predicates(maxStage + 1);
685
686 for (int64_t i = 1; i <= maxStage; i++) {
687 // newLastIter = lb + step * iterI
688 Value newlastIter = arith::AddIOp::create(
689 rewriter, loc, lb, arith::MulIOp::create(rewriter, loc, step, iterI));
690
691 setValueMapping(forOp.getInductionVar(), newlastIter, i);
692
693 // increment to next iterI
694 iterI = arith::AddIOp::create(rewriter, loc, iterI, one);
695
696 if (dynamicLoop) {
697 // Disable stages when `i` is greater than total_iters.
698 // pred = total_iters >= i
699 predicates[i] =
700 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge,
701 totalIterations, createConst(i));
702 }
703 }
704
705 // Emit `maxStage - 1` epilogue part that includes operations from stages
706 // [i; maxStage].
707 for (int64_t i = 1; i <= maxStage; i++) {
708 SmallVector<std::pair<Value, unsigned>> returnMap(returnValues.size());
709 for (Operation *op : opOrder) {
710 if (stages[op] < i)
711 continue;
712 unsigned currentVersion = maxStage - stages[op] + i;
713 unsigned nextVersion = currentVersion + 1;
714 Operation *newOp =
715 cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
716 auto it = valueMapping.find(newOperand->get());
717 if (it != valueMapping.end()) {
718 Value replacement = it->second[currentVersion];
719 newOperand->set(replacement);
720 }
721 });
722 if (dynamicLoop) {
723 OpBuilder::InsertionGuard insertGuard(rewriter);
724 newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
725 if (!newOp)
726 return failure();
727 }
728 if (annotateFn)
729 annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
730
731 for (auto [opRes, newRes] :
732 llvm::zip(op->getResults(), newOp->getResults())) {
733 setValueMapping(opRes, newRes, currentVersion);
734 // If the value is a loop carried dependency update the loop argument
735 // mapping and keep track of the last version to replace the original
736 // forOp uses.
737 for (OpOperand &operand :
738 forOp.getBody()->getTerminator()->getOpOperands()) {
739 if (operand.get() != opRes)
740 continue;
741 // If the version is greater than maxStage it means it maps to the
742 // original forOp returned value.
743 unsigned ri = operand.getOperandNumber();
744 returnValues[ri] = newRes;
745 Value mapVal = forOp.getRegionIterArgs()[ri];
746 returnMap[ri] = std::make_pair(mapVal, currentVersion);
747 if (nextVersion <= maxStage)
748 setValueMapping(mapVal, newRes, nextVersion);
749 }
750 }
751 }
752 if (dynamicLoop) {
753 // Select return values from this stage (live outs) based on predication.
754 // If the stage is valid select the peeled value, else use previous stage
755 // value.
756 for (auto pair : llvm::enumerate(returnValues)) {
757 unsigned ri = pair.index();
758 auto [mapVal, currentVersion] = returnMap[ri];
759 if (mapVal) {
760 unsigned nextVersion = currentVersion + 1;
761 Value pred = predicates[currentVersion];
762 Value prevValue = valueMapping[mapVal][currentVersion];
763 auto selOp = arith::SelectOp::create(rewriter, loc, pred,
764 pair.value(), prevValue);
765 returnValues[ri] = selOp;
766 if (nextVersion <= maxStage)
767 setValueMapping(mapVal, selOp, nextVersion);
768 }
769 }
770 }
771 }
772 return success();
773}
774
775void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
776 auto it = valueMapping.find(key);
777 // If the value is not in the map yet add a vector big enough to store all
778 // versions.
779 if (it == valueMapping.end())
780 it =
781 valueMapping
782 .insert(std::make_pair(key, llvm::SmallVector<Value>(maxStage + 1)))
783 .first;
784 it->second[idx] = el;
785}
786
787} // namespace
788
789FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
791 bool *modifiedIR) {
792 if (modifiedIR)
793 *modifiedIR = false;
794
795 // TODO: Add support for unsigned loops.
796 if (forOp.getUnsignedCmp())
797 return failure();
798
799 LoopPipelinerInternal pipeliner;
800 if (!pipeliner.initializeLoopInfo(forOp, options))
801 return failure();
802
803 if (modifiedIR)
804 *modifiedIR = true;
805
806 // 1. Emit prologue.
807 if (failed(pipeliner.emitPrologue(rewriter)))
808 return failure();
809
810 // 2. Track values used across stages. When a value cross stages it will
811 // need to be passed as loop iteration arguments.
812 // We first collect the values that are used in a different stage than where
813 // they are defined.
814 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
815 crossStageValues = pipeliner.analyzeCrossStageValues();
816
817 // Mapping between original loop values used cross stage and the block
818 // arguments associated after pipelining. A Value may map to several
819 // arguments if its liverange spans across more than 2 stages.
820 llvm::DenseMap<std::pair<Value, unsigned>, unsigned> loopArgMap;
821 // 3. Create the new kernel loop and return the block arguments mapping.
822 ForOp newForOp =
823 pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap);
824 // Create the kernel block, order ops based on user choice and remap
825 // operands.
826 if (failed(pipeliner.createKernel(newForOp, crossStageValues, loopArgMap,
827 rewriter)))
828 return failure();
829
830 llvm::SmallVector<Value> returnValues =
831 newForOp.getResults().take_front(forOp->getNumResults());
832 if (options.peelEpilogue) {
833 // 4. Emit the epilogue after the new forOp.
834 rewriter.setInsertionPointAfter(newForOp);
835 if (failed(pipeliner.emitEpilogue(rewriter, returnValues)))
836 return failure();
837 }
838 // 5. Erase the original loop and replace the uses with the epilogue output.
839 if (forOp->getNumResults() > 0)
840 rewriter.replaceOp(forOp, returnValues);
841 else
842 rewriter.eraseOp(forOp);
843
844 return newForOp;
845}
846
return success()
static Value createConst(Location loc, Type type, int value, PatternRewriter &rewriter)
Create an integer or index constant.
Definition ExpandOps.cpp:27
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static llvm::ManagedStatic< PassManagerOptions > options
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:65
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void setOperand(unsigned idx, Value value)
Definition Operation.h:351
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
result_range getResults()
Definition Operation.h:415
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns, const PipeliningOption &options)
Populate patterns for SCF software pipelining transformation.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
const FrozenRewritePatternSet & patterns
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
void visitUsedValuesDefinedAbove(Region &region, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
Options to dictate how loops should be pipelined.
Definition Transforms.h:129
std::function< Operation *(RewriterBase &, Operation *, Value)> PredicateOpFn
Definition Transforms.h:169
std::function< void(Operation *, PipelinerPart, unsigned)> AnnotationlFnType
Lambda called by the pipeliner to allow the user to annotate the IR while it is generated.
Definition Transforms.h:146