21 #include "llvm/ADT/MapVector.h"
22 #include "llvm/Support/DebugLog.h"
23 #include "llvm/Support/MathExtras.h"
25 #define DEBUG_TYPE "scf-loop-pipelining"
33 struct LoopPipelinerInternal {
35 struct LiverangeInfo {
36 unsigned lastUseStage = 0;
37 unsigned defStage = 0;
42 unsigned maxStage = 0;
44 std::vector<Operation *> opOrder;
61 void setValueMapping(
Value key,
Value el, int64_t idx);
66 std::pair<Operation *, int64_t> getDefiningOpAndDistance(
Value value);
70 bool verifySchedule();
81 llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues();
82 scf::ForOp createKernelLoop(
83 const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
88 LogicalResult createKernel(
90 const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
91 const llvm::DenseMap<std::pair<Value, unsigned>,
unsigned> &loopArgMap,
99 bool LoopPipelinerInternal::initializeLoopInfo(
101 LDBG() <<
"Start initializeLoopInfo";
103 ub = forOp.getUpperBound();
104 lb = forOp.getLowerBound();
105 step = forOp.getStep();
107 std::vector<std::pair<Operation *, unsigned>> schedule;
108 options.getScheduleFn(forOp, schedule);
109 if (schedule.empty()) {
110 LDBG() <<
"--empty schedule -> BAIL";
114 opOrder.reserve(schedule.size());
115 for (
auto &opSchedule : schedule) {
116 maxStage =
std::max(maxStage, opSchedule.second);
117 stages[opSchedule.first] = opSchedule.second;
118 opOrder.push_back(opSchedule.first);
125 if (!upperBoundCst || !lowerBoundCst || !stepCst) {
126 if (!
options.supportDynamicLoops) {
127 LDBG() <<
"--dynamic loop not supported -> BAIL";
131 int64_t ubImm = upperBoundCst.value();
132 int64_t lbImm = lowerBoundCst.value();
133 int64_t stepImm = stepCst.value();
135 LDBG() <<
"--invalid loop step -> BAIL";
138 int64_t numIteration = llvm::divideCeilSigned(ubImm - lbImm, stepImm);
139 if (numIteration >= maxStage) {
141 }
else if (!
options.supportDynamicLoops) {
142 LDBG() <<
"--fewer loop iterations than pipeline stages -> BAIL";
146 peelEpilogue =
options.peelEpilogue;
147 predicateFn =
options.predicateFn;
148 if ((!peelEpilogue || dynamicLoop) && predicateFn ==
nullptr) {
149 LDBG() <<
"--no epilogue or predicate set -> BAIL";
154 for (
Operation &op : forOp.getBody()->without_terminator()) {
155 if (!stages.contains(&op)) {
156 op.emitOpError(
"not assigned a pipeline stage");
157 LDBG() <<
"--op not assigned a pipeline stage: " << op <<
" -> BAIL";
162 if (!verifySchedule()) {
163 LDBG() <<
"--invalid schedule: " << op <<
" -> BAIL";
170 for (
const auto &[op, stageNum] : stages) {
172 if (op == forOp.getBody()->getTerminator()) {
173 op->emitError(
"terminator should not be assigned a stage");
174 LDBG() <<
"--terminator should not be assigned stage: " << *op
178 if (op->getBlock() != forOp.getBody()) {
179 op->emitOpError(
"the owning Block of all operations assigned a stage "
180 "should be the loop body block");
181 LDBG() <<
"--the owning Block of all operations assigned a stage "
182 "should be the loop body block: "
183 << *op <<
" -> BAIL";
192 if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
193 [
this](
Value operand) {
194 Operation *def = operand.getDefiningOp();
196 (!stages.contains(def) && forOp->isAncestor(def));
198 LDBG() <<
"--only support loop carried dependency with a distance of 1 or "
199 "defined outside of the loop -> BAIL";
202 annotateFn =
options.annotateFn;
218 bool LoopPipelinerInternal::verifySchedule() {
219 int64_t numCylesPerIter = opOrder.size();
222 for (int64_t cycle = 0; cycle < numCylesPerIter; cycle++) {
224 auto it = stages.find(def);
225 assert(it != stages.end());
226 int64_t stage = it->second;
227 unrolledCyles[def] = cycle + stage * numCylesPerIter;
230 int64_t consumerCycle = unrolledCyles[consumer];
231 for (
Value operand : getNestedOperands(consumer)) {
232 auto [producer, distance] = getDefiningOpAndDistance(operand);
235 auto it = unrolledCyles.find(producer);
237 if (it == unrolledCyles.end())
239 int64_t producerCycle = it->second;
240 if (consumerCycle < producerCycle - numCylesPerIter * distance) {
241 consumer->emitError(
"operation scheduled before its operands");
259 for (
OpOperand &operand : nested->getOpOperands()) {
260 Operation *def = operand.get().getDefiningOp();
261 if ((def && !
clone->
isAncestor(def)) || isa<BlockArgument>(operand.get()))
268 LogicalResult LoopPipelinerInternal::emitPrologue(
RewriterBase &rewriter) {
270 for (
auto [arg, operand] :
271 llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
272 setValueMapping(arg, operand.get(), 0);
274 auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
277 for (int64_t i = 0; i < maxStage; i++) {
279 Type t = ub.getType();
281 Value iv = arith::AddIOp::create(
283 arith::MulIOp::create(
285 arith::ConstantOp::create(rewriter, loc,
287 predicates[i] = arith::CmpIOp::create(rewriter, loc,
288 arith::CmpIPredicate::slt, iv, ub);
293 Type t = lb.getType();
294 Value iv = arith::AddIOp::create(
296 arith::MulIOp::create(
298 arith::ConstantOp::create(rewriter, loc,
300 setValueMapping(forOp.getInductionVar(), iv, i);
305 cloneAndUpdateOperands(rewriter, op, [&](
OpOperand *newOperand) {
306 auto it = valueMapping.find(newOperand->
get());
307 if (it != valueMapping.end()) {
308 Value replacement = it->second[i - stages[op]];
309 newOperand->set(replacement);
312 int predicateIdx = i - stages[op];
313 if (predicates[predicateIdx]) {
315 newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
316 if (newOp ==
nullptr)
321 for (
unsigned destId : llvm::seq(
unsigned(0), op->getNumResults())) {
324 for (
OpOperand &operand : yield->getOpOperands()) {
325 if (operand.get() != op->getResult(destId))
327 if (predicates[predicateIdx] &&
328 !forOp.getResult(operand.getOperandNumber()).use_empty()) {
331 Value prevValue = valueMapping
332 [forOp.getRegionIterArgs()[operand.getOperandNumber()]]
334 source = arith::SelectOp::create(
335 rewriter, loc, predicates[predicateIdx], source, prevValue);
337 setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
338 source, i - stages[op] + 1);
340 setValueMapping(op->getResult(destId), newOp->
getResult(destId),
348 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
349 LoopPipelinerInternal::analyzeCrossStageValues() {
350 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues;
352 unsigned stage = stages[op];
354 auto analyzeOperand = [&](
OpOperand &operand) {
355 auto [def, distance] = getDefiningOpAndDistance(operand.get());
358 auto defStage = stages.find(def);
359 if (defStage == stages.end() || defStage->second == stage ||
360 defStage->second == stage + distance)
362 assert(stage > defStage->second);
363 LiverangeInfo &info = crossStageValues[operand.get()];
364 info.defStage = defStage->second;
365 info.lastUseStage =
std::max(info.lastUseStage, stage);
368 for (
OpOperand &operand : op->getOpOperands())
369 analyzeOperand(operand);
371 analyzeOperand(*operand);
374 return crossStageValues;
377 std::pair<Operation *, int64_t>
378 LoopPipelinerInternal::getDefiningOpAndDistance(
Value value) {
379 int64_t distance = 0;
380 if (
auto arg = dyn_cast<BlockArgument>(value)) {
381 if (arg.getOwner() != forOp.getBody())
384 if (arg.getArgNumber() == 0)
388 forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1);
393 return {def, distance};
396 scf::ForOp LoopPipelinerInternal::createKernelLoop(
397 const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
400 llvm::DenseMap<std::pair<Value, unsigned>,
unsigned> &loopArgMap) {
408 for (
const auto &retVal :
410 Operation *def = retVal.value().getDefiningOp();
411 assert(def &&
"Only support loop carried dependencies of distance of 1 or "
413 auto defStage = stages.find(def);
414 if (defStage != stages.end()) {
416 valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
417 [maxStage - defStage->second];
418 assert(valueVersion);
419 newLoopArg.push_back(valueVersion);
421 newLoopArg.push_back(forOp.getInitArgs()[retVal.index()]);
424 for (
auto escape : crossStageValues) {
425 LiverangeInfo &info = escape.second;
426 Value value = escape.first;
427 for (
unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage;
430 valueMapping[value][maxStage - info.lastUseStage + stageIdx];
431 assert(valueVersion);
432 newLoopArg.push_back(valueVersion);
433 loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage -
434 stageIdx)] = newLoopArg.size() - 1;
441 Value newUb = forOp.getUpperBound();
443 Type t = ub.getType();
446 Value maxStageValue = arith::ConstantOp::create(
448 Value maxStageByStep =
449 arith::MulIOp::create(rewriter, loc, step, maxStageValue);
450 newUb = arith::SubIOp::create(rewriter, loc, ub, maxStageByStep);
453 scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), newUb,
454 forOp.getStep(), newLoopArg);
457 if (!newForOp.getBody()->empty())
458 rewriter.
eraseOp(newForOp.getBody()->getTerminator());
462 LogicalResult LoopPipelinerInternal::createKernel(
464 const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
466 const llvm::DenseMap<std::pair<Value, unsigned>,
unsigned> &loopArgMap,
468 valueMapping.clear();
474 mapping.
map(forOp.getInductionVar(), newForOp.getInductionVar());
476 mapping.
map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
482 Type t = ub.getType();
483 for (
unsigned i = 0; i < maxStage; i++) {
485 Value c = arith::SubIOp::create(
487 arith::MulIOp::create(
489 arith::ConstantOp::create(
493 Value pred = arith::CmpIOp::create(rewriter, newForOp.getLoc(),
494 arith::CmpIPredicate::slt,
495 newForOp.getInductionVar(), c);
496 predicates[i] = pred;
500 int64_t useStage = stages[op];
501 auto *newOp = rewriter.
clone(*op, mapping);
504 op->walk([&operands](
Operation *nestedOp) {
506 operands.push_back(&operand);
513 if (operand->get() == forOp.getInductionVar()) {
517 Type t = step.getType();
518 Value offset = arith::MulIOp::create(
519 rewriter, forOp.getLoc(), step,
520 arith::ConstantOp::create(
521 rewriter, forOp.getLoc(),
523 Value iv = arith::AddIOp::create(rewriter, forOp.getLoc(),
524 newForOp.getInductionVar(), offset);
525 nestedNewOp->
setOperand(operand->getOperandNumber(), iv);
529 Value source = operand->get();
530 auto arg = dyn_cast<BlockArgument>(source);
531 if (arg && arg.getOwner() == forOp.getBody()) {
532 Value ret = forOp.getBody()->getTerminator()->getOperand(
533 arg.getArgNumber() - 1);
537 auto stageDep = stages.find(dep);
538 if (stageDep == stages.end() || stageDep->second == useStage)
542 if (stageDep->second == useStage + 1) {
543 nestedNewOp->
setOperand(operand->getOperandNumber(),
555 auto stageDef = stages.find(def);
556 if (stageDef == stages.end() || stageDef->second == useStage)
558 auto remap = loopArgMap.find(
559 std::make_pair(operand->get(), useStage - stageDef->second));
560 assert(remap != loopArgMap.end());
561 nestedNewOp->
setOperand(operand->getOperandNumber(),
562 newForOp.getRegionIterArgs()[remap->second]);
565 if (predicates[useStage]) {
567 newOp = predicateFn(rewriter, newOp, predicates[useStage]);
571 for (
auto values : llvm::zip(op->getResults(), newOp->getResults()))
572 mapping.
map(std::get<0>(values), std::get<1>(values));
585 forOp.getBody()->getTerminator()->getOpOperands()) {
591 !forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) {
592 Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first;
594 auto defStage = stages.find(def);
595 if (defStage != stages.end() && defStage->second < maxStage) {
596 Value pred = predicates[defStage->second];
597 source = arith::SelectOp::create(
598 rewriter, pred.
getLoc(), pred, source,
600 ->getArguments()[yieldOperand.getOperandNumber() + 1]);
604 yieldOperands.push_back(source);
607 for (
auto &it : crossStageValues) {
608 int64_t version = maxStage - it.second.lastUseStage + 1;
609 unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
613 for (
unsigned i = 1; i < numVersionReturned; i++) {
614 setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
616 yieldOperands.push_back(
617 newForOp.getBody()->getArguments()[yieldOperands.size() + 1 +
618 newForOp.getNumInductionVars()]);
620 setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
625 for (
const auto &retVal :
627 Operation *def = retVal.value().getDefiningOp();
628 assert(def &&
"Only support loop carried dependencies of distance of 1 or "
629 "defined outside the loop");
630 auto defStage = stages.find(def);
631 if (defStage == stages.end()) {
632 for (
unsigned int stage = 1; stage <= maxStage; stage++)
633 setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
634 retVal.value(), stage);
635 }
else if (defStage->second > 0) {
636 setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
637 newForOp->getResult(retVal.index()),
638 maxStage - defStage->second + 1);
641 scf::YieldOp::create(rewriter, forOp.getLoc(), yieldOperands);
646 LoopPipelinerInternal::emitEpilogue(
RewriterBase &rewriter,
649 Type t = lb.getType();
655 return arith::ConstantOp::create(rewriter, loc,
664 Value stepLessZero = arith::CmpIOp::create(
665 rewriter, loc, arith::CmpIPredicate::slt, step, zero);
666 Value stepDecr = arith::SelectOp::create(rewriter, loc, stepLessZero, one,
669 Value rangeDiff = arith::SubIOp::create(rewriter, loc, ub, lb);
670 Value rangeIncrStep = arith::AddIOp::create(rewriter, loc, rangeDiff, step);
672 arith::AddIOp::create(rewriter, loc, rangeIncrStep, stepDecr);
673 Value totalIterations =
674 arith::DivSIOp::create(rewriter, loc, rangeDecr, step);
679 Value iterI = arith::SubIOp::create(rewriter, loc, totalIterations,
681 iterI = arith::MaxSIOp::create(rewriter, loc, zero, iterI);
686 for (int64_t i = 1; i <= maxStage; i++) {
688 Value newlastIter = arith::AddIOp::create(
689 rewriter, loc, lb, arith::MulIOp::create(rewriter, loc, step, iterI));
691 setValueMapping(forOp.getInductionVar(), newlastIter, i);
694 iterI = arith::AddIOp::create(rewriter, loc, iterI, one);
700 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge,
707 for (int64_t i = 1; i <= maxStage; i++) {
712 unsigned currentVersion = maxStage - stages[op] + i;
713 unsigned nextVersion = currentVersion + 1;
715 cloneAndUpdateOperands(rewriter, op, [&](
OpOperand *newOperand) {
716 auto it = valueMapping.find(newOperand->
get());
717 if (it != valueMapping.end()) {
718 Value replacement = it->second[currentVersion];
719 newOperand->set(replacement);
724 newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
731 for (
auto [opRes, newRes] :
732 llvm::zip(op->getResults(), newOp->
getResults())) {
733 setValueMapping(opRes, newRes, currentVersion);
738 forOp.getBody()->getTerminator()->getOpOperands()) {
739 if (operand.get() != opRes)
743 unsigned ri = operand.getOperandNumber();
744 returnValues[ri] = newRes;
745 Value mapVal = forOp.getRegionIterArgs()[ri];
746 returnMap[ri] = std::make_pair(mapVal, currentVersion);
747 if (nextVersion <= maxStage)
748 setValueMapping(mapVal, newRes, nextVersion);
757 unsigned ri = pair.index();
758 auto [mapVal, currentVersion] = returnMap[ri];
760 unsigned nextVersion = currentVersion + 1;
761 Value pred = predicates[currentVersion];
762 Value prevValue = valueMapping[mapVal][currentVersion];
763 auto selOp = arith::SelectOp::create(rewriter, loc, pred,
764 pair.value(), prevValue);
765 returnValues[ri] = selOp;
766 if (nextVersion <= maxStage)
767 setValueMapping(mapVal, selOp, nextVersion);
775 void LoopPipelinerInternal::setValueMapping(
Value key,
Value el, int64_t idx) {
776 auto it = valueMapping.find(key);
779 if (it == valueMapping.end())
784 it->second[idx] = el;
796 if (forOp.getUnsignedCmp())
799 LoopPipelinerInternal pipeliner;
800 if (!pipeliner.initializeLoopInfo(forOp,
options))
807 if (
failed(pipeliner.emitPrologue(rewriter)))
814 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
815 crossStageValues = pipeliner.analyzeCrossStageValues();
823 pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap);
826 if (
failed(pipeliner.createKernel(newForOp, crossStageValues, loopArgMap,
831 newForOp.getResults().take_front(forOp->getNumResults());
835 if (
failed(pipeliner.emitEpilogue(rewriter, returnValues)))
839 if (forOp->getNumResults() > 0)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerAttr getIntegerAttr(Type type, int64_t value)
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
result_range getResults()
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns, const PipeliningOption &options)
Populate patterns for SCF software pipelining transformation.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
const FrozenRewritePatternSet & patterns
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
Options to dictate how loops should be pipelined.
std::function< void(Operation *, PipelinerPart, unsigned)> AnnotationlFnType
Lambda called by the pipeliner to allow the user to annotate the IR while it is generated.
std::function< Operation *(RewriterBase &, Operation *, Value)> PredicateOpFn