22 #include "llvm/ADT/MapVector.h"
23 #include "llvm/Support/Debug.h"
25 #define DEBUG_TYPE "scf-loop-pipelining"
26 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
27 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
35 struct LoopPipelinerInternal {
37 struct LiverangeInfo {
38 unsigned lastUseStage = 0;
39 unsigned defStage = 0;
44 unsigned maxStage = 0;
46 std::vector<Operation *> opOrder;
62 void setValueMapping(
Value key,
Value el, int64_t idx);
73 llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues();
74 scf::ForOp createKernelLoop(
75 const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
82 const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
83 const llvm::DenseMap<std::pair<Value, unsigned>,
unsigned> &loopArgMap,
90 bool LoopPipelinerInternal::initializeLoopInfo(
92 LDBG(
"Start initializeLoopInfo");
95 forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
97 forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
98 auto stepCst = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
99 if (!upperBoundCst || !lowerBoundCst || !stepCst) {
100 LDBG(
"--no constant bounds or step -> BAIL");
103 ub = upperBoundCst.value();
104 lb = lowerBoundCst.value();
105 step = stepCst.value();
106 peelEpilogue =
options.peelEpilogue;
107 predicateFn =
options.predicateFn;
108 if (!peelEpilogue && predicateFn ==
nullptr) {
109 LDBG(
"--no epilogue or predicate set -> BAIL");
112 int64_t numIteration =
ceilDiv(ub - lb, step);
113 std::vector<std::pair<Operation *, unsigned>> schedule;
114 options.getScheduleFn(forOp, schedule);
115 if (schedule.empty()) {
116 LDBG(
"--empty schedule -> BAIL");
120 opOrder.reserve(schedule.size());
121 for (
auto &opSchedule : schedule) {
122 maxStage =
std::max(maxStage, opSchedule.second);
123 stages[opSchedule.first] = opSchedule.second;
124 opOrder.push_back(opSchedule.first);
126 if (numIteration <= maxStage) {
127 LDBG(
"--fewer loop iterations than pipeline stages -> BAIL");
132 for (
Operation &op : forOp.getBody()->without_terminator()) {
133 if (!stages.contains(&op)) {
135 LDBG(
"--op not assigned a pipeline stage: " << op <<
" -> BAIL");
143 for (
const auto &[op, stageNum] : stages) {
145 if (op == forOp.getBody()->getTerminator()) {
146 op->
emitError(
"terminator should not be assigned a stage");
147 LDBG(
"--terminator should not be assigned stage: " << *op <<
" -> BAIL");
150 if (op->
getBlock() != forOp.getBody()) {
151 op->
emitOpError(
"the owning Block of all operations assigned a stage "
152 "should be the loop body block");
153 LDBG(
"--the owning Block of all operations assigned a stage "
154 "should be the loop body block: "
155 << *op <<
" -> BAIL");
163 if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
164 [
this](
Value operand) {
165 Operation *def = operand.getDefiningOp();
166 return !def || !stages.contains(def);
168 LDBG(
"--only support loop carried dependency with a distance of 1 -> BAIL");
171 annotateFn =
options.annotateFn;
187 Operation *def = operand.get().getDefiningOp();
188 if ((def && !clone->isAncestor(def)) || isa<BlockArgument>(operand.get()))
195 void LoopPipelinerInternal::emitPrologue(
RewriterBase &rewriter) {
197 for (
auto [arg, operand] :
198 llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
199 setValueMapping(arg, operand.get(), 0);
201 auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
202 for (int64_t i = 0; i < maxStage; i++) {
205 rewriter.
create<arith::ConstantIndexOp>(forOp.getLoc(), lb + i * step);
206 setValueMapping(forOp.getInductionVar(), iv, i);
211 cloneAndUpdateOperands(rewriter, op, [&](
OpOperand *newOperand) {
212 auto it = valueMapping.find(newOperand->
get());
213 if (it != valueMapping.end()) {
214 Value replacement = it->second[i - stages[op]];
215 newOperand->set(replacement);
220 for (
unsigned destId : llvm::seq(
unsigned(0), op->
getNumResults())) {
225 for (
OpOperand &operand : yield->getOpOperands()) {
226 if (operand.get() != op->
getResult(destId))
228 setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
229 newOp->
getResult(destId), i - stages[op] + 1);
236 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
237 LoopPipelinerInternal::analyzeCrossStageValues() {
238 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues;
240 unsigned stage = stages[op];
242 auto analyzeOperand = [&](
OpOperand &operand) {
243 Operation *def = operand.get().getDefiningOp();
246 auto defStage = stages.find(def);
247 if (defStage == stages.end() || defStage->second == stage)
249 assert(stage > defStage->second);
250 LiverangeInfo &info = crossStageValues[operand.get()];
251 info.defStage = defStage->second;
252 info.lastUseStage =
std::max(info.lastUseStage, stage);
256 analyzeOperand(operand);
258 analyzeOperand(*operand);
261 return crossStageValues;
264 scf::ForOp LoopPipelinerInternal::createKernelLoop(
265 const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
268 llvm::DenseMap<std::pair<Value, unsigned>,
unsigned> &loopArgMap) {
276 for (
const auto &retVal :
278 Operation *def = retVal.value().getDefiningOp();
279 assert(def &&
"Only support loop carried dependencies of distance 1");
280 unsigned defStage = stages[def];
281 Value valueVersion = valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
282 [maxStage - defStage];
283 assert(valueVersion);
284 newLoopArg.push_back(valueVersion);
286 for (
auto escape : crossStageValues) {
287 LiverangeInfo &info = escape.second;
288 Value value = escape.first;
289 for (
unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage;
292 valueMapping[value][maxStage - info.lastUseStage + stageIdx];
293 assert(valueVersion);
294 newLoopArg.push_back(valueVersion);
295 loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage -
296 stageIdx)] = newLoopArg.size() - 1;
303 Value newUb = forOp.getUpperBound();
305 newUb = rewriter.
create<arith::ConstantIndexOp>(forOp.getLoc(),
306 ub - maxStage * step);
308 rewriter.
create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb,
309 forOp.getStep(), newLoopArg);
312 if (!newForOp.getBody()->empty())
313 rewriter.
eraseOp(newForOp.getBody()->getTerminator());
319 const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
321 const llvm::DenseMap<std::pair<Value, unsigned>,
unsigned> &loopArgMap,
323 valueMapping.clear();
329 mapping.
map(forOp.getInductionVar(), newForOp.getInductionVar());
331 mapping.
map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
336 for (
unsigned i = 0; i < maxStage; i++) {
338 newForOp.getLoc(), ub - (maxStage - i) * step);
340 newForOp.getLoc(), arith::CmpIPredicate::slt,
341 newForOp.getInductionVar(), c);
342 predicates[i] = pred;
346 int64_t useStage = stages[op];
347 auto *newOp = rewriter.
clone(*op, mapping);
352 operands.push_back(&operand);
359 if (operand->get() == forOp.getInductionVar()) {
361 Value offset = rewriter.
create<arith::ConstantIndexOp>(
362 forOp.getLoc(), (maxStage - stages[op]) * step);
364 forOp.getLoc(), newForOp.getInductionVar(), offset);
365 nestedNewOp->
setOperand(operand->getOperandNumber(), iv);
369 auto arg = dyn_cast<BlockArgument>(operand->get());
370 if (arg && arg.getOwner() == forOp.getBody()) {
373 Value ret = forOp.getBody()->getTerminator()->getOperand(
374 arg.getArgNumber() - 1);
378 auto stageDep = stages.find(dep);
379 if (stageDep == stages.end() || stageDep->second == useStage)
381 assert(stageDep->second == useStage + 1);
382 nestedNewOp->
setOperand(operand->getOperandNumber(),
389 Operation *def = operand->get().getDefiningOp();
392 auto stageDef = stages.find(def);
393 if (stageDef == stages.end() || stageDef->second == useStage)
395 auto remap = loopArgMap.find(
396 std::make_pair(operand->get(), useStage - stageDef->second));
397 assert(remap != loopArgMap.end());
398 nestedNewOp->
setOperand(operand->getOperandNumber(),
399 newForOp.getRegionIterArgs()[remap->second]);
402 if (predicates[useStage]) {
403 newOp = predicateFn(rewriter, newOp, predicates[useStage]);
407 for (
auto values : llvm::zip(op->
getResults(), newOp->getResults()))
408 mapping.
map(std::get<0>(values), std::get<1>(values));
421 for (
Value retVal : forOp.getBody()->getTerminator()->getOperands()) {
424 for (
auto &it : crossStageValues) {
425 int64_t version = maxStage - it.second.lastUseStage + 1;
426 unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
430 for (
unsigned i = 1; i < numVersionReturned; i++) {
431 setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
433 yieldOperands.push_back(
434 newForOp.getBody()->getArguments()[yieldOperands.size() + 1 +
435 newForOp.getNumInductionVars()]);
437 setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
442 for (
const auto &retVal :
444 Operation *def = retVal.value().getDefiningOp();
445 assert(def &&
"Only support loop carried dependencies of distance 1");
446 unsigned defStage = stages[def];
447 setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
448 newForOp->getResult(retVal.index()),
449 maxStage - defStage + 1);
451 rewriter.
create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
456 LoopPipelinerInternal::emitEpilogue(
RewriterBase &rewriter) {
460 for (int64_t i = 0; i < maxStage; i++) {
461 Value newlastIter = rewriter.
create<arith::ConstantIndexOp>(
462 forOp.getLoc(), lb + step * ((((ub - 1) - lb) / step) - i));
463 setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
467 for (int64_t i = 1; i <= maxStage; i++) {
472 cloneAndUpdateOperands(rewriter, op, [&](
OpOperand *newOperand) {
473 auto it = valueMapping.find(newOperand->
get());
474 if (it != valueMapping.end()) {
475 Value replacement = it->second[maxStage - stages[op] + i];
476 newOperand->set(replacement);
481 for (
unsigned destId : llvm::seq(
unsigned(0), op->
getNumResults())) {
483 maxStage - stages[op] + i);
488 forOp.getBody()->getTerminator()->getOpOperands()) {
489 if (operand.get() != op->
getResult(destId))
491 unsigned version = maxStage - stages[op] + i + 1;
494 if (version > maxStage) {
495 returnValues[operand.getOperandNumber()] = newOp->
getResult(destId);
498 setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
507 void LoopPipelinerInternal::setValueMapping(
Value key,
Value el, int64_t idx) {
508 auto it = valueMapping.find(key);
511 if (it == valueMapping.end())
516 it->second[idx] = el;
526 LoopPipelinerInternal pipeliner;
527 if (!pipeliner.initializeLoopInfo(forOp,
options))
534 pipeliner.emitPrologue(rewriter);
540 llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
541 crossStageValues = pipeliner.analyzeCrossStageValues();
549 pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap);
552 if (
failed(pipeliner.createKernel(newForOp, crossStageValues, loopArgMap,
557 newForOp.getResults().take_front(forOp->getNumResults());
561 returnValues = pipeliner.emitEpilogue(rewriter);
564 if (forOp->getNumResults() > 0)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class provides support for representing a failure result, or a valid value of type T.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns, const PipeliningOption &options)
Populate patterns for SCF software pipelining transformation.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
void visitUsedValuesDefinedAbove(Region ®ion, Region &limit, function_ref< void(OpOperand *)> callback)
Calls callback for each use of a value within region or its descendants that was defined at the ances...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
Options to dictate how loops should be pipelined.
std::function< void(Operation *, PipelinerPart, unsigned)> AnnotationlFnType
Lambda called by the pipeliner to allow the user to annotate the IR while it is generated.
std::function< Operation *(RewriterBase &, Operation *, Value)> PredicateOpFn