26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallPtrSet.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/MathExtras.h"
36 #define DEBUG_TYPE "scf-utils"
37 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
38 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
43 bool replaceIterOperandsUsesInLoop) {
49 assert(loopNest.size() <= 10 &&
50 "exceeded recursion limit when yielding value from loop nest");
82 if (loopNest.size() == 1) {
84 cast<scf::ForOp>(*loopNest.back().replaceWithAdditionalYields(
85 rewriter, newIterOperands, replaceIterOperandsUsesInLoop,
87 return {innerMostLoop};
97 innerNewBBArgs, newYieldValuesFn,
98 replaceIterOperandsUsesInLoop);
99 return llvm::to_vector(llvm::map_range(
100 newLoopNest.front().getResults().take_back(innerNewBBArgs.size()),
103 scf::ForOp outerMostLoop =
104 cast<scf::ForOp>(*loopNest.front().replaceWithAdditionalYields(
105 rewriter, newIterOperands, replaceIterOperandsUsesInLoop, fn));
106 newLoopNest.insert(newLoopNest.begin(), outerMostLoop);
123 func::CallOp *callOp) {
124 assert(!funcName.empty() &&
"funcName cannot be empty");
138 ValueRange outlinedValues(captures.getArrayRef());
145 outlinedFuncArgTypes.push_back(arg.getType());
146 outlinedFuncArgLocs.push_back(arg.getLoc());
148 for (
Value value : outlinedValues) {
149 outlinedFuncArgTypes.push_back(value.getType());
150 outlinedFuncArgLocs.push_back(value.getLoc());
152 FunctionType outlinedFuncType =
156 rewriter.
create<func::FuncOp>(loc, funcName, outlinedFuncType);
157 Block *outlinedFuncBody = outlinedFunc.addEntryBlock();
162 auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments();
167 originalBlock, outlinedFuncBody,
168 outlinedFuncBlockArgs.take_front(numOriginalBlockArguments));
178 ®ion, region.
begin(),
179 TypeRange{outlinedFuncArgTypes}.take_front(numOriginalBlockArguments),
181 .take_front(numOriginalBlockArguments));
186 llvm::append_range(callValues, newBlock->
getArguments());
187 llvm::append_range(callValues, outlinedValues);
188 auto call = rewriter.
create<func::CallOp>(loc, outlinedFunc, callValues);
197 rewriter.
clone(*originalTerminator, bvm);
198 rewriter.
eraseOp(originalTerminator);
203 for (
auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back(
204 outlinedValues.size()))) {
205 Value orig = std::get<0>(it);
206 Value repl = std::get<1>(it);
216 return outlinedFunc->isProperAncestor(opOperand.
getOwner());
224 func::FuncOp *thenFn, StringRef thenFnName,
225 func::FuncOp *elseFn, StringRef elseFnName) {
228 FailureOr<func::FuncOp> outlinedFuncOpOrFailure;
229 if (thenFn && !ifOp.getThenRegion().empty()) {
231 rewriter, loc, ifOp.getThenRegion(), thenFnName);
232 if (failed(outlinedFuncOpOrFailure))
234 *thenFn = *outlinedFuncOpOrFailure;
236 if (elseFn && !ifOp.getElseRegion().empty()) {
238 rewriter, loc, ifOp.getElseRegion(), elseFnName);
239 if (failed(outlinedFuncOpOrFailure))
241 *elseFn = *outlinedFuncOpOrFailure;
248 assert(rootOp !=
nullptr &&
"Root operation must not be a nullptr.");
249 bool rootEnclosesPloops =
false;
251 for (
Block &block : region.getBlocks()) {
254 rootEnclosesPloops |= enclosesPloops;
255 if (
auto ploop = dyn_cast<scf::ParallelOp>(op)) {
256 rootEnclosesPloops =
true;
260 result.push_back(ploop);
265 return rootEnclosesPloops;
273 assert(divisor > 0 &&
"expected positive divisor");
275 "expected integer or index-typed value");
277 Value divisorMinusOneCst = builder.
create<arith::ConstantOp>(
279 Value divisorCst = builder.
create<arith::ConstantOp>(
281 Value sum = builder.
create<arith::AddIOp>(loc, dividend, divisorMinusOneCst);
282 return builder.
create<arith::DivUIOp>(loc, sum, divisorCst);
292 "expected integer or index-typed value");
295 Value divisorMinusOne = builder.
create<arith::SubIOp>(loc, divisor, cstOne);
296 Value sum = builder.
create<arith::AddIOp>(loc, dividend, divisorMinusOne);
297 return builder.
create<arith::DivUIOp>(loc, sum, divisor);
307 if (!lbCstOp.has_value() || !ubCstOp.has_value() || !stepCstOp.has_value())
311 int64_t lbCst = lbCstOp.value();
312 int64_t ubCst = ubCstOp.value();
313 int64_t stepCst = stepCstOp.value();
314 assert(lbCst >= 0 && ubCst >= 0 && stepCst > 0 &&
315 "expected positive loop bounds and step");
316 return llvm::divideCeilSigned(ubCst - lbCst, stepCst);
324 Block *loopBodyBlock,
Value forOpIV, uint64_t unrollFactor,
334 annotateFn = defaultAnnotateFn;
343 for (
unsigned i = 1; i < unrollFactor; i++) {
347 operandMap.
map(iterArgs, lastYielded);
352 Value ivUnroll = ivRemapFn(i, forOpIV, builder);
353 operandMap.
map(forOpIV, ivUnroll);
357 for (
auto it = loopBodyBlock->
begin(); it != std::next(srcBlockEnd); it++) {
359 annotateFn(i, clonedOp, builder);
363 for (
unsigned i = 0, e = lastYielded.size(); i < e; i++)
369 for (
auto it = loopBodyBlock->
begin(); it != std::next(srcBlockEnd); it++)
370 annotateFn(0, &*it, builder);
379 scf::ForOp forOp, uint64_t unrollFactor,
381 assert(unrollFactor > 0 &&
"expected positive unroll factor");
384 if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
391 auto loc = forOp.getLoc();
392 Value step = forOp.getStep();
393 Value upperBoundUnrolled;
395 bool generateEpilogueLoop =
true;
398 if (constTripCount) {
403 if (unrollFactor == 1) {
404 if (*constTripCount == 1 &&
405 failed(forOp.promoteIfSingleIteration(rewriter)))
410 int64_t tripCountEvenMultiple =
411 *constTripCount - (*constTripCount % unrollFactor);
412 int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
413 int64_t stepUnrolledCst = stepCst * unrollFactor;
416 generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
417 if (generateEpilogueLoop)
418 upperBoundUnrolled = boundsBuilder.
create<arith::ConstantOp>(
419 loc, boundsBuilder.
getIntegerAttr(forOp.getUpperBound().getType(),
420 upperBoundUnrolledCst));
422 upperBoundUnrolled = forOp.getUpperBound();
425 stepUnrolled = stepCst == stepUnrolledCst
427 : boundsBuilder.
create<arith::ConstantOp>(
429 step.
getType(), stepUnrolledCst));
434 auto lowerBound = forOp.getLowerBound();
435 auto upperBound = forOp.getUpperBound();
437 boundsBuilder.
create<arith::SubIOp>(loc, upperBound, lowerBound);
439 Value unrollFactorCst = boundsBuilder.
create<arith::ConstantOp>(
440 loc, boundsBuilder.
getIntegerAttr(tripCount.getType(), unrollFactor));
442 boundsBuilder.
create<arith::RemSIOp>(loc, tripCount, unrollFactorCst);
444 Value tripCountEvenMultiple =
445 boundsBuilder.
create<arith::SubIOp>(loc, tripCount, tripCountRem);
447 upperBoundUnrolled = boundsBuilder.
create<arith::AddIOp>(
449 boundsBuilder.
create<arith::MulIOp>(loc, tripCountEvenMultiple, step));
452 boundsBuilder.
create<arith::MulIOp>(loc, step, unrollFactorCst);
458 if (generateEpilogueLoop) {
459 OpBuilder epilogueBuilder(forOp->getContext());
461 auto epilogueForOp = cast<scf::ForOp>(epilogueBuilder.
clone(*forOp));
462 epilogueForOp.setLowerBound(upperBoundUnrolled);
465 auto results = forOp.getResults();
466 auto epilogueResults = epilogueForOp.getResults();
468 for (
auto e : llvm::zip(results, epilogueResults)) {
469 std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
471 epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
472 epilogueForOp.getInitArgs().size(), results);
473 if (epilogueForOp.promoteIfSingleIteration(rewriter).failed())
478 forOp.setUpperBound(upperBoundUnrolled);
479 forOp.setStep(stepUnrolled);
481 auto iterArgs =
ValueRange(forOp.getRegionIterArgs());
482 auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();
485 forOp.getBody(), forOp.getInductionVar(), unrollFactor,
488 auto stride = b.create<arith::MulIOp>(
490 b.create<arith::ConstantOp>(loc,
491 b.getIntegerAttr(iv.getType(), i)));
492 return b.create<arith::AddIOp>(loc, iv, stride);
494 annotateFn, iterArgs, yieldedValues);
496 if (forOp.promoteIfSingleIteration(rewriter).failed())
504 auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
505 if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
506 !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
507 !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
512 return !walkResult.wasInterrupted();
517 uint64_t unrollJamFactor) {
518 assert(unrollJamFactor > 0 &&
"unroll jam factor should be positive");
520 if (unrollJamFactor == 1)
526 LDBG(
"failed to unroll and jam: inner bounds are not invariant");
531 if (forOp->getNumResults() > 0) {
532 LDBG(
"failed to unroll and jam: unsupported loop with results");
539 if (!tripCount.has_value()) {
541 LDBG(
"failed to unroll and jam: trip count could not be determined");
544 if (unrollJamFactor > *tripCount) {
545 LDBG(
"unroll and jam factor is greater than trip count, set factor to trip "
547 unrollJamFactor = *tripCount;
548 }
else if (*tripCount % unrollJamFactor != 0) {
549 LDBG(
"failed to unroll and jam: unsupported trip count that is not a "
550 "multiple of unroll jam factor");
555 if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
565 forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); });
576 for (scf::ForOp oldForOp : innerLoops) {
578 ValueRange oldIterOperands = oldForOp.getInits();
579 ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
581 cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
584 for (
unsigned i = unrollJamFactor - 1; i >= 1; --i) {
585 dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
586 dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
590 bool forOpReplaced = oldForOp == forOp;
591 scf::ForOp newForOp =
592 cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields(
593 rewriter, dupIterOperands,
false,
595 return dupYieldOperands;
597 newInnerLoops.push_back(newForOp);
602 ValueRange newIterArgs = newForOp.getRegionIterArgs();
603 unsigned oldNumIterArgs = oldIterArgs.size();
604 ValueRange newResults = newForOp.getResults();
605 unsigned oldNumResults = newResults.size() / unrollJamFactor;
606 assert(oldNumIterArgs == oldNumResults &&
607 "oldNumIterArgs must be the same as oldNumResults");
608 for (
unsigned i = unrollJamFactor - 1; i >= 1; --i) {
609 for (
unsigned j = 0;
j < oldNumIterArgs; ++
j) {
613 operandMaps[i - 1].map(newIterArgs[
j],
614 newIterArgs[i * oldNumIterArgs +
j]);
615 operandMaps[i - 1].map(newResults[
j],
616 newResults[i * oldNumResults +
j]);
623 int64_t step = forOp.getConstantStep()->getSExtValue();
625 forOp.getLoc(), forOp.getStep(),
627 forOp.getLoc(), rewriter.
getIndexAttr(unrollJamFactor)));
628 forOp.setStep(newStep);
629 auto forOpIV = forOp.getInductionVar();
632 for (
unsigned i = unrollJamFactor - 1; i >= 1; --i) {
633 for (
auto &subBlock : subBlocks) {
636 OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
645 builder.
createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag);
646 operandMaps[i - 1].map(forOpIV, ivUnroll);
649 for (
auto it = subBlock.first; it != std::next(subBlock.second); ++it)
650 builder.
clone(*it, operandMaps[i - 1]);
653 for (
auto newForOp : newInnerLoops) {
654 unsigned oldNumIterOperands =
655 newForOp.getNumRegionIterArgs() / unrollJamFactor;
656 unsigned numControlOperands = newForOp.getNumControlOperands();
657 auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
658 unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
659 assert(oldNumIterOperands == oldNumYieldOperands &&
660 "oldNumIterOperands must be the same as oldNumYieldOperands");
661 for (
unsigned j = 0;
j < oldNumIterOperands; ++
j) {
665 newForOp.setOperand(numControlOperands + i * oldNumIterOperands +
j,
666 operandMaps[i - 1].lookupOrDefault(
667 newForOp.getOperand(numControlOperands +
j)));
669 i * oldNumYieldOperands +
j,
670 operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(
j)));
676 (void)forOp.promoteIfSingleIteration(rewriter);
683 Range normalizedLoopBounds;
689 normalizedLoopBounds.
size =
691 return normalizedLoopBounds;
703 bool isZeroBased =
false;
705 isZeroBased = lbCst.value() == 0;
707 bool isStepOne =
false;
709 isStepOne = stepCst.value() == 1;
713 "expected matching types");
718 if (isZeroBased && isStepOne)
719 return {lb, ub, step};
729 newUpperBound = rewriter.
createOrFold<arith::CeilDivSIOp>(
737 return {newLowerBound, newUpperBound, newStep};
751 Value denormalizedIvVal =
758 if (
Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
759 preservedUses.insert(preservedUse);
768 if (
getType(origLb).isIndex()) {
772 Value denormalizedIv;
777 Value scaled = normalizedIv;
779 Value origStepValue =
781 scaled = rewriter.
create<arith::MulIOp>(loc, normalizedIv, origStepValue);
784 denormalizedIv = scaled;
787 denormalizedIv = rewriter.
create<arith::AddIOp>(loc, scaled, origLbValue);
796 assert(!values.empty() &&
"unexecpted empty array");
801 for (
auto v : values) {
811 assert(!values.empty() &&
"unexpected empty list");
817 std::optional<Value> productOf;
818 for (
auto v : values) {
820 if (vOne && vOne.value() == 1)
824 rewriter.
create<arith::MulIOp>(loc, productOf.value(), v).getResult();
830 .
create<arith::ConstantOp>(
834 return productOf.value();
851 rewriter.
create<affine::AffineDelinearizeIndexOp>(loc, linearizedIv,
853 auto resultVals = llvm::map_to_vector(
861 llvm::BitVector isUbOne(ubs.size());
864 if (ubCst && ubCst.value() == 1)
869 unsigned numLeadingOneUbs = 0;
871 if (!isUbOne.test(index)) {
874 delinearizedIvs[index] = rewriter.
create<arith::ConstantOp>(
879 Value previous = linearizedIv;
880 for (
unsigned i = numLeadingOneUbs, e = ubs.size(); i < e; ++i) {
881 unsigned idx = ubs.size() - (i - numLeadingOneUbs) - 1;
882 if (i != numLeadingOneUbs && !isUbOne.test(idx + 1)) {
883 previous = rewriter.
create<arith::DivSIOp>(loc, previous, ubs[idx + 1]);
884 preservedUsers.insert(previous.getDefiningOp());
888 if (!isUbOne.test(idx)) {
889 iv = rewriter.
create<arith::RemSIOp>(loc, previous, ubs[idx]);
892 iv = rewriter.
create<arith::ConstantOp>(
896 delinearizedIvs[idx] = iv;
898 return {delinearizedIvs, preservedUsers};
903 if (loops.size() < 2)
906 scf::ForOp innermost = loops.back();
907 scf::ForOp outermost = loops.front();
911 for (
auto loop : loops) {
914 Value lb = loop.getLowerBound();
915 Value ub = loop.getUpperBound();
916 Value step = loop.getStep();
922 newLoopRange.offset));
926 newLoopRange.stride));
930 loop.getInductionVar(), lb, step);
939 loops, [](
auto loop) {
return loop.getUpperBound(); });
941 outermost.setUpperBound(upperBound);
945 rewriter, loc, outermost.getInductionVar(), upperBounds);
949 for (
int i = loops.size() - 1; i > 0; --i) {
950 auto outerLoop = loops[i - 1];
951 auto innerLoop = loops[i];
953 Operation *innerTerminator = innerLoop.getBody()->getTerminator();
954 auto yieldedVals = llvm::to_vector(innerTerminator->
getOperands());
955 assert(llvm::equal(outerLoop.getRegionIterArgs(), innerLoop.getInitArgs()));
956 for (
Value &yieldedVal : yieldedVals) {
959 auto iter = llvm::find(innerLoop.getRegionIterArgs(), yieldedVal);
960 if (iter != innerLoop.getRegionIterArgs().end()) {
961 unsigned iterArgIndex = iter - innerLoop.getRegionIterArgs().begin();
963 assert(iterArgIndex < innerLoop.getInitArgs().size());
964 yieldedVal = innerLoop.getInitArgs()[iterArgIndex];
967 rewriter.
eraseOp(innerTerminator);
970 innerBlockArgs.push_back(delinearizeIvs[i]);
971 llvm::append_range(innerBlockArgs, outerLoop.getRegionIterArgs());
974 rewriter.
replaceOp(innerLoop, yieldedVals);
983 IRRewriter rewriter(loops.front().getContext());
988 LogicalResult result(failure());
998 for (
unsigned i = 0, e = loops.size(); i < e; ++i) {
999 operandsDefinedAbove[i] = i;
1000 for (
unsigned j = 0;
j < i; ++
j) {
1002 loops[i].getUpperBound(),
1003 loops[i].getStep()};
1005 operandsDefinedAbove[i] =
j;
1016 iterArgChainStart[0] = 0;
1017 for (
unsigned i = 1, e = loops.size(); i < e; ++i) {
1019 iterArgChainStart[i] = i;
1020 auto outerloop = loops[i - 1];
1021 auto innerLoop = loops[i];
1022 if (outerloop.getNumRegionIterArgs() != innerLoop.getNumRegionIterArgs()) {
1025 if (!llvm::equal(outerloop.getRegionIterArgs(), innerLoop.getInitArgs())) {
1028 auto outerloopTerminator = outerloop.getBody()->getTerminator();
1029 if (!llvm::equal(outerloopTerminator->getOperands(),
1030 innerLoop.getResults())) {
1033 iterArgChainStart[i] = iterArgChainStart[i - 1];
1039 for (
unsigned end = loops.size(); end > 0; --end) {
1041 for (; start < end - 1; ++start) {
1043 *std::max_element(std::next(operandsDefinedAbove.begin(), start),
1044 std::next(operandsDefinedAbove.begin(), end));
1047 if (iterArgChainStart[end - 1] > start)
1056 if (start != end - 1)
1064 ArrayRef<std::vector<unsigned>> combinedDimensions) {
1070 auto sortedDimensions = llvm::to_vector<3>(combinedDimensions);
1071 for (
auto &dims : sortedDimensions)
1076 for (
unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
1079 Value lb = loops.getLowerBound()[i];
1080 Value ub = loops.getUpperBound()[i];
1081 Value step = loops.getStep()[i];
1084 rewriter, loops.getLoc(), newLoopRange.size));
1095 for (
auto &sortedDimension : sortedDimensions) {
1097 for (
auto idx : sortedDimension) {
1098 newUpperBound = rewriter.
create<arith::MulIOp>(
1099 loc, newUpperBound, normalizedUpperBounds[idx]);
1101 lowerBounds.push_back(cst0);
1102 steps.push_back(cst1);
1103 upperBounds.push_back(newUpperBound);
1112 auto newPloop = rewriter.
create<scf::ParallelOp>(
1113 loc, lowerBounds, upperBounds, steps,
1115 for (
unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
1116 Value previous = ploopIVs[i];
1117 unsigned numberCombinedDimensions = combinedDimensions[i].size();
1119 for (
unsigned j = numberCombinedDimensions - 1;
j > 0; --
j) {
1120 unsigned idx = combinedDimensions[i][
j];
1123 Value iv = insideBuilder.create<arith::RemSIOp>(
1124 loc, previous, normalizedUpperBounds[idx]);
1130 previous = insideBuilder.create<arith::DivSIOp>(
1131 loc, previous, normalizedUpperBounds[idx]);
1135 unsigned idx = combinedDimensions[i][0];
1137 previous, loops.getRegion());
1142 loops.getBody()->back().
erase();
1143 newPloop.getBody()->getOperations().splice(
1145 loops.getBody()->getOperations());
1158 return op != inner.getOperation();
1161 LogicalResult status = success();
1163 for (
auto &op : outer.getBody()->without_terminator()) {
1165 if (&op == inner.getOperation())
1168 if (forwardSlice.count(&op) > 0) {
1173 if (isa<scf::ForOp>(op))
1176 if (op.getNumRegions() > 0) {
1186 toHoist.push_back(&op);
1188 auto *outerForOp = outer.getOperation();
1189 for (
auto *op : toHoist)
1190 op->moveBefore(outerForOp);
1199 LogicalResult status = success();
1200 const Loops &interTile = tileLoops.first;
1201 const Loops &intraTile = tileLoops.second;
1202 auto size = interTile.size();
1203 assert(size == intraTile.size());
1206 for (
unsigned s = 1; s < size; ++s)
1207 status = succeeded(status) ?
hoistOpsBetween(intraTile[0], intraTile[s])
1209 for (
unsigned s = 1; s < size; ++s)
1210 status = succeeded(status) ?
hoistOpsBetween(interTile[0], interTile[s])
1219 template <
typename T>
1223 for (
unsigned i = 0; i < maxLoops; ++i) {
1224 forOps.push_back(rootForOp);
1226 if (body.
begin() != std::prev(body.
end(), 2))
1229 rootForOp = dyn_cast<T>(&body.
front());
1237 auto originalStep = forOp.getStep();
1238 auto iv = forOp.getInductionVar();
1241 forOp.setStep(b.
create<arith::MulIOp>(forOp.getLoc(), originalStep, factor));
1244 for (
auto t : targets) {
1246 auto begin = t.getBody()->begin();
1247 auto nOps = t.getBody()->getOperations().size();
1251 Value stepped = b.
create<arith::AddIOp>(t.getLoc(), iv, forOp.getStep());
1253 b.
create<arith::MinSIOp>(t.getLoc(), forOp.getUpperBound(), stepped);
1256 auto newForOp = b.
create<scf::ForOp>(t.getLoc(), iv, ub, originalStep);
1257 newForOp.getBody()->getOperations().splice(
1258 newForOp.getBody()->getOperations().begin(),
1259 t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
1261 newForOp.getRegion());
1263 innerLoops.push_back(newForOp);
1271 template <
typename SizeType>
1273 scf::ForOp target) {
1279 assert(res.size() == 1 &&
"Expected 1 inner forOp");
1288 for (
auto it : llvm::zip(forOps, sizes)) {
1289 auto step =
stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
1290 res.push_back(step);
1291 currentTargets = step;
1297 scf::ForOp target) {
1300 assert(loops.size() == 1);
1301 res.push_back(loops[0]);
1310 forOps.reserve(sizes.size());
1312 if (forOps.size() < sizes.size())
1313 sizes = sizes.take_front(forOps.size());
1328 forOps.reserve(sizes.size());
1330 if (forOps.size() < sizes.size())
1331 sizes = sizes.take_front(forOps.size());
1338 tileSizes.reserve(sizes.size());
1339 for (
unsigned i = 0, e = sizes.size(); i < e; ++i) {
1340 assert(sizes[i] > 0 &&
"expected strictly positive size for strip-mining");
1342 auto forOp = forOps[i];
1344 auto loc = forOp.getLoc();
1345 Value diff = builder.
create<arith::SubIOp>(loc, forOp.getUpperBound(),
1346 forOp.getLowerBound());
1348 Value iterationsPerBlock =
1350 tileSizes.push_back(iterationsPerBlock);
1354 auto intraTile =
tile(forOps, tileSizes, forOps.back());
1355 TileLoops tileLoops = std::make_pair(forOps, intraTile);
1366 scf::ForallOp source,
1368 unsigned numTargetOuts = target.getNumResults();
1369 unsigned numSourceOuts = source.getNumResults();
1373 llvm::append_range(fusedOuts, target.getOutputs());
1374 llvm::append_range(fusedOuts, source.getOutputs());
1378 scf::ForallOp fusedLoop = rewriter.
create<scf::ForallOp>(
1379 source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
1380 source.getMixedStep(), fusedOuts, source.getMapping());
1384 mapping.
map(target.getInductionVars(), fusedLoop.getInductionVars());
1385 mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
1388 mapping.map(target.getRegionIterArgs(),
1389 fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1390 mapping.map(source.getRegionIterArgs(),
1391 fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1395 for (
Operation &op : target.getBody()->without_terminator())
1396 rewriter.
clone(op, mapping);
1397 for (
Operation &op : source.getBody()->without_terminator())
1398 rewriter.
clone(op, mapping);
1401 scf::InParallelOp targetTerm = target.getTerminator();
1402 scf::InParallelOp sourceTerm = source.getTerminator();
1403 scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
1405 for (
Operation &op : targetTerm.getYieldingOps())
1406 rewriter.
clone(op, mapping);
1407 for (
Operation &op : sourceTerm.getYieldingOps())
1408 rewriter.
clone(op, mapping);
1411 rewriter.
replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1412 rewriter.
replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1420 unsigned numTargetOuts = target.getNumResults();
1421 unsigned numSourceOuts = source.getNumResults();
1425 llvm::append_range(fusedInitArgs, target.getInitArgs());
1426 llvm::append_range(fusedInitArgs, source.getInitArgs());
1431 scf::ForOp fusedLoop = rewriter.
create<scf::ForOp>(
1432 source.getLoc(), source.getLowerBound(), source.getUpperBound(),
1433 source.getStep(), fusedInitArgs);
1437 mapping.
map(target.getInductionVar(), fusedLoop.getInductionVar());
1438 mapping.map(target.getRegionIterArgs(),
1439 fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1440 mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
1441 mapping.map(source.getRegionIterArgs(),
1442 fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1446 for (
Operation &op : target.getBody()->without_terminator())
1447 rewriter.
clone(op, mapping);
1448 for (
Operation &op : source.getBody()->without_terminator())
1449 rewriter.
clone(op, mapping);
1453 for (
Value operand : target.getBody()->getTerminator()->getOperands())
1454 yieldResults.push_back(mapping.lookupOrDefault(operand));
1455 for (
Value operand : source.getBody()->getTerminator()->getOperands())
1456 yieldResults.push_back(mapping.lookupOrDefault(operand));
1457 if (!yieldResults.empty())
1458 rewriter.
create<scf::YieldOp>(source.getLoc(), yieldResults);
1461 rewriter.
replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1462 rewriter.
replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1468 scf::ForallOp forallOp) {
1481 for (
auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
1482 Range normalizedLoopParams =
1484 newLbs.push_back(normalizedLoopParams.
offset);
1485 newUbs.push_back(normalizedLoopParams.
size);
1486 newSteps.push_back(normalizedLoopParams.
stride);
1489 auto normalizedForallOp = rewriter.
create<scf::ForallOp>(
1490 forallOp.getLoc(), newLbs, newUbs, newSteps, forallOp.getOutputs(),
1494 normalizedForallOp.getBodyRegion(),
1495 normalizedForallOp.getBodyRegion().begin());
static std::optional< int64_t > getConstantTripCount(scf::ForOp forOp)
Returns the trip count of forOp if its' low bound, high bound and step are constants,...
static OpFoldResult getProductOfIndexes(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > values)
static LogicalResult tryIsolateBands(const TileLoops &tileLoops)
static void getPerfectlyNestedLoopsImpl(SmallVectorImpl< T > &forOps, T rootForOp, unsigned maxLoops=std::numeric_limits< unsigned >::max())
Collect perfectly nested loops starting from rootForOps.
static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner)
static void generateUnrolledLoop(Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor, function_ref< Value(unsigned, Value, OpBuilder)> ivRemapFn, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn, ValueRange iterArgs, ValueRange yieldedValues)
Generates unrolled copies of scf::ForOp 'loopBodyBlock', with associated 'forOpIV' by 'unrollFactor',...
static Loops stripmineSink(scf::ForOp forOp, Value factor, ArrayRef< scf::ForOp > targets)
static std::pair< SmallVector< Value >, SmallPtrSet< Operation *, 2 > > delinearizeInductionVariable(RewriterBase &rewriter, Location loc, Value linearizedIv, ArrayRef< Value > ubs)
For each original loop, the value of the induction variable can be obtained by dividing the induction...
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, int64_t divisor)
static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc, ArrayRef< Value > values)
Helper function to multiply a sequence of values.
static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter, Location loc, Value normalizedIv, OpFoldResult origLb, OpFoldResult origStep)
Range emitNormalizedLoopBoundsForIndexType(RewriterBase &rewriter, Location loc, OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
static bool areInnerBoundsInvariant(scf::ForOp forOp)
Check if bounds of all inner loops are defined outside of forOp and return false if not.
static int64_t product(ArrayRef< int64_t > vals)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
OpListType::iterator iterator
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
TypedAttr getOneAttr(Type type)
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
static OpBuilder atBlockTerminator(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the block terminator.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
result_range getResults()
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
ParentT getParentOfType()
Find the first parent operation of the given type, or nullptr if there is no ancestor operation.
bool hasOneBlock()
Return true if this region has exactly one block.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< SmallVector< AffineForOp, 8 >, 8 > tile(ArrayRef< AffineForOp > forOps, ArrayRef< uint64_t > sizes, ArrayRef< AffineForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
void getPerfectlyNestedLoops(SmallVectorImpl< scf::ForOp > &nestedLoops, scf::ForOp root)
Get perfectly nested sequence of loops starting at root of loop nest (the first op being another Affi...
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult outlineIfOp(RewriterBase &b, scf::IfOp ifOp, func::FuncOp *thenFn, StringRef thenFnName, func::FuncOp *elseFn, StringRef elseFnName)
Outline the then and/or else regions of ifOp as follows:
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion)
Replace all uses of orig within the given region with replacement.
SmallVector< scf::ForOp > replaceLoopNestWithNewYields(RewriterBase &rewriter, MutableArrayRef< scf::ForOp > loopNest, ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn, bool replaceIterOperandsUsesInLoop=true)
Update a perfectly nested loop nest to yield new values from the innermost loop and propagating it up...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op)
Walk an affine.for to find a band to coalesce.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
std::pair< Loops, Loops > TileLoops
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops, ArrayRef< std::vector< unsigned >> combinedDimensions)
Take the ParallelLoop and for each set of dimension indices, combine them into a single dimension.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef< Value > sizes)
Tile a nest of scf::ForOp loops rooted at rootForOp with the given (parametric) sizes.
FailureOr< UnrolledLoopInfo > loopUnrollByFactor(scf::ForOp forOp, uint64_t unrollFactor, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr)
Unrolls this for operation by the specified unroll factor.
LogicalResult loopUnrollJamByFactor(scf::ForOp forOp, uint64_t unrollFactor)
Unrolls and jams this scf.for operation by the specified unroll factor.
bool getInnermostParallelLoops(Operation *rootOp, SmallVectorImpl< scf::ParallelOp > &result)
Get a list of innermost parallel loops contained in rootOp.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
FailureOr< func::FuncOp > outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region ®ion, StringRef funcName, func::CallOp *callOp=nullptr)
Outline a region with a single block into a new FuncOp.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool areValuesDefinedAbove(Range values, Region &limit)
Check if all values in the provided range are defined above the limit region.
void denormalizeInductionVariable(RewriterBase &rewriter, Location loc, Value normalizedIv, OpFoldResult origLb, OpFoldResult origStep)
Get back the original induction variable values after loop normalization.
scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter)
Given two scf.forall loops, target and source, fuses target into source.
LogicalResult coalesceLoops(MutableArrayRef< scf::ForOp > loops)
Replace a perfect nest of "for" loops with a single linearized loop.
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter)
Given two scf.for loops, target and source, fuses target into source.
TileLoops extractFixedOuterLoops(scf::ForOp rootFOrOp, ArrayRef< int64_t > sizes)
Range emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc, OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Materialize bounds and step of a zero-based and unit-step loop derived by normalizing the specified b...
FailureOr< scf::ForallOp > normalizeForallOp(RewriterBase &rewriter, scf::ForallOp forallOp)
Normalize an scf.forall operation.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
SmallVector< std::pair< Block::iterator, Block::iterator > > subBlocks
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
std::optional< scf::ForOp > epilogueLoopOp
std::optional< scf::ForOp > mainLoopOp
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.