21 #include "llvm/ADT/MapVector.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/MathExtras.h"
25 #define DEBUG_TYPE "scf-loop-pipelining"
26 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
27 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
35 struct LoopPipelinerInternal {
37 struct LiverangeInfo {
38 unsigned lastUseStage = 0;
39 unsigned defStage = 0;
44 unsigned maxStage = 0;
46 std::vector<Operation *> opOrder;
63 void setValueMapping(
Value key,
Value el, int64_t idx);
68 std::pair<Operation *, int64_t> getDefiningOpAndDistance(
Value value);
72 bool verifySchedule();
83 llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues();
84 scf::ForOp createKernelLoop(
85 const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
90 LogicalResult createKernel(
92 const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
93 const llvm::DenseMap<std::pair<Value, unsigned>,
unsigned> &loopArgMap,
101 bool LoopPipelinerInternal::initializeLoopInfo(
103 LDBG(
"Start initializeLoopInfo");
105 ub = forOp.getUpperBound();
106 lb = forOp.getLowerBound();
107 step = forOp.getStep();
113 if (!upperBoundCst || !lowerBoundCst || !stepCst) {
114 if (!
options.supportDynamicLoops) {
115 LDBG(
"--dynamic loop not supported -> BAIL");
119 int64_t ubImm = upperBoundCst.value();
120 int64_t lbImm = lowerBoundCst.value();
121 int64_t stepImm = stepCst.value();
123 LDBG(
"--invalid loop step -> BAIL");
126 int64_t numIteration = llvm::divideCeilSigned(ubImm - lbImm, stepImm);
127 if (numIteration > maxStage) {
129 }
else if (!
options.supportDynamicLoops) {
130 LDBG(
"--fewer loop iterations than pipeline stages -> BAIL");
134 peelEpilogue =
options.peelEpilogue;
135 predicateFn =
options.predicateFn;
136 if ((!peelEpilogue || dynamicLoop) && predicateFn ==
nullptr) {
137 LDBG(
"--no epilogue or predicate set -> BAIL");
140 std::vector<std::pair<Operation *, unsigned>> schedule;
141 options.getScheduleFn(forOp, schedule);
142 if (schedule.empty()) {
143 LDBG(
"--empty schedule -> BAIL");
147 opOrder.reserve(schedule.size());
148 for (
auto &opSchedule : schedule) {
149 maxStage =
std::max(maxStage, opSchedule.second);
150 stages[opSchedule.first] = opSchedule.second;
151 opOrder.push_back(opSchedule.first);
155 for (
Operation &op : forOp.getBody()->without_terminator()) {
156 if (!stages.contains(&op)) {
157 op.emitOpError(
"not assigned a pipeline stage");
158 LDBG(
"--op not assigned a pipeline stage: " << op <<
" -> BAIL");
163 if (!verifySchedule()) {
164 LDBG(
"--invalid schedule: " << op <<
" -> BAIL");
171 for (
const auto &[op, stageNum] : stages) {
173 if (op == forOp.getBody()->getTerminator()) {
174 op->emitError(
"terminator should not be assigned a stage");
175 LDBG(
"--terminator should not be assigned stage: " << *op <<
" -> BAIL");
178 if (op->getBlock() != forOp.getBody()) {
179 op->emitOpError(
"the owning Block of all operations assigned a stage "
180 "should be the loop body block");
181 LDBG(
"--the owning Block of all operations assigned a stage "
182 "should be the loop body block: "
183 << *op <<
" -> BAIL");
192 if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
193 [
this](
Value operand) {
194 Operation *def = operand.getDefiningOp();
196 (!stages.contains(def) && forOp->isAncestor(def));
198 LDBG(
"--only support loop carried dependency with a distance of 1 or "
199 "defined outside of the loop -> BAIL");
202 annotateFn =
options.annotateFn;
218 bool LoopPipelinerInternal::verifySchedule() {
219 int64_t numCylesPerIter = opOrder.size();
222 for (int64_t cycle = 0; cycle < numCylesPerIter; cycle++) {
224 auto it = stages.find(def);
225 assert(it != stages.end());
226 int64_t stage = it->second;
227 unrolledCyles[def] = cycle + stage * numCylesPerIter;
230 int64_t consumerCycle = unrolledCyles[consumer];
231 for (
Value operand : getNestedOperands(consumer)) {
232 auto [producer, distance] = getDefiningOpAndDistance(operand);
235 auto it = unrolledCyles.find(producer);
237 if (it == unrolledCyles.end())
239 int64_t producerCycle = it->second;
240 if (consumerCycle < producerCycle - numCylesPerIter * distance) {
241 consumer->emitError(
"operation scheduled before its operands");
259 for (
OpOperand &operand : nested->getOpOperands()) {
260 Operation *def = operand.get().getDefiningOp();
261 if ((def && !
clone->
isAncestor(def)) || isa<BlockArgument>(operand.get()))
268 LogicalResult LoopPipelinerInternal::emitPrologue(
RewriterBase &rewriter) {
270 for (
auto [arg, operand] :
271 llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
272 setValueMapping(arg, operand.get(), 0);
274 auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
277 for (int64_t i = 0; i < maxStage; i++) {
279 Type t = ub.getType();
283 rewriter.
create<arith::MulIOp>(
285 rewriter.
create<arith::ConstantOp>(
287 predicates[i] = rewriter.
create<arith::CmpIOp>(
288 loc, arith::CmpIPredicate::slt, iv, ub);
293 Type t = lb.getType();
296 rewriter.
create<arith::MulIOp>(
298 rewriter.
create<arith::ConstantOp>(loc,
300 setValueMapping(forOp.getInductionVar(), iv, i);
305 cloneAndUpdateOperands(rewriter, op, [&](
OpOperand *newOperand) {
306 auto it = valueMapping.find(newOperand->
get());
307 if (it != valueMapping.end()) {
308 Value replacement = it->second[i - stages[op]];
309 newOperand->set(replacement);
312 int predicateIdx = i - stages[op];
313 if (predicates[predicateIdx]) {
315 newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
316 if (newOp ==
nullptr)
321 for (
unsigned destId : llvm::seq(
unsigned(0), op->getNumResults())) {
324 for (
OpOperand &operand : yield->getOpOperands()) {
325 if (operand.get() != op->getResult(destId))
327 if (predicates[predicateIdx] &&
328 !forOp.getResult(operand.getOperandNumber()).use_empty()) {
331 Value prevValue = valueMapping
332 [forOp.getRegionIterArgs()[operand.getOperandNumber()]]
334 source = rewriter.
create<arith::SelectOp>(
335 loc, predicates[predicateIdx], source, prevValue);
337 setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
338 source, i - stages[op] + 1);
340 setValueMapping(op->getResult(destId), newOp->
getResult(destId),
348 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
349 LoopPipelinerInternal::analyzeCrossStageValues() {
350 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues;
352 unsigned stage = stages[op];
354 auto analyzeOperand = [&](
OpOperand &operand) {
355 auto [def, distance] = getDefiningOpAndDistance(operand.get());
358 auto defStage = stages.find(def);
359 if (defStage == stages.end() || defStage->second == stage ||
360 defStage->second == stage + distance)
362 assert(stage > defStage->second);
363 LiverangeInfo &info = crossStageValues[operand.get()];
364 info.defStage = defStage->second;
365 info.lastUseStage =
std::max(info.lastUseStage, stage);
368 for (
OpOperand &operand : op->getOpOperands())
369 analyzeOperand(operand);
371 analyzeOperand(*operand);
374 return crossStageValues;
377 std::pair<Operation *, int64_t>
378 LoopPipelinerInternal::getDefiningOpAndDistance(
Value value) {
379 int64_t distance = 0;
380 if (
auto arg = dyn_cast<BlockArgument>(value)) {
381 if (arg.getOwner() != forOp.getBody())
384 if (arg.getArgNumber() == 0)
388 forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1);
393 return {def, distance};
396 scf::ForOp LoopPipelinerInternal::createKernelLoop(
397 const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
400 llvm::DenseMap<std::pair<Value, unsigned>,
unsigned> &loopArgMap) {
408 for (
const auto &retVal :
410 Operation *def = retVal.value().getDefiningOp();
411 assert(def &&
"Only support loop carried dependencies of distance of 1 or "
413 auto defStage = stages.find(def);
414 if (defStage != stages.end()) {
416 valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
417 [maxStage - defStage->second];
418 assert(valueVersion);
419 newLoopArg.push_back(valueVersion);
421 newLoopArg.push_back(forOp.getInitArgs()[retVal.index()]);
424 for (
auto escape : crossStageValues) {
425 LiverangeInfo &info = escape.second;
426 Value value = escape.first;
427 for (
unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage;
430 valueMapping[value][maxStage - info.lastUseStage + stageIdx];
431 assert(valueVersion);
432 newLoopArg.push_back(valueVersion);
433 loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage -
434 stageIdx)] = newLoopArg.size() - 1;
441 Value newUb = forOp.getUpperBound();
443 Type t = ub.getType();
446 Value maxStageValue = rewriter.
create<arith::ConstantOp>(
448 Value maxStageByStep =
449 rewriter.
create<arith::MulIOp>(loc, step, maxStageValue);
450 newUb = rewriter.
create<arith::SubIOp>(loc, ub, maxStageByStep);
453 rewriter.
create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb,
454 forOp.getStep(), newLoopArg);
457 if (!newForOp.getBody()->empty())
458 rewriter.
eraseOp(newForOp.getBody()->getTerminator());
462 LogicalResult LoopPipelinerInternal::createKernel(
464 const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
466 const llvm::DenseMap<std::pair<Value, unsigned>,
unsigned> &loopArgMap,
468 valueMapping.clear();
474 mapping.
map(forOp.getInductionVar(), newForOp.getInductionVar());
476 mapping.
map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
482 Type t = ub.getType();
483 for (
unsigned i = 0; i < maxStage; i++) {
487 rewriter.
create<arith::MulIOp>(
489 rewriter.
create<arith::ConstantOp>(
493 newForOp.getLoc(), arith::CmpIPredicate::slt,
494 newForOp.getInductionVar(), c);
495 predicates[i] = pred;
499 int64_t useStage = stages[op];
500 auto *newOp = rewriter.
clone(*op, mapping);
503 op->walk([&operands](
Operation *nestedOp) {
505 operands.push_back(&operand);
512 if (operand->get() == forOp.getInductionVar()) {
516 Type t = step.getType();
518 forOp.getLoc(), step,
519 rewriter.
create<arith::ConstantOp>(
523 forOp.getLoc(), newForOp.getInductionVar(), offset);
524 nestedNewOp->
setOperand(operand->getOperandNumber(), iv);
528 Value source = operand->get();
529 auto arg = dyn_cast<BlockArgument>(source);
530 if (arg && arg.getOwner() == forOp.getBody()) {
531 Value ret = forOp.getBody()->getTerminator()->getOperand(
532 arg.getArgNumber() - 1);
536 auto stageDep = stages.find(dep);
537 if (stageDep == stages.end() || stageDep->second == useStage)
541 if (stageDep->second == useStage + 1) {
542 nestedNewOp->
setOperand(operand->getOperandNumber(),
554 auto stageDef = stages.find(def);
555 if (stageDef == stages.end() || stageDef->second == useStage)
557 auto remap = loopArgMap.find(
558 std::make_pair(operand->get(), useStage - stageDef->second));
559 assert(remap != loopArgMap.end());
560 nestedNewOp->
setOperand(operand->getOperandNumber(),
561 newForOp.getRegionIterArgs()[remap->second]);
564 if (predicates[useStage]) {
566 newOp = predicateFn(rewriter, newOp, predicates[useStage]);
570 for (
auto values : llvm::zip(op->getResults(), newOp->getResults()))
571 mapping.
map(std::get<0>(values), std::get<1>(values));
584 forOp.getBody()->getTerminator()->getOpOperands()) {
590 !forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) {
591 Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first;
593 auto defStage = stages.find(def);
594 if (defStage != stages.end() && defStage->second < maxStage) {
595 Value pred = predicates[defStage->second];
596 source = rewriter.
create<arith::SelectOp>(
597 pred.
getLoc(), pred, source,
599 ->getArguments()[yieldOperand.getOperandNumber() + 1]);
603 yieldOperands.push_back(source);
606 for (
auto &it : crossStageValues) {
607 int64_t version = maxStage - it.second.lastUseStage + 1;
608 unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
612 for (
unsigned i = 1; i < numVersionReturned; i++) {
613 setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
615 yieldOperands.push_back(
616 newForOp.getBody()->getArguments()[yieldOperands.size() + 1 +
617 newForOp.getNumInductionVars()]);
619 setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
624 for (
const auto &retVal :
626 Operation *def = retVal.value().getDefiningOp();
627 assert(def &&
"Only support loop carried dependencies of distance of 1 or "
628 "defined outside the loop");
629 auto defStage = stages.find(def);
630 if (defStage == stages.end()) {
631 for (
unsigned int stage = 1; stage <= maxStage; stage++)
632 setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
633 retVal.value(), stage);
634 }
else if (defStage->second > 0) {
635 setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
636 newForOp->getResult(retVal.index()),
637 maxStage - defStage->second + 1);
640 rewriter.
create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
645 LoopPipelinerInternal::emitEpilogue(
RewriterBase &rewriter,
648 Type t = lb.getType();
654 return rewriter.
create<arith::ConstantOp>(loc,
663 Value stepLessZero = rewriter.
create<arith::CmpIOp>(
664 loc, arith::CmpIPredicate::slt, step, zero);
668 Value rangeDiff = rewriter.
create<arith::SubIOp>(loc, ub, lb);
669 Value rangeIncrStep = rewriter.
create<arith::AddIOp>(loc, rangeDiff, step);
671 rewriter.
create<arith::AddIOp>(loc, rangeIncrStep, stepDecr);
672 Value totalIterations = rewriter.
create<arith::DivSIOp>(loc, rangeDecr, step);
677 Value iterI = rewriter.
create<arith::SubIOp>(loc, totalIterations,
679 iterI = rewriter.
create<arith::MaxSIOp>(loc, zero, iterI);
684 for (int64_t i = 1; i <= maxStage; i++) {
686 Value newlastIter = rewriter.
create<arith::AddIOp>(
687 loc, lb, rewriter.
create<arith::MulIOp>(loc, step, iterI));
689 setValueMapping(forOp.getInductionVar(), newlastIter, i);
692 iterI = rewriter.
create<arith::AddIOp>(loc, iterI, one);
697 predicates[i] = rewriter.
create<arith::CmpIOp>(
698 loc, arith::CmpIPredicate::sge, totalIterations,
createConst(i));
704 for (int64_t i = 1; i <= maxStage; i++) {
709 unsigned currentVersion = maxStage - stages[op] + i;
710 unsigned nextVersion = currentVersion + 1;
712 cloneAndUpdateOperands(rewriter, op, [&](
OpOperand *newOperand) {
713 auto it = valueMapping.find(newOperand->
get());
714 if (it != valueMapping.end()) {
715 Value replacement = it->second[currentVersion];
716 newOperand->set(replacement);
721 newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
728 for (
auto [opRes, newRes] :
729 llvm::zip(op->getResults(), newOp->
getResults())) {
730 setValueMapping(opRes, newRes, currentVersion);
735 forOp.getBody()->getTerminator()->getOpOperands()) {
736 if (operand.get() != opRes)
740 unsigned ri = operand.getOperandNumber();
741 returnValues[ri] = newRes;
742 Value mapVal = forOp.getRegionIterArgs()[ri];
743 returnMap[ri] = std::make_pair(mapVal, currentVersion);
744 if (nextVersion <= maxStage)
745 setValueMapping(mapVal, newRes, nextVersion);
754 unsigned ri = pair.index();
755 auto [mapVal, currentVersion] = returnMap[ri];
757 unsigned nextVersion = currentVersion + 1;
758 Value pred = predicates[currentVersion];
759 Value prevValue = valueMapping[mapVal][currentVersion];
760 auto selOp = rewriter.
create<arith::SelectOp>(loc, pred, pair.value(),
762 returnValues[ri] = selOp;
763 if (nextVersion <= maxStage)
764 setValueMapping(mapVal, selOp, nextVersion);
772 void LoopPipelinerInternal::setValueMapping(
Value key,
Value el, int64_t idx) {
773 auto it = valueMapping.find(key);
776 if (it == valueMapping.end())
781 it->second[idx] = el;
791 LoopPipelinerInternal pipeliner;
792 if (!pipeliner.initializeLoopInfo(forOp,
options))
799 if (failed(pipeliner.emitPrologue(rewriter)))
806 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
807 crossStageValues = pipeliner.analyzeCrossStageValues();
815 pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap);
818 if (failed(pipeliner.createKernel(newForOp, crossStageValues, loopArgMap,
823 newForOp.getResults().take_front(forOp->getNumResults());
827 if (failed(pipeliner.emitEpilogue(rewriter, returnValues)))
831 if (forOp->getNumResults() > 0)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerAttr getIntegerAttr(Type type, int64_t value)
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
result_range getResults()
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns, const PipeliningOption &options)
Populate patterns for SCF software pipelining transformation.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
const FrozenRewritePatternSet & patterns
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
Options to dictate how loops should be pipelined.
std::function< void(Operation *, PipelinerPart, unsigned)> AnnotationlFnType
Lambda called by the pipeliner to allow the user to annotate the IR while it is generated.
std::function< Operation *(RewriterBase &, Operation *, Value)> PredicateOpFn