21 #include "llvm/ADT/MapVector.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/MathExtras.h"
25 #define DEBUG_TYPE "scf-loop-pipelining"
26 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
27 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
35 struct LoopPipelinerInternal {
37 struct LiverangeInfo {
38 unsigned lastUseStage = 0;
39 unsigned defStage = 0;
44 unsigned maxStage = 0;
46 std::vector<Operation *> opOrder;
63 void setValueMapping(
Value key,
Value el, int64_t idx);
68 std::pair<Operation *, int64_t> getDefiningOpAndDistance(
Value value);
72 bool verifySchedule();
83 llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues();
84 scf::ForOp createKernelLoop(
85 const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
90 LogicalResult createKernel(
92 const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
93 const llvm::DenseMap<std::pair<Value, unsigned>,
unsigned> &loopArgMap,
101 bool LoopPipelinerInternal::initializeLoopInfo(
103 LDBG(
"Start initializeLoopInfo");
105 ub = forOp.getUpperBound();
106 lb = forOp.getLowerBound();
107 step = forOp.getStep();
109 std::vector<std::pair<Operation *, unsigned>> schedule;
110 options.getScheduleFn(forOp, schedule);
111 if (schedule.empty()) {
112 LDBG(
"--empty schedule -> BAIL");
116 opOrder.reserve(schedule.size());
117 for (
auto &opSchedule : schedule) {
118 maxStage =
std::max(maxStage, opSchedule.second);
119 stages[opSchedule.first] = opSchedule.second;
120 opOrder.push_back(opSchedule.first);
127 if (!upperBoundCst || !lowerBoundCst || !stepCst) {
128 if (!
options.supportDynamicLoops) {
129 LDBG(
"--dynamic loop not supported -> BAIL");
133 int64_t ubImm = upperBoundCst.value();
134 int64_t lbImm = lowerBoundCst.value();
135 int64_t stepImm = stepCst.value();
137 LDBG(
"--invalid loop step -> BAIL");
140 int64_t numIteration = llvm::divideCeilSigned(ubImm - lbImm, stepImm);
141 if (numIteration >= maxStage) {
143 }
else if (!
options.supportDynamicLoops) {
144 LDBG(
"--fewer loop iterations than pipeline stages -> BAIL");
148 peelEpilogue =
options.peelEpilogue;
149 predicateFn =
options.predicateFn;
150 if ((!peelEpilogue || dynamicLoop) && predicateFn ==
nullptr) {
151 LDBG(
"--no epilogue or predicate set -> BAIL");
156 for (
Operation &op : forOp.getBody()->without_terminator()) {
157 if (!stages.contains(&op)) {
158 op.emitOpError(
"not assigned a pipeline stage");
159 LDBG(
"--op not assigned a pipeline stage: " << op <<
" -> BAIL");
164 if (!verifySchedule()) {
165 LDBG(
"--invalid schedule: " << op <<
" -> BAIL");
172 for (
const auto &[op, stageNum] : stages) {
174 if (op == forOp.getBody()->getTerminator()) {
175 op->emitError(
"terminator should not be assigned a stage");
176 LDBG(
"--terminator should not be assigned stage: " << *op <<
" -> BAIL");
179 if (op->getBlock() != forOp.getBody()) {
180 op->emitOpError(
"the owning Block of all operations assigned a stage "
181 "should be the loop body block");
182 LDBG(
"--the owning Block of all operations assigned a stage "
183 "should be the loop body block: "
184 << *op <<
" -> BAIL");
193 if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
194 [
this](
Value operand) {
195 Operation *def = operand.getDefiningOp();
197 (!stages.contains(def) && forOp->isAncestor(def));
199 LDBG(
"--only support loop carried dependency with a distance of 1 or "
200 "defined outside of the loop -> BAIL");
203 annotateFn =
options.annotateFn;
219 bool LoopPipelinerInternal::verifySchedule() {
220 int64_t numCylesPerIter = opOrder.size();
223 for (int64_t cycle = 0; cycle < numCylesPerIter; cycle++) {
225 auto it = stages.find(def);
226 assert(it != stages.end());
227 int64_t stage = it->second;
228 unrolledCyles[def] = cycle + stage * numCylesPerIter;
231 int64_t consumerCycle = unrolledCyles[consumer];
232 for (
Value operand : getNestedOperands(consumer)) {
233 auto [producer, distance] = getDefiningOpAndDistance(operand);
236 auto it = unrolledCyles.find(producer);
238 if (it == unrolledCyles.end())
240 int64_t producerCycle = it->second;
241 if (consumerCycle < producerCycle - numCylesPerIter * distance) {
242 consumer->emitError(
"operation scheduled before its operands");
260 for (
OpOperand &operand : nested->getOpOperands()) {
261 Operation *def = operand.get().getDefiningOp();
262 if ((def && !
clone->
isAncestor(def)) || isa<BlockArgument>(operand.get()))
269 LogicalResult LoopPipelinerInternal::emitPrologue(
RewriterBase &rewriter) {
271 for (
auto [arg, operand] :
272 llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
273 setValueMapping(arg, operand.get(), 0);
275 auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
278 for (int64_t i = 0; i < maxStage; i++) {
280 Type t = ub.getType();
284 rewriter.
create<arith::MulIOp>(
286 rewriter.
create<arith::ConstantOp>(
288 predicates[i] = rewriter.
create<arith::CmpIOp>(
289 loc, arith::CmpIPredicate::slt, iv, ub);
294 Type t = lb.getType();
297 rewriter.
create<arith::MulIOp>(
299 rewriter.
create<arith::ConstantOp>(loc,
301 setValueMapping(forOp.getInductionVar(), iv, i);
306 cloneAndUpdateOperands(rewriter, op, [&](
OpOperand *newOperand) {
307 auto it = valueMapping.find(newOperand->
get());
308 if (it != valueMapping.end()) {
309 Value replacement = it->second[i - stages[op]];
310 newOperand->set(replacement);
313 int predicateIdx = i - stages[op];
314 if (predicates[predicateIdx]) {
316 newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
317 if (newOp ==
nullptr)
322 for (
unsigned destId : llvm::seq(
unsigned(0), op->getNumResults())) {
325 for (
OpOperand &operand : yield->getOpOperands()) {
326 if (operand.get() != op->getResult(destId))
328 if (predicates[predicateIdx] &&
329 !forOp.getResult(operand.getOperandNumber()).use_empty()) {
332 Value prevValue = valueMapping
333 [forOp.getRegionIterArgs()[operand.getOperandNumber()]]
335 source = rewriter.
create<arith::SelectOp>(
336 loc, predicates[predicateIdx], source, prevValue);
338 setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
339 source, i - stages[op] + 1);
341 setValueMapping(op->getResult(destId), newOp->
getResult(destId),
349 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
350 LoopPipelinerInternal::analyzeCrossStageValues() {
351 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues;
353 unsigned stage = stages[op];
355 auto analyzeOperand = [&](
OpOperand &operand) {
356 auto [def, distance] = getDefiningOpAndDistance(operand.get());
359 auto defStage = stages.find(def);
360 if (defStage == stages.end() || defStage->second == stage ||
361 defStage->second == stage + distance)
363 assert(stage > defStage->second);
364 LiverangeInfo &info = crossStageValues[operand.get()];
365 info.defStage = defStage->second;
366 info.lastUseStage =
std::max(info.lastUseStage, stage);
369 for (
OpOperand &operand : op->getOpOperands())
370 analyzeOperand(operand);
372 analyzeOperand(*operand);
375 return crossStageValues;
378 std::pair<Operation *, int64_t>
379 LoopPipelinerInternal::getDefiningOpAndDistance(
Value value) {
380 int64_t distance = 0;
381 if (
auto arg = dyn_cast<BlockArgument>(value)) {
382 if (arg.getOwner() != forOp.getBody())
385 if (arg.getArgNumber() == 0)
389 forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1);
394 return {def, distance};
397 scf::ForOp LoopPipelinerInternal::createKernelLoop(
398 const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
401 llvm::DenseMap<std::pair<Value, unsigned>,
unsigned> &loopArgMap) {
409 for (
const auto &retVal :
411 Operation *def = retVal.value().getDefiningOp();
412 assert(def &&
"Only support loop carried dependencies of distance of 1 or "
414 auto defStage = stages.find(def);
415 if (defStage != stages.end()) {
417 valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
418 [maxStage - defStage->second];
419 assert(valueVersion);
420 newLoopArg.push_back(valueVersion);
422 newLoopArg.push_back(forOp.getInitArgs()[retVal.index()]);
425 for (
auto escape : crossStageValues) {
426 LiverangeInfo &info = escape.second;
427 Value value = escape.first;
428 for (
unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage;
431 valueMapping[value][maxStage - info.lastUseStage + stageIdx];
432 assert(valueVersion);
433 newLoopArg.push_back(valueVersion);
434 loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage -
435 stageIdx)] = newLoopArg.size() - 1;
442 Value newUb = forOp.getUpperBound();
444 Type t = ub.getType();
447 Value maxStageValue = rewriter.
create<arith::ConstantOp>(
449 Value maxStageByStep =
450 rewriter.
create<arith::MulIOp>(loc, step, maxStageValue);
451 newUb = rewriter.
create<arith::SubIOp>(loc, ub, maxStageByStep);
454 rewriter.
create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb,
455 forOp.getStep(), newLoopArg);
458 if (!newForOp.getBody()->empty())
459 rewriter.
eraseOp(newForOp.getBody()->getTerminator());
463 LogicalResult LoopPipelinerInternal::createKernel(
465 const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
467 const llvm::DenseMap<std::pair<Value, unsigned>,
unsigned> &loopArgMap,
469 valueMapping.clear();
475 mapping.
map(forOp.getInductionVar(), newForOp.getInductionVar());
477 mapping.
map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
483 Type t = ub.getType();
484 for (
unsigned i = 0; i < maxStage; i++) {
488 rewriter.
create<arith::MulIOp>(
490 rewriter.
create<arith::ConstantOp>(
494 newForOp.getLoc(), arith::CmpIPredicate::slt,
495 newForOp.getInductionVar(), c);
496 predicates[i] = pred;
500 int64_t useStage = stages[op];
501 auto *newOp = rewriter.
clone(*op, mapping);
504 op->walk([&operands](
Operation *nestedOp) {
506 operands.push_back(&operand);
513 if (operand->get() == forOp.getInductionVar()) {
517 Type t = step.getType();
519 forOp.getLoc(), step,
520 rewriter.
create<arith::ConstantOp>(
524 forOp.getLoc(), newForOp.getInductionVar(), offset);
525 nestedNewOp->
setOperand(operand->getOperandNumber(), iv);
529 Value source = operand->get();
530 auto arg = dyn_cast<BlockArgument>(source);
531 if (arg && arg.getOwner() == forOp.getBody()) {
532 Value ret = forOp.getBody()->getTerminator()->getOperand(
533 arg.getArgNumber() - 1);
537 auto stageDep = stages.find(dep);
538 if (stageDep == stages.end() || stageDep->second == useStage)
542 if (stageDep->second == useStage + 1) {
543 nestedNewOp->
setOperand(operand->getOperandNumber(),
555 auto stageDef = stages.find(def);
556 if (stageDef == stages.end() || stageDef->second == useStage)
558 auto remap = loopArgMap.find(
559 std::make_pair(operand->get(), useStage - stageDef->second));
560 assert(remap != loopArgMap.end());
561 nestedNewOp->
setOperand(operand->getOperandNumber(),
562 newForOp.getRegionIterArgs()[remap->second]);
565 if (predicates[useStage]) {
567 newOp = predicateFn(rewriter, newOp, predicates[useStage]);
571 for (
auto values : llvm::zip(op->getResults(), newOp->getResults()))
572 mapping.
map(std::get<0>(values), std::get<1>(values));
585 forOp.getBody()->getTerminator()->getOpOperands()) {
591 !forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) {
592 Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first;
594 auto defStage = stages.find(def);
595 if (defStage != stages.end() && defStage->second < maxStage) {
596 Value pred = predicates[defStage->second];
597 source = rewriter.
create<arith::SelectOp>(
598 pred.
getLoc(), pred, source,
600 ->getArguments()[yieldOperand.getOperandNumber() + 1]);
604 yieldOperands.push_back(source);
607 for (
auto &it : crossStageValues) {
608 int64_t version = maxStage - it.second.lastUseStage + 1;
609 unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
613 for (
unsigned i = 1; i < numVersionReturned; i++) {
614 setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
616 yieldOperands.push_back(
617 newForOp.getBody()->getArguments()[yieldOperands.size() + 1 +
618 newForOp.getNumInductionVars()]);
620 setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
625 for (
const auto &retVal :
627 Operation *def = retVal.value().getDefiningOp();
628 assert(def &&
"Only support loop carried dependencies of distance of 1 or "
629 "defined outside the loop");
630 auto defStage = stages.find(def);
631 if (defStage == stages.end()) {
632 for (
unsigned int stage = 1; stage <= maxStage; stage++)
633 setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
634 retVal.value(), stage);
635 }
else if (defStage->second > 0) {
636 setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
637 newForOp->getResult(retVal.index()),
638 maxStage - defStage->second + 1);
641 rewriter.
create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
646 LoopPipelinerInternal::emitEpilogue(
RewriterBase &rewriter,
649 Type t = lb.getType();
655 return rewriter.
create<arith::ConstantOp>(loc,
664 Value stepLessZero = rewriter.
create<arith::CmpIOp>(
665 loc, arith::CmpIPredicate::slt, step, zero);
669 Value rangeDiff = rewriter.
create<arith::SubIOp>(loc, ub, lb);
670 Value rangeIncrStep = rewriter.
create<arith::AddIOp>(loc, rangeDiff, step);
672 rewriter.
create<arith::AddIOp>(loc, rangeIncrStep, stepDecr);
673 Value totalIterations = rewriter.
create<arith::DivSIOp>(loc, rangeDecr, step);
678 Value iterI = rewriter.
create<arith::SubIOp>(loc, totalIterations,
680 iterI = rewriter.
create<arith::MaxSIOp>(loc, zero, iterI);
685 for (int64_t i = 1; i <= maxStage; i++) {
687 Value newlastIter = rewriter.
create<arith::AddIOp>(
688 loc, lb, rewriter.
create<arith::MulIOp>(loc, step, iterI));
690 setValueMapping(forOp.getInductionVar(), newlastIter, i);
693 iterI = rewriter.
create<arith::AddIOp>(loc, iterI, one);
698 predicates[i] = rewriter.
create<arith::CmpIOp>(
699 loc, arith::CmpIPredicate::sge, totalIterations,
createConst(i));
705 for (int64_t i = 1; i <= maxStage; i++) {
710 unsigned currentVersion = maxStage - stages[op] + i;
711 unsigned nextVersion = currentVersion + 1;
713 cloneAndUpdateOperands(rewriter, op, [&](
OpOperand *newOperand) {
714 auto it = valueMapping.find(newOperand->
get());
715 if (it != valueMapping.end()) {
716 Value replacement = it->second[currentVersion];
717 newOperand->set(replacement);
722 newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
729 for (
auto [opRes, newRes] :
730 llvm::zip(op->getResults(), newOp->
getResults())) {
731 setValueMapping(opRes, newRes, currentVersion);
736 forOp.getBody()->getTerminator()->getOpOperands()) {
737 if (operand.get() != opRes)
741 unsigned ri = operand.getOperandNumber();
742 returnValues[ri] = newRes;
743 Value mapVal = forOp.getRegionIterArgs()[ri];
744 returnMap[ri] = std::make_pair(mapVal, currentVersion);
745 if (nextVersion <= maxStage)
746 setValueMapping(mapVal, newRes, nextVersion);
755 unsigned ri = pair.index();
756 auto [mapVal, currentVersion] = returnMap[ri];
758 unsigned nextVersion = currentVersion + 1;
759 Value pred = predicates[currentVersion];
760 Value prevValue = valueMapping[mapVal][currentVersion];
761 auto selOp = rewriter.
create<arith::SelectOp>(loc, pred, pair.value(),
763 returnValues[ri] = selOp;
764 if (nextVersion <= maxStage)
765 setValueMapping(mapVal, selOp, nextVersion);
773 void LoopPipelinerInternal::setValueMapping(
Value key,
Value el, int64_t idx) {
774 auto it = valueMapping.find(key);
777 if (it == valueMapping.end())
782 it->second[idx] = el;
792 LoopPipelinerInternal pipeliner;
793 if (!pipeliner.initializeLoopInfo(forOp,
options))
800 if (failed(pipeliner.emitPrologue(rewriter)))
807 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
808 crossStageValues = pipeliner.analyzeCrossStageValues();
816 pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap);
819 if (failed(pipeliner.createKernel(newForOp, crossStageValues, loopArgMap,
824 newForOp.getResults().take_front(forOp->getNumResults());
828 if (failed(pipeliner.emitEpilogue(rewriter, returnValues)))
832 if (forOp->getNumResults() > 0)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerAttr getIntegerAttr(Type type, int64_t value)
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
result_range getResults()
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns, const PipeliningOption &options)
Populate patterns for SCF software pipelining transformation.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
const FrozenRewritePatternSet & patterns
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
Options to dictate how loops should be pipelined.
std::function< void(Operation *, PipelinerPart, unsigned)> AnnotationlFnType
Lambda called by the pipeliner to allow the user to annotate the IR while it is generated.
std::function< Operation *(RewriterBase &, Operation *, Value)> PredicateOpFn