22 #include "llvm/ADT/MapVector.h"
23 #include "llvm/Support/Debug.h"
25 #define DEBUG_TYPE "scf-loop-pipelining"
26 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
27 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
35 struct LoopPipelinerInternal {
37 struct LiverangeInfo {
38 unsigned lastUseStage = 0;
39 unsigned defStage = 0;
44 unsigned maxStage = 0;
46 std::vector<Operation *> opOrder;
63 void setValueMapping(
Value key,
Value el, int64_t idx);
68 std::pair<Operation *, int64_t> getDefiningOpAndDistance(
Value value);
72 bool verifySchedule();
83 llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues();
84 scf::ForOp createKernelLoop(
85 const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
92 const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
93 const llvm::DenseMap<std::pair<Value, unsigned>,
unsigned> &loopArgMap,
101 bool LoopPipelinerInternal::initializeLoopInfo(
103 LDBG(
"Start initializeLoopInfo");
105 ub = forOp.getUpperBound();
106 lb = forOp.getLowerBound();
107 step = forOp.getStep();
113 if (!upperBoundCst || !lowerBoundCst || !stepCst) {
114 if (!
options.supportDynamicLoops) {
115 LDBG(
"--dynamic loop not supported -> BAIL");
119 int64_t ubImm = upperBoundCst.value();
120 int64_t lbImm = lowerBoundCst.value();
121 int64_t stepImm = stepCst.value();
122 int64_t numIteration =
ceilDiv(ubImm - lbImm, stepImm);
123 if (numIteration > maxStage) {
125 }
else if (!
options.supportDynamicLoops) {
126 LDBG(
"--fewer loop iterations than pipeline stages -> BAIL");
130 peelEpilogue =
options.peelEpilogue;
131 predicateFn =
options.predicateFn;
132 if ((!peelEpilogue || dynamicLoop) && predicateFn ==
nullptr) {
133 LDBG(
"--no epilogue or predicate set -> BAIL");
136 if (dynamicLoop && peelEpilogue) {
137 LDBG(
"--dynamic loop doesn't support epilogue yet -> BAIL");
140 std::vector<std::pair<Operation *, unsigned>> schedule;
141 options.getScheduleFn(forOp, schedule);
142 if (schedule.empty()) {
143 LDBG(
"--empty schedule -> BAIL");
147 opOrder.reserve(schedule.size());
148 for (
auto &opSchedule : schedule) {
149 maxStage =
std::max(maxStage, opSchedule.second);
150 stages[opSchedule.first] = opSchedule.second;
151 opOrder.push_back(opSchedule.first);
155 for (
Operation &op : forOp.getBody()->without_terminator()) {
156 if (!stages.contains(&op)) {
158 LDBG(
"--op not assigned a pipeline stage: " << op <<
" -> BAIL");
163 if (!verifySchedule()) {
164 LDBG(
"--invalid schedule: " << op <<
" -> BAIL");
171 for (
const auto &[op, stageNum] : stages) {
173 if (op == forOp.getBody()->getTerminator()) {
174 op->
emitError(
"terminator should not be assigned a stage");
175 LDBG(
"--terminator should not be assigned stage: " << *op <<
" -> BAIL");
178 if (op->
getBlock() != forOp.getBody()) {
179 op->
emitOpError(
"the owning Block of all operations assigned a stage "
180 "should be the loop body block");
181 LDBG(
"--the owning Block of all operations assigned a stage "
182 "should be the loop body block: "
183 << *op <<
" -> BAIL");
192 if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
193 [
this](
Value operand) {
194 Operation *def = operand.getDefiningOp();
196 (!stages.contains(def) && forOp->isAncestor(def));
198 LDBG(
"--only support loop carried dependency with a distance of 1 or "
199 "defined outside of the loop -> BAIL");
202 annotateFn =
options.annotateFn;
211 operands.insert(operand);
220 bool LoopPipelinerInternal::verifySchedule() {
221 int64_t numCylesPerIter = opOrder.size();
224 for (int64_t cycle = 0; cycle < numCylesPerIter; cycle++) {
226 auto it = stages.find(def);
227 assert(it != stages.end());
228 int64_t stage = it->second;
229 unrolledCyles[def] = cycle + stage * numCylesPerIter;
232 int64_t consumerCycle = unrolledCyles[consumer];
233 for (
Value operand : getNestedOperands(consumer)) {
234 auto [producer, distance] = getDefiningOpAndDistance(operand);
237 auto it = unrolledCyles.find(producer);
239 if (it == unrolledCyles.end())
241 int64_t producerCycle = it->second;
242 if (consumerCycle < producerCycle - numCylesPerIter * distance) {
243 consumer->emitError(
"operation scheduled before its operands");
261 for (
OpOperand &operand : nested->getOpOperands()) {
262 Operation *def = operand.get().getDefiningOp();
263 if ((def && !
clone->
isAncestor(def)) || isa<BlockArgument>(operand.get()))
270 void LoopPipelinerInternal::emitPrologue(
RewriterBase &rewriter) {
272 for (
auto [arg, operand] :
273 llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
274 setValueMapping(arg, operand.get(), 0);
276 auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
279 for (int64_t i = 0; i < maxStage; i++) {
281 Type t = ub.getType();
285 rewriter.
create<arith::MulIOp>(
287 rewriter.
create<arith::ConstantOp>(
289 predicates[i] = rewriter.
create<arith::CmpIOp>(
290 loc, arith::CmpIPredicate::slt, iv, ub);
295 Type t = lb.getType();
298 rewriter.
create<arith::MulIOp>(
300 rewriter.
create<arith::ConstantOp>(loc,
302 setValueMapping(forOp.getInductionVar(), iv, i);
307 cloneAndUpdateOperands(rewriter, op, [&](
OpOperand *newOperand) {
308 auto it = valueMapping.find(newOperand->
get());
309 if (it != valueMapping.end()) {
310 Value replacement = it->second[i - stages[op]];
311 newOperand->set(replacement);
314 int predicateIdx = i - stages[op];
315 if (predicates[predicateIdx]) {
316 newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
317 assert(newOp &&
"failed to predicate op.");
322 for (
unsigned destId : llvm::seq(
unsigned(0), op->
getNumResults())) {
327 for (
OpOperand &operand : yield->getOpOperands()) {
328 if (operand.get() != op->
getResult(destId))
330 setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
331 newOp->
getResult(destId), i - stages[op] + 1);
338 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
339 LoopPipelinerInternal::analyzeCrossStageValues() {
340 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues;
342 unsigned stage = stages[op];
344 auto analyzeOperand = [&](
OpOperand &operand) {
345 auto [def, distance] = getDefiningOpAndDistance(operand.get());
348 auto defStage = stages.find(def);
349 if (defStage == stages.end() || defStage->second == stage ||
350 defStage->second == stage + distance)
352 assert(stage > defStage->second);
353 LiverangeInfo &info = crossStageValues[operand.get()];
354 info.defStage = defStage->second;
355 info.lastUseStage =
std::max(info.lastUseStage, stage);
359 analyzeOperand(operand);
361 analyzeOperand(*operand);
364 return crossStageValues;
367 std::pair<Operation *, int64_t>
368 LoopPipelinerInternal::getDefiningOpAndDistance(
Value value) {
369 int64_t distance = 0;
370 if (
auto arg = dyn_cast<BlockArgument>(value)) {
371 if (arg.getOwner() != forOp.getBody())
374 if (arg.getArgNumber() == 0)
378 forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1);
383 return {def, distance};
386 scf::ForOp LoopPipelinerInternal::createKernelLoop(
387 const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
390 llvm::DenseMap<std::pair<Value, unsigned>,
unsigned> &loopArgMap) {
398 for (
const auto &retVal :
400 Operation *def = retVal.value().getDefiningOp();
401 assert(def &&
"Only support loop carried dependencies of distance of 1 or "
403 auto defStage = stages.find(def);
404 if (defStage != stages.end()) {
406 valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
407 [maxStage - defStage->second];
408 assert(valueVersion);
409 newLoopArg.push_back(valueVersion);
411 newLoopArg.push_back(forOp.getInitArgs()[retVal.index()]);
413 for (
auto escape : crossStageValues) {
414 LiverangeInfo &info = escape.second;
415 Value value = escape.first;
416 for (
unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage;
419 valueMapping[value][maxStage - info.lastUseStage + stageIdx];
420 assert(valueVersion);
421 newLoopArg.push_back(valueVersion);
422 loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage -
423 stageIdx)] = newLoopArg.size() - 1;
430 Value newUb = forOp.getUpperBound();
432 Type t = ub.getType();
435 Value maxStageValue = rewriter.
create<arith::ConstantOp>(
437 Value maxStageByStep =
438 rewriter.
create<arith::MulIOp>(loc, step, maxStageValue);
439 newUb = rewriter.
create<arith::SubIOp>(loc, ub, maxStageByStep);
442 rewriter.
create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb,
443 forOp.getStep(), newLoopArg);
446 if (!newForOp.getBody()->empty())
447 rewriter.
eraseOp(newForOp.getBody()->getTerminator());
453 const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
455 const llvm::DenseMap<std::pair<Value, unsigned>,
unsigned> &loopArgMap,
457 valueMapping.clear();
463 mapping.
map(forOp.getInductionVar(), newForOp.getInductionVar());
465 mapping.
map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
471 Type t = ub.getType();
472 for (
unsigned i = 0; i < maxStage; i++) {
476 rewriter.
create<arith::MulIOp>(
478 rewriter.
create<arith::ConstantOp>(
482 newForOp.getLoc(), arith::CmpIPredicate::slt,
483 newForOp.getInductionVar(), c);
484 predicates[i] = pred;
488 int64_t useStage = stages[op];
489 auto *newOp = rewriter.
clone(*op, mapping);
494 operands.push_back(&operand);
501 if (operand->get() == forOp.getInductionVar()) {
505 Type t = step.getType();
507 forOp.getLoc(), step,
508 rewriter.
create<arith::ConstantOp>(
512 forOp.getLoc(), newForOp.getInductionVar(), offset);
513 nestedNewOp->
setOperand(operand->getOperandNumber(), iv);
517 Value source = operand->get();
518 auto arg = dyn_cast<BlockArgument>(source);
519 if (arg && arg.getOwner() == forOp.getBody()) {
520 Value ret = forOp.getBody()->getTerminator()->getOperand(
521 arg.getArgNumber() - 1);
525 auto stageDep = stages.find(dep);
526 if (stageDep == stages.end() || stageDep->second == useStage)
530 if (stageDep->second == useStage + 1) {
531 nestedNewOp->
setOperand(operand->getOperandNumber(),
543 auto stageDef = stages.find(def);
544 if (stageDef == stages.end() || stageDef->second == useStage)
546 auto remap = loopArgMap.find(
547 std::make_pair(operand->get(), useStage - stageDef->second));
548 assert(remap != loopArgMap.end());
549 nestedNewOp->
setOperand(operand->getOperandNumber(),
550 newForOp.getRegionIterArgs()[remap->second]);
553 if (predicates[useStage]) {
554 newOp = predicateFn(rewriter, newOp, predicates[useStage]);
558 for (
auto values : llvm::zip(op->
getResults(), newOp->getResults()))
559 mapping.
map(std::get<0>(values), std::get<1>(values));
573 forOp.getBody()->getTerminator()->getOpOperands()) {
579 !forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) {
580 Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first;
582 auto defStage = stages.find(def);
583 if (defStage != stages.end() && defStage->second < maxStage) {
584 Value pred = predicates[defStage->second];
585 source = rewriter.
create<arith::SelectOp>(
586 pred.
getLoc(), pred, source,
588 ->getArguments()[yieldOperand.getOperandNumber() + 1]);
592 yieldOperands.push_back(source);
595 for (
auto &it : crossStageValues) {
596 int64_t version = maxStage - it.second.lastUseStage + 1;
597 unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
601 for (
unsigned i = 1; i < numVersionReturned; i++) {
602 setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
604 yieldOperands.push_back(
605 newForOp.getBody()->getArguments()[yieldOperands.size() + 1 +
606 newForOp.getNumInductionVars()]);
608 setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
613 for (
const auto &retVal :
615 Operation *def = retVal.value().getDefiningOp();
616 assert(def &&
"Only support loop carried dependencies of distance of 1 or "
617 "defined outside the loop");
618 auto defStage = stages.find(def);
619 if (defStage == stages.end()) {
620 for (
unsigned int stage = 1; stage <= maxStage; stage++)
621 setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
622 retVal.value(), stage);
623 }
else if (defStage->second > 0) {
624 setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
625 newForOp->getResult(retVal.index()),
626 maxStage - defStage->second + 1);
629 rewriter.
create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
633 void LoopPipelinerInternal::emitEpilogue(
637 for (int64_t i = 0; i < maxStage; i++) {
639 Type t = lb.getType();
643 Value totalNumIteration = rewriter.
create<arith::DivUIOp>(
645 rewriter.
create<arith::SubIOp>(
646 loc, rewriter.
create<arith::AddIOp>(loc, ub, minusOne), lb),
651 Value newlastIter = rewriter.
create<arith::AddIOp>(
653 rewriter.
create<arith::MulIOp>(
655 rewriter.
create<arith::AddIOp>(loc, totalNumIteration, minusI)));
656 setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
660 for (int64_t i = 1; i <= maxStage; i++) {
665 cloneAndUpdateOperands(rewriter, op, [&](
OpOperand *newOperand) {
666 auto it = valueMapping.find(newOperand->
get());
667 if (it != valueMapping.end()) {
668 Value replacement = it->second[maxStage - stages[op] + i];
669 newOperand->set(replacement);
674 for (
unsigned destId : llvm::seq(
unsigned(0), op->
getNumResults())) {
676 maxStage - stages[op] + i);
681 forOp.getBody()->getTerminator()->getOpOperands()) {
682 if (operand.get() != op->
getResult(destId))
684 unsigned version = maxStage - stages[op] + i + 1;
687 if (version > maxStage) {
688 returnValues[operand.getOperandNumber()] = newOp->
getResult(destId);
691 setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
699 void LoopPipelinerInternal::setValueMapping(
Value key,
Value el, int64_t idx) {
700 auto it = valueMapping.find(key);
703 if (it == valueMapping.end())
708 it->second[idx] = el;
718 LoopPipelinerInternal pipeliner;
719 if (!pipeliner.initializeLoopInfo(forOp,
options))
726 pipeliner.emitPrologue(rewriter);
732 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
733 crossStageValues = pipeliner.analyzeCrossStageValues();
741 pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap);
744 if (
failed(pipeliner.createKernel(newForOp, crossStageValues, loopArgMap,
749 newForOp.getResults().take_front(forOp->getNumResults());
753 pipeliner.emitEpilogue(rewriter, returnValues);
756 if (forOp->getNumResults() > 0)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
IntegerAttr getIntegerAttr(Type type, int64_t value)
This class provides support for representing a failure result, or a valid value of type T.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns, const PipeliningOption &options)
Populate patterns for SCF software pipelining transformation.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
Options to dictate how loops should be pipelined.
std::function< void(Operation *, PipelinerPart, unsigned)> AnnotationlFnType
Lambda called by the pipeliner to allow the user to annotate the IR while it is generated.
std::function< Operation *(RewriterBase &, Operation *, Value)> PredicateOpFn