21 #include "llvm/ADT/MapVector.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/MathExtras.h"
25 #define DEBUG_TYPE "scf-loop-pipelining"
26 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
27 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
35 struct LoopPipelinerInternal {
37 struct LiverangeInfo {
38 unsigned lastUseStage = 0;
39 unsigned defStage = 0;
44 unsigned maxStage = 0;
46 std::vector<Operation *> opOrder;
63 void setValueMapping(
Value key,
Value el, int64_t idx);
68 std::pair<Operation *, int64_t> getDefiningOpAndDistance(
Value value);
72 bool verifySchedule();
83 llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues();
84 scf::ForOp createKernelLoop(
85 const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
90 LogicalResult createKernel(
92 const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
93 const llvm::DenseMap<std::pair<Value, unsigned>,
unsigned> &loopArgMap,
101 bool LoopPipelinerInternal::initializeLoopInfo(
103 LDBG(
"Start initializeLoopInfo");
105 ub = forOp.getUpperBound();
106 lb = forOp.getLowerBound();
107 step = forOp.getStep();
113 if (!upperBoundCst || !lowerBoundCst || !stepCst) {
114 if (!
options.supportDynamicLoops) {
115 LDBG(
"--dynamic loop not supported -> BAIL");
119 int64_t ubImm = upperBoundCst.value();
120 int64_t lbImm = lowerBoundCst.value();
121 int64_t stepImm = stepCst.value();
122 int64_t numIteration = llvm::divideCeilSigned(ubImm - lbImm, stepImm);
123 if (numIteration > maxStage) {
125 }
else if (!
options.supportDynamicLoops) {
126 LDBG(
"--fewer loop iterations than pipeline stages -> BAIL");
130 peelEpilogue =
options.peelEpilogue;
131 predicateFn =
options.predicateFn;
132 if ((!peelEpilogue || dynamicLoop) && predicateFn ==
nullptr) {
133 LDBG(
"--no epilogue or predicate set -> BAIL");
136 std::vector<std::pair<Operation *, unsigned>> schedule;
137 options.getScheduleFn(forOp, schedule);
138 if (schedule.empty()) {
139 LDBG(
"--empty schedule -> BAIL");
143 opOrder.reserve(schedule.size());
144 for (
auto &opSchedule : schedule) {
145 maxStage =
std::max(maxStage, opSchedule.second);
146 stages[opSchedule.first] = opSchedule.second;
147 opOrder.push_back(opSchedule.first);
151 for (
Operation &op : forOp.getBody()->without_terminator()) {
152 if (!stages.contains(&op)) {
154 LDBG(
"--op not assigned a pipeline stage: " << op <<
" -> BAIL");
159 if (!verifySchedule()) {
160 LDBG(
"--invalid schedule: " << op <<
" -> BAIL");
167 for (
const auto &[op, stageNum] : stages) {
169 if (op == forOp.getBody()->getTerminator()) {
170 op->
emitError(
"terminator should not be assigned a stage");
171 LDBG(
"--terminator should not be assigned stage: " << *op <<
" -> BAIL");
174 if (op->
getBlock() != forOp.getBody()) {
175 op->
emitOpError(
"the owning Block of all operations assigned a stage "
176 "should be the loop body block");
177 LDBG(
"--the owning Block of all operations assigned a stage "
178 "should be the loop body block: "
179 << *op <<
" -> BAIL");
188 if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
189 [
this](
Value operand) {
190 Operation *def = operand.getDefiningOp();
192 (!stages.contains(def) && forOp->isAncestor(def));
194 LDBG(
"--only support loop carried dependency with a distance of 1 or "
195 "defined outside of the loop -> BAIL");
198 annotateFn =
options.annotateFn;
207 operands.insert(operand);
216 bool LoopPipelinerInternal::verifySchedule() {
217 int64_t numCylesPerIter = opOrder.size();
220 for (int64_t cycle = 0; cycle < numCylesPerIter; cycle++) {
222 auto it = stages.find(def);
223 assert(it != stages.end());
224 int64_t stage = it->second;
225 unrolledCyles[def] = cycle + stage * numCylesPerIter;
228 int64_t consumerCycle = unrolledCyles[consumer];
229 for (
Value operand : getNestedOperands(consumer)) {
230 auto [producer, distance] = getDefiningOpAndDistance(operand);
233 auto it = unrolledCyles.find(producer);
235 if (it == unrolledCyles.end())
237 int64_t producerCycle = it->second;
238 if (consumerCycle < producerCycle - numCylesPerIter * distance) {
239 consumer->emitError(
"operation scheduled before its operands");
257 for (
OpOperand &operand : nested->getOpOperands()) {
258 Operation *def = operand.get().getDefiningOp();
259 if ((def && !
clone->
isAncestor(def)) || isa<BlockArgument>(operand.get()))
266 LogicalResult LoopPipelinerInternal::emitPrologue(
RewriterBase &rewriter) {
268 for (
auto [arg, operand] :
269 llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
270 setValueMapping(arg, operand.get(), 0);
272 auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
275 for (int64_t i = 0; i < maxStage; i++) {
277 Type t = ub.getType();
281 rewriter.
create<arith::MulIOp>(
283 rewriter.
create<arith::ConstantOp>(
285 predicates[i] = rewriter.
create<arith::CmpIOp>(
286 loc, arith::CmpIPredicate::slt, iv, ub);
291 Type t = lb.getType();
294 rewriter.
create<arith::MulIOp>(
296 rewriter.
create<arith::ConstantOp>(loc,
298 setValueMapping(forOp.getInductionVar(), iv, i);
303 cloneAndUpdateOperands(rewriter, op, [&](
OpOperand *newOperand) {
304 auto it = valueMapping.find(newOperand->
get());
305 if (it != valueMapping.end()) {
306 Value replacement = it->second[i - stages[op]];
307 newOperand->set(replacement);
310 int predicateIdx = i - stages[op];
311 if (predicates[predicateIdx]) {
313 newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
314 if (newOp ==
nullptr)
319 for (
unsigned destId : llvm::seq(
unsigned(0), op->
getNumResults())) {
322 for (
OpOperand &operand : yield->getOpOperands()) {
323 if (operand.get() != op->
getResult(destId))
325 if (predicates[predicateIdx] &&
326 !forOp.getResult(operand.getOperandNumber()).use_empty()) {
329 Value prevValue = valueMapping
330 [forOp.getRegionIterArgs()[operand.getOperandNumber()]]
332 source = rewriter.
create<arith::SelectOp>(
333 loc, predicates[predicateIdx], source, prevValue);
335 setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
336 source, i - stages[op] + 1);
346 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
347 LoopPipelinerInternal::analyzeCrossStageValues() {
348 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues;
350 unsigned stage = stages[op];
352 auto analyzeOperand = [&](
OpOperand &operand) {
353 auto [def, distance] = getDefiningOpAndDistance(operand.get());
356 auto defStage = stages.find(def);
357 if (defStage == stages.end() || defStage->second == stage ||
358 defStage->second == stage + distance)
360 assert(stage > defStage->second);
361 LiverangeInfo &info = crossStageValues[operand.get()];
362 info.defStage = defStage->second;
363 info.lastUseStage =
std::max(info.lastUseStage, stage);
367 analyzeOperand(operand);
369 analyzeOperand(*operand);
372 return crossStageValues;
375 std::pair<Operation *, int64_t>
376 LoopPipelinerInternal::getDefiningOpAndDistance(
Value value) {
377 int64_t distance = 0;
378 if (
auto arg = dyn_cast<BlockArgument>(value)) {
379 if (arg.getOwner() != forOp.getBody())
382 if (arg.getArgNumber() == 0)
386 forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1);
391 return {def, distance};
394 scf::ForOp LoopPipelinerInternal::createKernelLoop(
395 const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
398 llvm::DenseMap<std::pair<Value, unsigned>,
unsigned> &loopArgMap) {
406 for (
const auto &retVal :
408 Operation *def = retVal.value().getDefiningOp();
409 assert(def &&
"Only support loop carried dependencies of distance of 1 or "
411 auto defStage = stages.find(def);
412 if (defStage != stages.end()) {
414 valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
415 [maxStage - defStage->second];
416 assert(valueVersion);
417 newLoopArg.push_back(valueVersion);
419 newLoopArg.push_back(forOp.getInitArgs()[retVal.index()]);
421 for (
auto escape : crossStageValues) {
422 LiverangeInfo &info = escape.second;
423 Value value = escape.first;
424 for (
unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage;
427 valueMapping[value][maxStage - info.lastUseStage + stageIdx];
428 assert(valueVersion);
429 newLoopArg.push_back(valueVersion);
430 loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage -
431 stageIdx)] = newLoopArg.size() - 1;
438 Value newUb = forOp.getUpperBound();
440 Type t = ub.getType();
443 Value maxStageValue = rewriter.
create<arith::ConstantOp>(
445 Value maxStageByStep =
446 rewriter.
create<arith::MulIOp>(loc, step, maxStageValue);
447 newUb = rewriter.
create<arith::SubIOp>(loc, ub, maxStageByStep);
450 rewriter.
create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb,
451 forOp.getStep(), newLoopArg);
454 if (!newForOp.getBody()->empty())
455 rewriter.
eraseOp(newForOp.getBody()->getTerminator());
459 LogicalResult LoopPipelinerInternal::createKernel(
461 const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
463 const llvm::DenseMap<std::pair<Value, unsigned>,
unsigned> &loopArgMap,
465 valueMapping.clear();
471 mapping.
map(forOp.getInductionVar(), newForOp.getInductionVar());
473 mapping.
map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
479 Type t = ub.getType();
480 for (
unsigned i = 0; i < maxStage; i++) {
484 rewriter.
create<arith::MulIOp>(
486 rewriter.
create<arith::ConstantOp>(
490 newForOp.getLoc(), arith::CmpIPredicate::slt,
491 newForOp.getInductionVar(), c);
492 predicates[i] = pred;
496 int64_t useStage = stages[op];
497 auto *newOp = rewriter.
clone(*op, mapping);
502 operands.push_back(&operand);
509 if (operand->get() == forOp.getInductionVar()) {
513 Type t = step.getType();
515 forOp.getLoc(), step,
516 rewriter.
create<arith::ConstantOp>(
520 forOp.getLoc(), newForOp.getInductionVar(), offset);
521 nestedNewOp->
setOperand(operand->getOperandNumber(), iv);
525 Value source = operand->get();
526 auto arg = dyn_cast<BlockArgument>(source);
527 if (arg && arg.getOwner() == forOp.getBody()) {
528 Value ret = forOp.getBody()->getTerminator()->getOperand(
529 arg.getArgNumber() - 1);
533 auto stageDep = stages.find(dep);
534 if (stageDep == stages.end() || stageDep->second == useStage)
538 if (stageDep->second == useStage + 1) {
539 nestedNewOp->
setOperand(operand->getOperandNumber(),
551 auto stageDef = stages.find(def);
552 if (stageDef == stages.end() || stageDef->second == useStage)
554 auto remap = loopArgMap.find(
555 std::make_pair(operand->get(), useStage - stageDef->second));
556 assert(remap != loopArgMap.end());
557 nestedNewOp->
setOperand(operand->getOperandNumber(),
558 newForOp.getRegionIterArgs()[remap->second]);
561 if (predicates[useStage]) {
563 newOp = predicateFn(rewriter, newOp, predicates[useStage]);
567 for (
auto values : llvm::zip(op->
getResults(), newOp->getResults()))
568 mapping.
map(std::get<0>(values), std::get<1>(values));
581 forOp.getBody()->getTerminator()->getOpOperands()) {
587 !forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) {
588 Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first;
590 auto defStage = stages.find(def);
591 if (defStage != stages.end() && defStage->second < maxStage) {
592 Value pred = predicates[defStage->second];
593 source = rewriter.
create<arith::SelectOp>(
594 pred.
getLoc(), pred, source,
596 ->getArguments()[yieldOperand.getOperandNumber() + 1]);
600 yieldOperands.push_back(source);
603 for (
auto &it : crossStageValues) {
604 int64_t version = maxStage - it.second.lastUseStage + 1;
605 unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
609 for (
unsigned i = 1; i < numVersionReturned; i++) {
610 setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
612 yieldOperands.push_back(
613 newForOp.getBody()->getArguments()[yieldOperands.size() + 1 +
614 newForOp.getNumInductionVars()]);
616 setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
621 for (
const auto &retVal :
623 Operation *def = retVal.value().getDefiningOp();
624 assert(def &&
"Only support loop carried dependencies of distance of 1 or "
625 "defined outside the loop");
626 auto defStage = stages.find(def);
627 if (defStage == stages.end()) {
628 for (
unsigned int stage = 1; stage <= maxStage; stage++)
629 setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
630 retVal.value(), stage);
631 }
else if (defStage->second > 0) {
632 setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
633 newForOp->getResult(retVal.index()),
634 maxStage - defStage->second + 1);
637 rewriter.
create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
642 LoopPipelinerInternal::emitEpilogue(
RewriterBase &rewriter,
650 Type t = lb.getType();
653 Value boundsRange = rewriter.
create<arith::SubIOp>(loc, ub, lb);
654 Value rangeIncr = rewriter.
create<arith::AddIOp>(loc, boundsRange, step);
655 Value rangeDecr = rewriter.
create<arith::AddIOp>(loc, rangeIncr, minus1);
656 Value totalIterations = rewriter.
create<arith::DivUIOp>(loc, rangeDecr, step);
659 for (int64_t i = 0; i < maxStage; i++) {
665 loc, rewriter.
create<arith::AddIOp>(loc, totalIterations, minus1),
668 Value newlastIter = rewriter.
create<arith::AddIOp>(
669 loc, lb, rewriter.
create<arith::MulIOp>(loc, step, iterI));
671 setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
675 predicates[i + 1] = rewriter.
create<arith::CmpIOp>(
676 loc, arith::CmpIPredicate::sge, iterI, lb);
682 for (int64_t i = 1; i <= maxStage; i++) {
687 unsigned currentVersion = maxStage - stages[op] + i;
688 unsigned nextVersion = currentVersion + 1;
690 cloneAndUpdateOperands(rewriter, op, [&](
OpOperand *newOperand) {
691 auto it = valueMapping.find(newOperand->
get());
692 if (it != valueMapping.end()) {
693 Value replacement = it->second[currentVersion];
694 newOperand->set(replacement);
699 newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
706 for (
auto [opRes, newRes] :
708 setValueMapping(opRes, newRes, currentVersion);
713 forOp.getBody()->getTerminator()->getOpOperands()) {
714 if (operand.get() != opRes)
718 unsigned ri = operand.getOperandNumber();
719 returnValues[ri] = newRes;
720 Value mapVal = forOp.getRegionIterArgs()[ri];
721 returnMap[ri] = std::make_pair(mapVal, currentVersion);
722 if (nextVersion <= maxStage)
723 setValueMapping(mapVal, newRes, nextVersion);
732 unsigned ri = pair.index();
733 auto [mapVal, currentVersion] = returnMap[ri];
735 unsigned nextVersion = currentVersion + 1;
736 Value pred = predicates[currentVersion];
737 Value prevValue = valueMapping[mapVal][currentVersion];
738 auto selOp = rewriter.
create<arith::SelectOp>(loc, pred, pair.value(),
740 returnValues[ri] = selOp;
741 if (nextVersion <= maxStage)
742 setValueMapping(mapVal, selOp, nextVersion);
750 void LoopPipelinerInternal::setValueMapping(
Value key,
Value el, int64_t idx) {
751 auto it = valueMapping.find(key);
754 if (it == valueMapping.end())
759 it->second[idx] = el;
769 LoopPipelinerInternal pipeliner;
770 if (!pipeliner.initializeLoopInfo(forOp,
options))
777 if (failed(pipeliner.emitPrologue(rewriter)))
784 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
785 crossStageValues = pipeliner.analyzeCrossStageValues();
793 pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap);
796 if (failed(pipeliner.createKernel(newForOp, crossStageValues, loopArgMap,
801 newForOp.getResults().take_front(forOp->getNumResults());
805 if (failed(pipeliner.emitEpilogue(rewriter, returnValues)))
809 if (forOp->getNumResults() > 0)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerAttr getIntegerAttr(Type type, int64_t value)
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns, const PipeliningOption &options)
Populate patterns for SCF software pipelining transformation.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
Options to dictate how loops should be pipelined.
std::function< void(Operation *, PipelinerPart, unsigned)> AnnotationlFnType
Lambda called by the pipeliner to allow the user to annotate the IR while it is generated.
std::function< Operation *(RewriterBase &, Operation *, Value)> PredicateOpFn