21 #include "llvm/ADT/MapVector.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/MathExtras.h"
25 #define DEBUG_TYPE "scf-loop-pipelining"
26 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
27 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
35 struct LoopPipelinerInternal {
37 struct LiverangeInfo {
38 unsigned lastUseStage = 0;
39 unsigned defStage = 0;
44 unsigned maxStage = 0;
46 std::vector<Operation *> opOrder;
63 void setValueMapping(
Value key,
Value el, int64_t idx);
68 std::pair<Operation *, int64_t> getDefiningOpAndDistance(
Value value);
72 bool verifySchedule();
83 llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues();
84 scf::ForOp createKernelLoop(
85 const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
90 LogicalResult createKernel(
92 const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
93 const llvm::DenseMap<std::pair<Value, unsigned>,
unsigned> &loopArgMap,
101 bool LoopPipelinerInternal::initializeLoopInfo(
103 LDBG(
"Start initializeLoopInfo");
105 ub = forOp.getUpperBound();
106 lb = forOp.getLowerBound();
107 step = forOp.getStep();
113 if (!upperBoundCst || !lowerBoundCst || !stepCst) {
114 if (!
options.supportDynamicLoops) {
115 LDBG(
"--dynamic loop not supported -> BAIL");
119 int64_t ubImm = upperBoundCst.value();
120 int64_t lbImm = lowerBoundCst.value();
121 int64_t stepImm = stepCst.value();
123 LDBG(
"--invalid loop step -> BAIL");
126 int64_t numIteration = llvm::divideCeilSigned(ubImm - lbImm, stepImm);
127 if (numIteration > maxStage) {
129 }
else if (!
options.supportDynamicLoops) {
130 LDBG(
"--fewer loop iterations than pipeline stages -> BAIL");
134 peelEpilogue =
options.peelEpilogue;
135 predicateFn =
options.predicateFn;
136 if ((!peelEpilogue || dynamicLoop) && predicateFn ==
nullptr) {
137 LDBG(
"--no epilogue or predicate set -> BAIL");
140 std::vector<std::pair<Operation *, unsigned>> schedule;
141 options.getScheduleFn(forOp, schedule);
142 if (schedule.empty()) {
143 LDBG(
"--empty schedule -> BAIL");
147 opOrder.reserve(schedule.size());
148 for (
auto &opSchedule : schedule) {
149 maxStage =
std::max(maxStage, opSchedule.second);
150 stages[opSchedule.first] = opSchedule.second;
151 opOrder.push_back(opSchedule.first);
155 for (
Operation &op : forOp.getBody()->without_terminator()) {
156 if (!stages.contains(&op)) {
157 op.emitOpError(
"not assigned a pipeline stage");
158 LDBG(
"--op not assigned a pipeline stage: " << op <<
" -> BAIL");
163 if (!verifySchedule()) {
164 LDBG(
"--invalid schedule: " << op <<
" -> BAIL");
171 for (
const auto &[op, stageNum] : stages) {
173 if (op == forOp.getBody()->getTerminator()) {
174 op->emitError(
"terminator should not be assigned a stage");
175 LDBG(
"--terminator should not be assigned stage: " << *op <<
" -> BAIL");
178 if (op->getBlock() != forOp.getBody()) {
179 op->emitOpError(
"the owning Block of all operations assigned a stage "
180 "should be the loop body block");
181 LDBG(
"--the owning Block of all operations assigned a stage "
182 "should be the loop body block: "
183 << *op <<
" -> BAIL");
192 if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
193 [
this](
Value operand) {
194 Operation *def = operand.getDefiningOp();
196 (!stages.contains(def) && forOp->isAncestor(def));
198 LDBG(
"--only support loop carried dependency with a distance of 1 or "
199 "defined outside of the loop -> BAIL");
202 annotateFn =
options.annotateFn;
218 bool LoopPipelinerInternal::verifySchedule() {
219 int64_t numCylesPerIter = opOrder.size();
222 for (int64_t cycle = 0; cycle < numCylesPerIter; cycle++) {
224 auto it = stages.find(def);
225 assert(it != stages.end());
226 int64_t stage = it->second;
227 unrolledCyles[def] = cycle + stage * numCylesPerIter;
230 int64_t consumerCycle = unrolledCyles[consumer];
231 for (
Value operand : getNestedOperands(consumer)) {
232 auto [producer, distance] = getDefiningOpAndDistance(operand);
235 auto it = unrolledCyles.find(producer);
237 if (it == unrolledCyles.end())
239 int64_t producerCycle = it->second;
240 if (consumerCycle < producerCycle - numCylesPerIter * distance) {
241 consumer->emitError(
"operation scheduled before its operands");
259 for (
OpOperand &operand : nested->getOpOperands()) {
260 Operation *def = operand.get().getDefiningOp();
261 if ((def && !
clone->
isAncestor(def)) || isa<BlockArgument>(operand.get()))
268 LogicalResult LoopPipelinerInternal::emitPrologue(
RewriterBase &rewriter) {
270 for (
auto [arg, operand] :
271 llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
272 setValueMapping(arg, operand.get(), 0);
274 auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
277 for (int64_t i = 0; i < maxStage; i++) {
279 Type t = ub.getType();
283 rewriter.
create<arith::MulIOp>(
285 rewriter.
create<arith::ConstantOp>(
287 predicates[i] = rewriter.
create<arith::CmpIOp>(
288 loc, arith::CmpIPredicate::slt, iv, ub);
293 Type t = lb.getType();
296 rewriter.
create<arith::MulIOp>(
298 rewriter.
create<arith::ConstantOp>(loc,
300 setValueMapping(forOp.getInductionVar(), iv, i);
305 cloneAndUpdateOperands(rewriter, op, [&](
OpOperand *newOperand) {
306 auto it = valueMapping.find(newOperand->
get());
307 if (it != valueMapping.end()) {
308 Value replacement = it->second[i - stages[op]];
309 newOperand->set(replacement);
312 int predicateIdx = i - stages[op];
313 if (predicates[predicateIdx]) {
315 newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
316 if (newOp ==
nullptr)
321 for (
unsigned destId : llvm::seq(
unsigned(0), op->getNumResults())) {
324 for (
OpOperand &operand : yield->getOpOperands()) {
325 if (operand.get() != op->getResult(destId))
327 if (predicates[predicateIdx] &&
328 !forOp.getResult(operand.getOperandNumber()).use_empty()) {
331 Value prevValue = valueMapping
332 [forOp.getRegionIterArgs()[operand.getOperandNumber()]]
334 source = rewriter.
create<arith::SelectOp>(
335 loc, predicates[predicateIdx], source, prevValue);
337 setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
338 source, i - stages[op] + 1);
340 setValueMapping(op->getResult(destId), newOp->
getResult(destId),
348 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
349 LoopPipelinerInternal::analyzeCrossStageValues() {
350 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues;
352 unsigned stage = stages[op];
354 auto analyzeOperand = [&](
OpOperand &operand) {
355 auto [def, distance] = getDefiningOpAndDistance(operand.get());
358 auto defStage = stages.find(def);
359 if (defStage == stages.end() || defStage->second == stage ||
360 defStage->second == stage + distance)
362 assert(stage > defStage->second);
363 LiverangeInfo &info = crossStageValues[operand.get()];
364 info.defStage = defStage->second;
365 info.lastUseStage =
std::max(info.lastUseStage, stage);
368 for (
OpOperand &operand : op->getOpOperands())
369 analyzeOperand(operand);
371 analyzeOperand(*operand);
374 return crossStageValues;
377 std::pair<Operation *, int64_t>
378 LoopPipelinerInternal::getDefiningOpAndDistance(
Value value) {
379 int64_t distance = 0;
380 if (
auto arg = dyn_cast<BlockArgument>(value)) {
381 if (arg.getOwner() != forOp.getBody())
384 if (arg.getArgNumber() == 0)
388 forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1);
393 return {def, distance};
396 scf::ForOp LoopPipelinerInternal::createKernelLoop(
397 const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
400 llvm::DenseMap<std::pair<Value, unsigned>,
unsigned> &loopArgMap) {
408 for (
const auto &retVal :
410 Operation *def = retVal.value().getDefiningOp();
411 assert(def &&
"Only support loop carried dependencies of distance of 1 or "
413 auto defStage = stages.find(def);
414 if (defStage != stages.end()) {
416 valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
417 [maxStage - defStage->second];
418 assert(valueVersion);
419 newLoopArg.push_back(valueVersion);
421 newLoopArg.push_back(forOp.getInitArgs()[retVal.index()]);
423 for (
auto escape : crossStageValues) {
424 LiverangeInfo &info = escape.second;
425 Value value = escape.first;
426 for (
unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage;
429 valueMapping[value][maxStage - info.lastUseStage + stageIdx];
430 assert(valueVersion);
431 newLoopArg.push_back(valueVersion);
432 loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage -
433 stageIdx)] = newLoopArg.size() - 1;
440 Value newUb = forOp.getUpperBound();
442 Type t = ub.getType();
445 Value maxStageValue = rewriter.
create<arith::ConstantOp>(
447 Value maxStageByStep =
448 rewriter.
create<arith::MulIOp>(loc, step, maxStageValue);
449 newUb = rewriter.
create<arith::SubIOp>(loc, ub, maxStageByStep);
452 rewriter.
create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb,
453 forOp.getStep(), newLoopArg);
456 if (!newForOp.getBody()->empty())
457 rewriter.
eraseOp(newForOp.getBody()->getTerminator());
461 LogicalResult LoopPipelinerInternal::createKernel(
463 const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
465 const llvm::DenseMap<std::pair<Value, unsigned>,
unsigned> &loopArgMap,
467 valueMapping.clear();
473 mapping.
map(forOp.getInductionVar(), newForOp.getInductionVar());
475 mapping.
map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
481 Type t = ub.getType();
482 for (
unsigned i = 0; i < maxStage; i++) {
486 rewriter.
create<arith::MulIOp>(
488 rewriter.
create<arith::ConstantOp>(
492 newForOp.getLoc(), arith::CmpIPredicate::slt,
493 newForOp.getInductionVar(), c);
494 predicates[i] = pred;
498 int64_t useStage = stages[op];
499 auto *newOp = rewriter.
clone(*op, mapping);
502 op->walk([&operands](
Operation *nestedOp) {
504 operands.push_back(&operand);
511 if (operand->get() == forOp.getInductionVar()) {
515 Type t = step.getType();
517 forOp.getLoc(), step,
518 rewriter.
create<arith::ConstantOp>(
522 forOp.getLoc(), newForOp.getInductionVar(), offset);
523 nestedNewOp->
setOperand(operand->getOperandNumber(), iv);
527 Value source = operand->get();
528 auto arg = dyn_cast<BlockArgument>(source);
529 if (arg && arg.getOwner() == forOp.getBody()) {
530 Value ret = forOp.getBody()->getTerminator()->getOperand(
531 arg.getArgNumber() - 1);
535 auto stageDep = stages.find(dep);
536 if (stageDep == stages.end() || stageDep->second == useStage)
540 if (stageDep->second == useStage + 1) {
541 nestedNewOp->
setOperand(operand->getOperandNumber(),
553 auto stageDef = stages.find(def);
554 if (stageDef == stages.end() || stageDef->second == useStage)
556 auto remap = loopArgMap.find(
557 std::make_pair(operand->get(), useStage - stageDef->second));
558 assert(remap != loopArgMap.end());
559 nestedNewOp->
setOperand(operand->getOperandNumber(),
560 newForOp.getRegionIterArgs()[remap->second]);
563 if (predicates[useStage]) {
565 newOp = predicateFn(rewriter, newOp, predicates[useStage]);
569 for (
auto values : llvm::zip(op->getResults(), newOp->getResults()))
570 mapping.
map(std::get<0>(values), std::get<1>(values));
583 forOp.getBody()->getTerminator()->getOpOperands()) {
589 !forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) {
590 Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first;
592 auto defStage = stages.find(def);
593 if (defStage != stages.end() && defStage->second < maxStage) {
594 Value pred = predicates[defStage->second];
595 source = rewriter.
create<arith::SelectOp>(
596 pred.
getLoc(), pred, source,
598 ->getArguments()[yieldOperand.getOperandNumber() + 1]);
602 yieldOperands.push_back(source);
605 for (
auto &it : crossStageValues) {
606 int64_t version = maxStage - it.second.lastUseStage + 1;
607 unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
611 for (
unsigned i = 1; i < numVersionReturned; i++) {
612 setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
614 yieldOperands.push_back(
615 newForOp.getBody()->getArguments()[yieldOperands.size() + 1 +
616 newForOp.getNumInductionVars()]);
618 setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
623 for (
const auto &retVal :
625 Operation *def = retVal.value().getDefiningOp();
626 assert(def &&
"Only support loop carried dependencies of distance of 1 or "
627 "defined outside the loop");
628 auto defStage = stages.find(def);
629 if (defStage == stages.end()) {
630 for (
unsigned int stage = 1; stage <= maxStage; stage++)
631 setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
632 retVal.value(), stage);
633 }
else if (defStage->second > 0) {
634 setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
635 newForOp->getResult(retVal.index()),
636 maxStage - defStage->second + 1);
639 rewriter.
create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
644 LoopPipelinerInternal::emitEpilogue(
RewriterBase &rewriter,
647 Type t = lb.getType();
653 return rewriter.
create<arith::ConstantOp>(loc,
662 Value stepLessZero = rewriter.
create<arith::CmpIOp>(
663 loc, arith::CmpIPredicate::slt, step, zero);
667 Value rangeDiff = rewriter.
create<arith::SubIOp>(loc, ub, lb);
668 Value rangeIncrStep = rewriter.
create<arith::AddIOp>(loc, rangeDiff, step);
670 rewriter.
create<arith::AddIOp>(loc, rangeIncrStep, stepDecr);
671 Value totalIterations = rewriter.
create<arith::DivSIOp>(loc, rangeDecr, step);
676 Value iterI = rewriter.
create<arith::SubIOp>(loc, totalIterations,
678 iterI = rewriter.
create<arith::MaxSIOp>(loc, zero, iterI);
683 for (int64_t i = 1; i <= maxStage; i++) {
685 Value newlastIter = rewriter.
create<arith::AddIOp>(
686 loc, lb, rewriter.
create<arith::MulIOp>(loc, step, iterI));
688 setValueMapping(forOp.getInductionVar(), newlastIter, i);
691 iterI = rewriter.
create<arith::AddIOp>(loc, iterI, one);
696 predicates[i] = rewriter.
create<arith::CmpIOp>(
697 loc, arith::CmpIPredicate::sge, totalIterations,
createConst(i));
703 for (int64_t i = 1; i <= maxStage; i++) {
708 unsigned currentVersion = maxStage - stages[op] + i;
709 unsigned nextVersion = currentVersion + 1;
711 cloneAndUpdateOperands(rewriter, op, [&](
OpOperand *newOperand) {
712 auto it = valueMapping.find(newOperand->
get());
713 if (it != valueMapping.end()) {
714 Value replacement = it->second[currentVersion];
715 newOperand->set(replacement);
720 newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
727 for (
auto [opRes, newRes] :
728 llvm::zip(op->getResults(), newOp->
getResults())) {
729 setValueMapping(opRes, newRes, currentVersion);
734 forOp.getBody()->getTerminator()->getOpOperands()) {
735 if (operand.get() != opRes)
739 unsigned ri = operand.getOperandNumber();
740 returnValues[ri] = newRes;
741 Value mapVal = forOp.getRegionIterArgs()[ri];
742 returnMap[ri] = std::make_pair(mapVal, currentVersion);
743 if (nextVersion <= maxStage)
744 setValueMapping(mapVal, newRes, nextVersion);
753 unsigned ri = pair.index();
754 auto [mapVal, currentVersion] = returnMap[ri];
756 unsigned nextVersion = currentVersion + 1;
757 Value pred = predicates[currentVersion];
758 Value prevValue = valueMapping[mapVal][currentVersion];
759 auto selOp = rewriter.
create<arith::SelectOp>(loc, pred, pair.value(),
761 returnValues[ri] = selOp;
762 if (nextVersion <= maxStage)
763 setValueMapping(mapVal, selOp, nextVersion);
771 void LoopPipelinerInternal::setValueMapping(
Value key,
Value el, int64_t idx) {
772 auto it = valueMapping.find(key);
775 if (it == valueMapping.end())
780 it->second[idx] = el;
790 LoopPipelinerInternal pipeliner;
791 if (!pipeliner.initializeLoopInfo(forOp,
options))
798 if (failed(pipeliner.emitPrologue(rewriter)))
805 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
806 crossStageValues = pipeliner.analyzeCrossStageValues();
814 pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap);
817 if (failed(pipeliner.createKernel(newForOp, crossStageValues, loopArgMap,
822 newForOp.getResults().take_front(forOp->getNumResults());
826 if (failed(pipeliner.emitEpilogue(rewriter, returnValues)))
830 if (forOp->getNumResults() > 0)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerAttr getIntegerAttr(Type type, int64_t value)
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
result_range getResults()
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns, const PipeliningOption &options)
Populate patterns for SCF software pipelining transformation.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
const FrozenRewritePatternSet & patterns
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
Options to dictate how loops should be pipelined.
std::function< void(Operation *, PipelinerPart, unsigned)> AnnotationlFnType
Lambda called by the pipeliner to allow the user to annotate the IR while it is generated.
std::function< Operation *(RewriterBase &, Operation *, Value)> PredicateOpFn