25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallPtrSet.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/MathExtras.h"
35 #define DEBUG_TYPE "scf-utils"
36 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
37 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
42 bool replaceIterOperandsUsesInLoop) {
48 assert(loopNest.size() <= 10 &&
49 "exceeded recursion limit when yielding value from loop nest");
81 if (loopNest.size() == 1) {
83 cast<scf::ForOp>(*loopNest.back().replaceWithAdditionalYields(
84 rewriter, newIterOperands, replaceIterOperandsUsesInLoop,
86 return {innerMostLoop};
96 innerNewBBArgs, newYieldValuesFn,
97 replaceIterOperandsUsesInLoop);
98 return llvm::to_vector(llvm::map_range(
99 newLoopNest.front().getResults().take_back(innerNewBBArgs.size()),
102 scf::ForOp outerMostLoop =
103 cast<scf::ForOp>(*loopNest.front().replaceWithAdditionalYields(
104 rewriter, newIterOperands, replaceIterOperandsUsesInLoop, fn));
105 newLoopNest.insert(newLoopNest.begin(), outerMostLoop);
122 func::CallOp *callOp) {
123 assert(!funcName.empty() &&
"funcName cannot be empty");
137 ValueRange outlinedValues(captures.getArrayRef());
144 outlinedFuncArgTypes.push_back(arg.getType());
145 outlinedFuncArgLocs.push_back(arg.getLoc());
147 for (
Value value : outlinedValues) {
148 outlinedFuncArgTypes.push_back(value.getType());
149 outlinedFuncArgLocs.push_back(value.getLoc());
151 FunctionType outlinedFuncType =
155 rewriter.
create<func::FuncOp>(loc, funcName, outlinedFuncType);
156 Block *outlinedFuncBody = outlinedFunc.addEntryBlock();
161 auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments();
166 originalBlock, outlinedFuncBody,
167 outlinedFuncBlockArgs.take_front(numOriginalBlockArguments));
177 ®ion, region.
begin(),
178 TypeRange{outlinedFuncArgTypes}.take_front(numOriginalBlockArguments),
180 .take_front(numOriginalBlockArguments));
185 llvm::append_range(callValues, newBlock->
getArguments());
186 llvm::append_range(callValues, outlinedValues);
187 auto call = rewriter.
create<func::CallOp>(loc, outlinedFunc, callValues);
196 rewriter.
clone(*originalTerminator, bvm);
197 rewriter.
eraseOp(originalTerminator);
202 for (
auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back(
203 outlinedValues.size()))) {
204 Value orig = std::get<0>(it);
205 Value repl = std::get<1>(it);
215 return outlinedFunc->isProperAncestor(opOperand.
getOwner());
223 func::FuncOp *thenFn, StringRef thenFnName,
224 func::FuncOp *elseFn, StringRef elseFnName) {
227 FailureOr<func::FuncOp> outlinedFuncOpOrFailure;
228 if (thenFn && !ifOp.getThenRegion().empty()) {
230 rewriter, loc, ifOp.getThenRegion(), thenFnName);
231 if (failed(outlinedFuncOpOrFailure))
233 *thenFn = *outlinedFuncOpOrFailure;
235 if (elseFn && !ifOp.getElseRegion().empty()) {
237 rewriter, loc, ifOp.getElseRegion(), elseFnName);
238 if (failed(outlinedFuncOpOrFailure))
240 *elseFn = *outlinedFuncOpOrFailure;
247 assert(rootOp !=
nullptr &&
"Root operation must not be a nullptr.");
248 bool rootEnclosesPloops =
false;
250 for (
Block &block : region.getBlocks()) {
253 rootEnclosesPloops |= enclosesPloops;
254 if (
auto ploop = dyn_cast<scf::ParallelOp>(op)) {
255 rootEnclosesPloops =
true;
259 result.push_back(ploop);
264 return rootEnclosesPloops;
272 assert(divisor > 0 &&
"expected positive divisor");
274 "expected integer or index-typed value");
276 Value divisorMinusOneCst = builder.
create<arith::ConstantOp>(
278 Value divisorCst = builder.
create<arith::ConstantOp>(
280 Value sum = builder.
create<arith::AddIOp>(loc, dividend, divisorMinusOneCst);
281 return builder.
create<arith::DivUIOp>(loc, sum, divisorCst);
291 "expected integer or index-typed value");
294 Value divisorMinusOne = builder.
create<arith::SubIOp>(loc, divisor, cstOne);
295 Value sum = builder.
create<arith::AddIOp>(loc, dividend, divisorMinusOne);
296 return builder.
create<arith::DivUIOp>(loc, sum, divisor);
306 if (!lbCstOp.has_value() || !ubCstOp.has_value() || !stepCstOp.has_value())
310 int64_t lbCst = lbCstOp.value();
311 int64_t ubCst = ubCstOp.value();
312 int64_t stepCst = stepCstOp.value();
313 assert(lbCst >= 0 && ubCst >= 0 && stepCst > 0 &&
314 "expected positive loop bounds and step");
315 return llvm::divideCeilSigned(ubCst - lbCst, stepCst);
323 Block *loopBodyBlock,
Value forOpIV, uint64_t unrollFactor,
341 for (
unsigned i = 1; i < unrollFactor; i++) {
345 operandMap.
map(iterArgs, lastYielded);
350 Value ivUnroll = ivRemapFn(i, forOpIV, builder);
351 operandMap.
map(forOpIV, ivUnroll);
355 for (
auto it = loopBodyBlock->
begin(); it != std::next(srcBlockEnd); it++) {
357 annotateFn(i, clonedOp, builder);
361 for (
unsigned i = 0, e = lastYielded.size(); i < e; i++)
362 lastYielded[i] = operandMap.
lookup(yieldedValues[i]);
367 for (
auto it = loopBodyBlock->
begin(); it != std::next(srcBlockEnd); it++)
368 annotateFn(0, &*it, builder);
376 scf::ForOp forOp, uint64_t unrollFactor,
378 assert(unrollFactor > 0 &&
"expected positive unroll factor");
381 if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
388 auto loc = forOp.getLoc();
389 Value step = forOp.getStep();
390 Value upperBoundUnrolled;
392 bool generateEpilogueLoop =
true;
395 if (constTripCount) {
400 if (unrollFactor == 1) {
401 if (*constTripCount == 1 &&
402 failed(forOp.promoteIfSingleIteration(rewriter)))
407 int64_t tripCountEvenMultiple =
408 *constTripCount - (*constTripCount % unrollFactor);
409 int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
410 int64_t stepUnrolledCst = stepCst * unrollFactor;
413 generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
414 if (generateEpilogueLoop)
415 upperBoundUnrolled = boundsBuilder.
create<arith::ConstantOp>(
416 loc, boundsBuilder.
getIntegerAttr(forOp.getUpperBound().getType(),
417 upperBoundUnrolledCst));
419 upperBoundUnrolled = forOp.getUpperBound();
422 stepUnrolled = stepCst == stepUnrolledCst
424 : boundsBuilder.
create<arith::ConstantOp>(
426 step.
getType(), stepUnrolledCst));
431 auto lowerBound = forOp.getLowerBound();
432 auto upperBound = forOp.getUpperBound();
434 boundsBuilder.
create<arith::SubIOp>(loc, upperBound, lowerBound);
436 Value unrollFactorCst = boundsBuilder.
create<arith::ConstantOp>(
437 loc, boundsBuilder.
getIntegerAttr(tripCount.getType(), unrollFactor));
439 boundsBuilder.
create<arith::RemSIOp>(loc, tripCount, unrollFactorCst);
441 Value tripCountEvenMultiple =
442 boundsBuilder.
create<arith::SubIOp>(loc, tripCount, tripCountRem);
444 upperBoundUnrolled = boundsBuilder.
create<arith::AddIOp>(
446 boundsBuilder.
create<arith::MulIOp>(loc, tripCountEvenMultiple, step));
449 boundsBuilder.
create<arith::MulIOp>(loc, step, unrollFactorCst);
453 if (generateEpilogueLoop) {
454 OpBuilder epilogueBuilder(forOp->getContext());
457 auto epilogueForOp = cast<scf::ForOp>(epilogueBuilder.
clone(*forOp));
458 epilogueForOp.setLowerBound(upperBoundUnrolled);
461 auto results = forOp.getResults();
462 auto epilogueResults = epilogueForOp.getResults();
464 for (
auto e : llvm::zip(results, epilogueResults)) {
465 std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
467 epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
468 epilogueForOp.getInitArgs().size(), results);
469 (void)epilogueForOp.promoteIfSingleIteration(rewriter);
473 forOp.setUpperBound(upperBoundUnrolled);
474 forOp.setStep(stepUnrolled);
476 auto iterArgs =
ValueRange(forOp.getRegionIterArgs());
477 auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();
480 forOp.getBody(), forOp.getInductionVar(), unrollFactor,
483 auto stride = b.create<arith::MulIOp>(
485 b.create<arith::ConstantOp>(loc,
486 b.getIntegerAttr(iv.getType(), i)));
487 return b.create<arith::AddIOp>(loc, iv, stride);
489 annotateFn, iterArgs, yieldedValues);
491 (void)forOp.promoteIfSingleIteration(rewriter);
498 auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
499 if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
500 !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
501 !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
506 return !walkResult.wasInterrupted();
511 uint64_t unrollJamFactor) {
512 assert(unrollJamFactor > 0 &&
"unroll jam factor should be positive");
514 if (unrollJamFactor == 1)
520 LDBG(
"failed to unroll and jam: inner bounds are not invariant");
525 if (forOp->getNumResults() > 0) {
526 LDBG(
"failed to unroll and jam: unsupported loop with results");
533 if (!tripCount.has_value()) {
535 LDBG(
"failed to unroll and jam: trip count could not be determined");
538 if (unrollJamFactor > *tripCount) {
539 LDBG(
"unroll and jam factor is greater than trip count, set factor to trip "
541 unrollJamFactor = *tripCount;
542 }
else if (*tripCount % unrollJamFactor != 0) {
543 LDBG(
"failed to unroll and jam: unsupported trip count that is not a "
544 "multiple of unroll jam factor");
549 if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
559 forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); });
570 for (scf::ForOp oldForOp : innerLoops) {
572 ValueRange oldIterOperands = oldForOp.getInits();
573 ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
575 cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
578 for (
unsigned i = unrollJamFactor - 1; i >= 1; --i) {
579 dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
580 dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
584 bool forOpReplaced = oldForOp == forOp;
585 scf::ForOp newForOp =
586 cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields(
587 rewriter, dupIterOperands,
false,
589 return dupYieldOperands;
591 newInnerLoops.push_back(newForOp);
596 ValueRange newIterArgs = newForOp.getRegionIterArgs();
597 unsigned oldNumIterArgs = oldIterArgs.size();
598 ValueRange newResults = newForOp.getResults();
599 unsigned oldNumResults = newResults.size() / unrollJamFactor;
600 assert(oldNumIterArgs == oldNumResults &&
601 "oldNumIterArgs must be the same as oldNumResults");
602 for (
unsigned i = unrollJamFactor - 1; i >= 1; --i) {
603 for (
unsigned j = 0;
j < oldNumIterArgs; ++
j) {
607 operandMaps[i - 1].map(newIterArgs[
j],
608 newIterArgs[i * oldNumIterArgs +
j]);
609 operandMaps[i - 1].map(newResults[
j],
610 newResults[i * oldNumResults +
j]);
617 int64_t step = forOp.getConstantStep()->getSExtValue();
619 forOp.getLoc(), forOp.getStep(),
621 forOp.getLoc(), rewriter.
getIndexAttr(unrollJamFactor)));
622 forOp.setStep(newStep);
623 auto forOpIV = forOp.getInductionVar();
626 for (
unsigned i = unrollJamFactor - 1; i >= 1; --i) {
627 for (
auto &subBlock : subBlocks) {
630 OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
639 builder.
createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag);
640 operandMaps[i - 1].map(forOpIV, ivUnroll);
643 for (
auto it = subBlock.first; it != std::next(subBlock.second); ++it)
644 builder.
clone(*it, operandMaps[i - 1]);
647 for (
auto newForOp : newInnerLoops) {
648 unsigned oldNumIterOperands =
649 newForOp.getNumRegionIterArgs() / unrollJamFactor;
650 unsigned numControlOperands = newForOp.getNumControlOperands();
651 auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
652 unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
653 assert(oldNumIterOperands == oldNumYieldOperands &&
654 "oldNumIterOperands must be the same as oldNumYieldOperands");
655 for (
unsigned j = 0;
j < oldNumIterOperands; ++
j) {
659 newForOp.setOperand(numControlOperands + i * oldNumIterOperands +
j,
660 operandMaps[i - 1].lookupOrDefault(
661 newForOp.getOperand(numControlOperands +
j)));
663 i * oldNumYieldOperands +
j,
664 operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(
j)));
670 (void)forOp.promoteIfSingleIteration(rewriter);
680 bool isZeroBased =
false;
682 isZeroBased = lbCst.value() == 0;
684 bool isStepOne =
false;
686 isStepOne = stepCst.value() == 1;
690 "expected matching types");
695 if (isZeroBased && isStepOne)
696 return {lb, ub, step};
706 newUpperBound = rewriter.
createOrFold<arith::CeilDivSIOp>(
714 return {newLowerBound, newUpperBound, newStep};
720 Value denormalizedIv;
725 Value scaled = normalizedIv;
727 Value origStepValue =
729 scaled = rewriter.
create<arith::MulIOp>(loc, normalizedIv, origStepValue);
732 denormalizedIv = scaled;
735 denormalizedIv = rewriter.
create<arith::AddIOp>(loc, scaled, origLbValue);
745 assert(!values.empty() &&
"unexpected empty list");
746 std::optional<Value> productOf;
747 for (
auto v : values) {
749 if (vOne && vOne.value() == 1)
753 rewriter.
create<arith::MulIOp>(loc, productOf.value(), v).getResult();
759 .
create<arith::ConstantOp>(
760 loc, rewriter.
getOneAttr(values.front().getType()))
763 return productOf.value();
780 llvm::BitVector isUbOne(ubs.size());
783 if (ubCst && ubCst.value() == 1)
788 unsigned numLeadingOneUbs = 0;
790 if (!isUbOne.test(index)) {
793 delinearizedIvs[index] = rewriter.
create<arith::ConstantOp>(
798 Value previous = linearizedIv;
799 for (
unsigned i = numLeadingOneUbs, e = ubs.size(); i < e; ++i) {
800 unsigned idx = ubs.size() - (i - numLeadingOneUbs) - 1;
801 if (i != numLeadingOneUbs && !isUbOne.test(idx + 1)) {
802 previous = rewriter.
create<arith::DivSIOp>(loc, previous, ubs[idx + 1]);
803 preservedUsers.insert(previous.getDefiningOp());
807 if (!isUbOne.test(idx)) {
808 iv = rewriter.
create<arith::RemSIOp>(loc, previous, ubs[idx]);
811 iv = rewriter.
create<arith::ConstantOp>(
815 delinearizedIvs[idx] = iv;
817 return {delinearizedIvs, preservedUsers};
822 if (loops.size() < 2)
825 scf::ForOp innermost = loops.back();
826 scf::ForOp outermost = loops.front();
830 for (
auto loop : loops) {
833 Value lb = loop.getLowerBound();
834 Value ub = loop.getUpperBound();
835 Value step = loop.getStep();
841 newLoopRange.offset));
845 newLoopRange.stride));
849 loop.getInductionVar(), lb, step);
858 loops, [](
auto loop) {
return loop.getUpperBound(); });
860 outermost.setUpperBound(upperBound);
864 rewriter, loc, outermost.getInductionVar(), upperBounds);
868 for (
int i = loops.size() - 1; i > 0; --i) {
869 auto outerLoop = loops[i - 1];
870 auto innerLoop = loops[i];
872 Operation *innerTerminator = innerLoop.getBody()->getTerminator();
873 auto yieldedVals = llvm::to_vector(innerTerminator->
getOperands());
874 assert(llvm::equal(outerLoop.getRegionIterArgs(), innerLoop.getInitArgs()));
875 for (
Value &yieldedVal : yieldedVals) {
878 auto iter = llvm::find(innerLoop.getRegionIterArgs(), yieldedVal);
879 if (iter != innerLoop.getRegionIterArgs().end()) {
880 unsigned iterArgIndex = iter - innerLoop.getRegionIterArgs().begin();
882 assert(iterArgIndex < innerLoop.getInitArgs().size());
883 yieldedVal = innerLoop.getInitArgs()[iterArgIndex];
886 rewriter.
eraseOp(innerTerminator);
889 innerBlockArgs.push_back(delinearizeIvs[i]);
890 llvm::append_range(innerBlockArgs, outerLoop.getRegionIterArgs());
893 rewriter.
replaceOp(innerLoop, yieldedVals);
902 IRRewriter rewriter(loops.front().getContext());
907 LogicalResult result(failure());
917 for (
unsigned i = 0, e = loops.size(); i < e; ++i) {
918 operandsDefinedAbove[i] = i;
919 for (
unsigned j = 0;
j < i; ++
j) {
921 loops[i].getUpperBound(),
924 operandsDefinedAbove[i] =
j;
935 iterArgChainStart[0] = 0;
936 for (
unsigned i = 1, e = loops.size(); i < e; ++i) {
938 iterArgChainStart[i] = i;
939 auto outerloop = loops[i - 1];
940 auto innerLoop = loops[i];
941 if (outerloop.getNumRegionIterArgs() != innerLoop.getNumRegionIterArgs()) {
944 if (!llvm::equal(outerloop.getRegionIterArgs(), innerLoop.getInitArgs())) {
947 auto outerloopTerminator = outerloop.getBody()->getTerminator();
948 if (!llvm::equal(outerloopTerminator->getOperands(),
949 innerLoop.getResults())) {
952 iterArgChainStart[i] = iterArgChainStart[i - 1];
958 for (
unsigned end = loops.size(); end > 0; --end) {
960 for (; start < end - 1; ++start) {
962 *std::max_element(std::next(operandsDefinedAbove.begin(), start),
963 std::next(operandsDefinedAbove.begin(), end));
966 if (iterArgChainStart[end - 1] > start)
975 if (start != end - 1)
983 ArrayRef<std::vector<unsigned>> combinedDimensions) {
989 auto sortedDimensions = llvm::to_vector<3>(combinedDimensions);
990 for (
auto &dims : sortedDimensions)
995 for (
unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
998 Value lb = loops.getLowerBound()[i];
999 Value ub = loops.getUpperBound()[i];
1000 Value step = loops.getStep()[i];
1003 rewriter, loops.getLoc(), newLoopRange.size));
1014 for (
auto &sortedDimension : sortedDimensions) {
1016 for (
auto idx : sortedDimension) {
1017 newUpperBound = rewriter.
create<arith::MulIOp>(
1018 loc, newUpperBound, normalizedUpperBounds[idx]);
1020 lowerBounds.push_back(cst0);
1021 steps.push_back(cst1);
1022 upperBounds.push_back(newUpperBound);
1031 auto newPloop = rewriter.
create<scf::ParallelOp>(
1032 loc, lowerBounds, upperBounds, steps,
1034 for (
unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
1035 Value previous = ploopIVs[i];
1036 unsigned numberCombinedDimensions = combinedDimensions[i].size();
1038 for (
unsigned j = numberCombinedDimensions - 1;
j > 0; --
j) {
1039 unsigned idx = combinedDimensions[i][
j];
1042 Value iv = insideBuilder.create<arith::RemSIOp>(
1043 loc, previous, normalizedUpperBounds[idx]);
1049 previous = insideBuilder.create<arith::DivSIOp>(
1050 loc, previous, normalizedUpperBounds[idx]);
1054 unsigned idx = combinedDimensions[i][0];
1056 previous, loops.getRegion());
1061 loops.getBody()->back().
erase();
1062 newPloop.getBody()->getOperations().splice(
1064 loops.getBody()->getOperations());
1077 return op != inner.getOperation();
1080 LogicalResult status = success();
1082 for (
auto &op : outer.getBody()->without_terminator()) {
1084 if (&op == inner.getOperation())
1087 if (forwardSlice.count(&op) > 0) {
1092 if (isa<scf::ForOp>(op))
1105 toHoist.push_back(&op);
1107 auto *outerForOp = outer.getOperation();
1108 for (
auto *op : toHoist)
1118 LogicalResult status = success();
1119 const Loops &interTile = tileLoops.first;
1120 const Loops &intraTile = tileLoops.second;
1121 auto size = interTile.size();
1122 assert(size == intraTile.size());
1125 for (
unsigned s = 1; s < size; ++s)
1126 status = succeeded(status) ?
hoistOpsBetween(intraTile[0], intraTile[s])
1128 for (
unsigned s = 1; s < size; ++s)
1129 status = succeeded(status) ?
hoistOpsBetween(interTile[0], interTile[s])
1138 template <
typename T>
1142 for (
unsigned i = 0; i < maxLoops; ++i) {
1143 forOps.push_back(rootForOp);
1145 if (body.
begin() != std::prev(body.
end(), 2))
1148 rootForOp = dyn_cast<T>(&body.
front());
1156 auto originalStep = forOp.getStep();
1157 auto iv = forOp.getInductionVar();
1160 forOp.setStep(b.
create<arith::MulIOp>(forOp.getLoc(), originalStep, factor));
1163 for (
auto t : targets) {
1165 auto begin = t.getBody()->begin();
1166 auto nOps = t.getBody()->getOperations().size();
1170 Value stepped = b.
create<arith::AddIOp>(t.getLoc(), iv, forOp.getStep());
1172 b.
create<arith::MinSIOp>(t.getLoc(), forOp.getUpperBound(), stepped);
1175 auto newForOp = b.
create<scf::ForOp>(t.getLoc(), iv, ub, originalStep);
1176 newForOp.getBody()->getOperations().splice(
1177 newForOp.getBody()->getOperations().begin(),
1178 t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
1180 newForOp.getRegion());
1182 innerLoops.push_back(newForOp);
1190 template <
typename SizeType>
1192 scf::ForOp target) {
1198 assert(res.size() == 1 &&
"Expected 1 inner forOp");
1207 for (
auto it : llvm::zip(forOps, sizes)) {
1208 auto step =
stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
1209 res.push_back(step);
1210 currentTargets = step;
1216 scf::ForOp target) {
1219 assert(loops.size() == 1);
1220 res.push_back(loops[0]);
1229 forOps.reserve(sizes.size());
1231 if (forOps.size() < sizes.size())
1232 sizes = sizes.take_front(forOps.size());
1247 forOps.reserve(sizes.size());
1249 if (forOps.size() < sizes.size())
1250 sizes = sizes.take_front(forOps.size());
1257 tileSizes.reserve(sizes.size());
1258 for (
unsigned i = 0, e = sizes.size(); i < e; ++i) {
1259 assert(sizes[i] > 0 &&
"expected strictly positive size for strip-mining");
1261 auto forOp = forOps[i];
1263 auto loc = forOp.getLoc();
1264 Value diff = builder.
create<arith::SubIOp>(loc, forOp.getUpperBound(),
1265 forOp.getLowerBound());
1267 Value iterationsPerBlock =
1269 tileSizes.push_back(iterationsPerBlock);
1273 auto intraTile =
tile(forOps, tileSizes, forOps.back());
1274 TileLoops tileLoops = std::make_pair(forOps, intraTile);
1285 scf::ForallOp source,
1287 unsigned numTargetOuts = target.getNumResults();
1288 unsigned numSourceOuts = source.getNumResults();
1292 llvm::append_range(fusedOuts, target.getOutputs());
1293 llvm::append_range(fusedOuts, source.getOutputs());
1297 scf::ForallOp fusedLoop = rewriter.
create<scf::ForallOp>(
1298 source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
1299 source.getMixedStep(), fusedOuts, source.getMapping());
1303 mapping.
map(target.getInductionVars(), fusedLoop.getInductionVars());
1304 mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
1307 mapping.map(target.getRegionIterArgs(),
1308 fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1309 mapping.map(source.getRegionIterArgs(),
1310 fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1314 for (
Operation &op : target.getBody()->without_terminator())
1315 rewriter.
clone(op, mapping);
1316 for (
Operation &op : source.getBody()->without_terminator())
1317 rewriter.
clone(op, mapping);
1320 scf::InParallelOp targetTerm = target.getTerminator();
1321 scf::InParallelOp sourceTerm = source.getTerminator();
1322 scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
1324 for (
Operation &op : targetTerm.getYieldingOps())
1325 rewriter.
clone(op, mapping);
1326 for (
Operation &op : sourceTerm.getYieldingOps())
1327 rewriter.
clone(op, mapping);
1330 rewriter.
replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1331 rewriter.
replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1339 unsigned numTargetOuts = target.getNumResults();
1340 unsigned numSourceOuts = source.getNumResults();
1344 llvm::append_range(fusedInitArgs, target.getInitArgs());
1345 llvm::append_range(fusedInitArgs, source.getInitArgs());
1350 scf::ForOp fusedLoop = rewriter.
create<scf::ForOp>(
1351 source.getLoc(), source.getLowerBound(), source.getUpperBound(),
1352 source.getStep(), fusedInitArgs);
1356 mapping.
map(target.getInductionVar(), fusedLoop.getInductionVar());
1357 mapping.map(target.getRegionIterArgs(),
1358 fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1359 mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
1360 mapping.map(source.getRegionIterArgs(),
1361 fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1365 for (
Operation &op : target.getBody()->without_terminator())
1366 rewriter.
clone(op, mapping);
1367 for (
Operation &op : source.getBody()->without_terminator())
1368 rewriter.
clone(op, mapping);
1372 for (
Value operand : target.getBody()->getTerminator()->getOperands())
1373 yieldResults.push_back(mapping.lookupOrDefault(operand));
1374 for (
Value operand : source.getBody()->getTerminator()->getOperands())
1375 yieldResults.push_back(mapping.lookupOrDefault(operand));
1376 if (!yieldResults.empty())
1377 rewriter.
create<scf::YieldOp>(source.getLoc(), yieldResults);
1380 rewriter.
replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1381 rewriter.
replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1387 scf::ForallOp forallOp) {
1400 for (
auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
1401 Range normalizedLoopParams =
1403 newLbs.push_back(normalizedLoopParams.
offset);
1404 newUbs.push_back(normalizedLoopParams.
size);
1405 newSteps.push_back(normalizedLoopParams.
stride);
1408 auto normalizedForallOp = rewriter.
create<scf::ForallOp>(
1409 forallOp.getLoc(), newLbs, newUbs, newSteps, forallOp.getOutputs(),
1413 normalizedForallOp.getBodyRegion(),
1414 normalizedForallOp.getBodyRegion().begin());
static std::optional< int64_t > getConstantTripCount(scf::ForOp forOp)
Returns the trip count of forOp if its' low bound, high bound and step are constants,...
static LogicalResult tryIsolateBands(const TileLoops &tileLoops)
static void getPerfectlyNestedLoopsImpl(SmallVectorImpl< T > &forOps, T rootForOp, unsigned maxLoops=std::numeric_limits< unsigned >::max())
Collect perfectly nested loops starting from rootForOps.
static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner)
static void generateUnrolledLoop(Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor, function_ref< Value(unsigned, Value, OpBuilder)> ivRemapFn, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn, ValueRange iterArgs, ValueRange yieldedValues)
Generates unrolled copies of scf::ForOp 'loopBodyBlock', with associated 'forOpIV' by 'unrollFactor',...
static Loops stripmineSink(scf::ForOp forOp, Value factor, ArrayRef< scf::ForOp > targets)
static std::pair< SmallVector< Value >, SmallPtrSet< Operation *, 2 > > delinearizeInductionVariable(RewriterBase &rewriter, Location loc, Value linearizedIv, ArrayRef< Value > ubs)
For each original loop, the value of the induction variable can be obtained by dividing the induction...
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, int64_t divisor)
static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc, ArrayRef< Value > values)
Helper function to multiply a sequence of values.
static bool areInnerBoundsInvariant(scf::ForOp forOp)
Check if bounds of all inner loops are defined outside of forOp and return false if not.
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class represents an argument of a Block.
Block represents an ordered list of Operations.
OpListType::iterator iterator
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
TypedAttr getOneAttr(Type type)
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
static OpBuilder atBlockTerminator(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the block terminator.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
ParentT getParentOfType()
Find the first parent operation of the given type, or nullptr if there is no ancestor operation.
bool hasOneBlock()
Return true if this region has exactly one block.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Operation * getOwner() const
Return the owner of this operand.
SmallVector< SmallVector< AffineForOp, 8 >, 8 > tile(ArrayRef< AffineForOp > forOps, ArrayRef< uint64_t > sizes, ArrayRef< AffineForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
void getPerfectlyNestedLoops(SmallVectorImpl< scf::ForOp > &nestedLoops, scf::ForOp root)
Get perfectly nested sequence of loops starting at root of loop nest (the first op being another Affi...
LogicalResult loopUnrollByFactor(scf::ForOp forOp, uint64_t unrollFactor, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr)
Unrolls this for operation by the specified unroll factor.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult outlineIfOp(RewriterBase &b, scf::IfOp ifOp, func::FuncOp *thenFn, StringRef thenFnName, func::FuncOp *elseFn, StringRef elseFnName)
Outline the then and/or else regions of ifOp as follows:
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion)
Replace all uses of orig within the given region with replacement.
SmallVector< scf::ForOp > replaceLoopNestWithNewYields(RewriterBase &rewriter, MutableArrayRef< scf::ForOp > loopNest, ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn, bool replaceIterOperandsUsesInLoop=true)
Update a perfectly nested loop nest to yield new values from the innermost loop and propagating it up...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op)
Walk an affine.for to find a band to coalesce.
std::pair< Loops, Loops > TileLoops
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops, ArrayRef< std::vector< unsigned >> combinedDimensions)
Take the ParallelLoop and for each set of dimension indices, combine them into a single dimension.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef< Value > sizes)
Tile a nest of scf::ForOp loops rooted at rootForOp with the given (parametric) sizes.
LogicalResult loopUnrollJamByFactor(scf::ForOp forOp, uint64_t unrollFactor)
Unrolls and jams this scf.for operation by the specified unroll factor.
bool getInnermostParallelLoops(Operation *rootOp, SmallVectorImpl< scf::ParallelOp > &result)
Get a list of innermost parallel loops contained in rootOp.
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
FailureOr< func::FuncOp > outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region ®ion, StringRef funcName, func::CallOp *callOp=nullptr)
Outline a region with a single block into a new FuncOp.
bool areValuesDefinedAbove(Range values, Region &limit)
Check if all values in the provided range are defined above the limit region.
void denormalizeInductionVariable(RewriterBase &rewriter, Location loc, Value normalizedIv, OpFoldResult origLb, OpFoldResult origStep)
Get back the original induction variable values after loop normalization.
scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter)
Given two scf.forall loops, target and source, fuses target into source.
LogicalResult coalesceLoops(MutableArrayRef< scf::ForOp > loops)
Replace a perfect nest of "for" loops with a single linearized loop.
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter)
Given two scf.for loops, target and source, fuses target into source.
TileLoops extractFixedOuterLoops(scf::ForOp rootFOrOp, ArrayRef< int64_t > sizes)
Range emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc, OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Materialize bounds and step of a zero-based and unit-step loop derived by normalizing the specified b...
FailureOr< scf::ForallOp > normalizeForallOp(RewriterBase &rewriter, scf::ForallOp forallOp)
Normalize an scf.forall operation.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
SmallVector< std::pair< Block::iterator, Block::iterator > > subBlocks
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.