26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SetVector.h"
28 #include "llvm/ADT/SmallPtrSet.h"
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/MathExtras.h"
36 #define DEBUG_TYPE "scf-utils"
37 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
38 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
43 bool replaceIterOperandsUsesInLoop) {
49 assert(loopNest.size() <= 10 &&
50 "exceeded recursion limit when yielding value from loop nest");
82 if (loopNest.size() == 1) {
84 cast<scf::ForOp>(*loopNest.back().replaceWithAdditionalYields(
85 rewriter, newIterOperands, replaceIterOperandsUsesInLoop,
87 return {innerMostLoop};
97 innerNewBBArgs, newYieldValuesFn,
98 replaceIterOperandsUsesInLoop);
99 return llvm::to_vector(llvm::map_range(
100 newLoopNest.front().getResults().take_back(innerNewBBArgs.size()),
103 scf::ForOp outerMostLoop =
104 cast<scf::ForOp>(*loopNest.front().replaceWithAdditionalYields(
105 rewriter, newIterOperands, replaceIterOperandsUsesInLoop, fn));
106 newLoopNest.insert(newLoopNest.begin(), outerMostLoop);
123 func::CallOp *callOp) {
124 assert(!funcName.empty() &&
"funcName cannot be empty");
138 ValueRange outlinedValues(captures.getArrayRef());
145 outlinedFuncArgTypes.push_back(arg.getType());
146 outlinedFuncArgLocs.push_back(arg.getLoc());
148 for (
Value value : outlinedValues) {
149 outlinedFuncArgTypes.push_back(value.getType());
150 outlinedFuncArgLocs.push_back(value.getLoc());
152 FunctionType outlinedFuncType =
156 rewriter.
create<func::FuncOp>(loc, funcName, outlinedFuncType);
157 Block *outlinedFuncBody = outlinedFunc.addEntryBlock();
162 auto outlinedFuncBlockArgs = outlinedFuncBody->getArguments();
167 originalBlock, outlinedFuncBody,
168 outlinedFuncBlockArgs.take_front(numOriginalBlockArguments));
178 ®ion, region.
begin(),
179 TypeRange{outlinedFuncArgTypes}.take_front(numOriginalBlockArguments),
181 .take_front(numOriginalBlockArguments));
186 llvm::append_range(callValues, newBlock->
getArguments());
187 llvm::append_range(callValues, outlinedValues);
188 auto call = rewriter.
create<func::CallOp>(loc, outlinedFunc, callValues);
197 rewriter.
clone(*originalTerminator, bvm);
198 rewriter.
eraseOp(originalTerminator);
203 for (
auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back(
204 outlinedValues.size()))) {
205 Value orig = std::get<0>(it);
206 Value repl = std::get<1>(it);
216 return outlinedFunc->isProperAncestor(opOperand.
getOwner());
224 func::FuncOp *thenFn, StringRef thenFnName,
225 func::FuncOp *elseFn, StringRef elseFnName) {
228 FailureOr<func::FuncOp> outlinedFuncOpOrFailure;
229 if (thenFn && !ifOp.getThenRegion().empty()) {
231 rewriter, loc, ifOp.getThenRegion(), thenFnName);
232 if (failed(outlinedFuncOpOrFailure))
234 *thenFn = *outlinedFuncOpOrFailure;
236 if (elseFn && !ifOp.getElseRegion().empty()) {
238 rewriter, loc, ifOp.getElseRegion(), elseFnName);
239 if (failed(outlinedFuncOpOrFailure))
241 *elseFn = *outlinedFuncOpOrFailure;
248 assert(rootOp !=
nullptr &&
"Root operation must not be a nullptr.");
249 bool rootEnclosesPloops =
false;
251 for (
Block &block : region.getBlocks()) {
254 rootEnclosesPloops |= enclosesPloops;
255 if (
auto ploop = dyn_cast<scf::ParallelOp>(op)) {
256 rootEnclosesPloops =
true;
260 result.push_back(ploop);
265 return rootEnclosesPloops;
273 assert(divisor > 0 &&
"expected positive divisor");
275 "expected integer or index-typed value");
277 Value divisorMinusOneCst = builder.
create<arith::ConstantOp>(
279 Value divisorCst = builder.
create<arith::ConstantOp>(
281 Value sum = builder.
create<arith::AddIOp>(loc, dividend, divisorMinusOneCst);
282 return builder.
create<arith::DivUIOp>(loc, sum, divisorCst);
292 "expected integer or index-typed value");
295 Value divisorMinusOne = builder.
create<arith::SubIOp>(loc, divisor, cstOne);
296 Value sum = builder.
create<arith::AddIOp>(loc, dividend, divisorMinusOne);
297 return builder.
create<arith::DivUIOp>(loc, sum, divisor);
307 if (!lbCstOp.has_value() || !ubCstOp.has_value() || !stepCstOp.has_value())
311 int64_t lbCst = lbCstOp.value();
312 int64_t ubCst = ubCstOp.value();
313 int64_t stepCst = stepCstOp.value();
314 assert(lbCst >= 0 && ubCst >= 0 && stepCst > 0 &&
315 "expected positive loop bounds and step");
316 return llvm::divideCeilSigned(ubCst - lbCst, stepCst);
324 Block *loopBodyBlock,
Value forOpIV, uint64_t unrollFactor,
342 for (
unsigned i = 1; i < unrollFactor; i++) {
346 operandMap.
map(iterArgs, lastYielded);
351 Value ivUnroll = ivRemapFn(i, forOpIV, builder);
352 operandMap.
map(forOpIV, ivUnroll);
356 for (
auto it = loopBodyBlock->
begin(); it != std::next(srcBlockEnd); it++) {
358 annotateFn(i, clonedOp, builder);
362 for (
unsigned i = 0, e = lastYielded.size(); i < e; i++)
363 lastYielded[i] = operandMap.
lookup(yieldedValues[i]);
368 for (
auto it = loopBodyBlock->
begin(); it != std::next(srcBlockEnd); it++)
369 annotateFn(0, &*it, builder);
378 scf::ForOp forOp, uint64_t unrollFactor,
380 assert(unrollFactor > 0 &&
"expected positive unroll factor");
383 if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
390 auto loc = forOp.getLoc();
391 Value step = forOp.getStep();
392 Value upperBoundUnrolled;
394 bool generateEpilogueLoop =
true;
397 if (constTripCount) {
402 if (unrollFactor == 1) {
403 if (*constTripCount == 1 &&
404 failed(forOp.promoteIfSingleIteration(rewriter)))
409 int64_t tripCountEvenMultiple =
410 *constTripCount - (*constTripCount % unrollFactor);
411 int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
412 int64_t stepUnrolledCst = stepCst * unrollFactor;
415 generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
416 if (generateEpilogueLoop)
417 upperBoundUnrolled = boundsBuilder.
create<arith::ConstantOp>(
418 loc, boundsBuilder.
getIntegerAttr(forOp.getUpperBound().getType(),
419 upperBoundUnrolledCst));
421 upperBoundUnrolled = forOp.getUpperBound();
424 stepUnrolled = stepCst == stepUnrolledCst
426 : boundsBuilder.
create<arith::ConstantOp>(
428 step.
getType(), stepUnrolledCst));
433 auto lowerBound = forOp.getLowerBound();
434 auto upperBound = forOp.getUpperBound();
436 boundsBuilder.
create<arith::SubIOp>(loc, upperBound, lowerBound);
438 Value unrollFactorCst = boundsBuilder.
create<arith::ConstantOp>(
439 loc, boundsBuilder.
getIntegerAttr(tripCount.getType(), unrollFactor));
441 boundsBuilder.
create<arith::RemSIOp>(loc, tripCount, unrollFactorCst);
443 Value tripCountEvenMultiple =
444 boundsBuilder.
create<arith::SubIOp>(loc, tripCount, tripCountRem);
446 upperBoundUnrolled = boundsBuilder.
create<arith::AddIOp>(
448 boundsBuilder.
create<arith::MulIOp>(loc, tripCountEvenMultiple, step));
451 boundsBuilder.
create<arith::MulIOp>(loc, step, unrollFactorCst);
457 if (generateEpilogueLoop) {
458 OpBuilder epilogueBuilder(forOp->getContext());
460 auto epilogueForOp = cast<scf::ForOp>(epilogueBuilder.
clone(*forOp));
461 epilogueForOp.setLowerBound(upperBoundUnrolled);
464 auto results = forOp.getResults();
465 auto epilogueResults = epilogueForOp.getResults();
467 for (
auto e : llvm::zip(results, epilogueResults)) {
468 std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
470 epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
471 epilogueForOp.getInitArgs().size(), results);
472 if (epilogueForOp.promoteIfSingleIteration(rewriter).failed())
477 forOp.setUpperBound(upperBoundUnrolled);
478 forOp.setStep(stepUnrolled);
480 auto iterArgs =
ValueRange(forOp.getRegionIterArgs());
481 auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();
484 forOp.getBody(), forOp.getInductionVar(), unrollFactor,
487 auto stride = b.create<arith::MulIOp>(
489 b.create<arith::ConstantOp>(loc,
490 b.getIntegerAttr(iv.getType(), i)));
491 return b.create<arith::AddIOp>(loc, iv, stride);
493 annotateFn, iterArgs, yieldedValues);
495 if (forOp.promoteIfSingleIteration(rewriter).failed())
503 auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
504 if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
505 !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
506 !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
511 return !walkResult.wasInterrupted();
516 uint64_t unrollJamFactor) {
517 assert(unrollJamFactor > 0 &&
"unroll jam factor should be positive");
519 if (unrollJamFactor == 1)
525 LDBG(
"failed to unroll and jam: inner bounds are not invariant");
530 if (forOp->getNumResults() > 0) {
531 LDBG(
"failed to unroll and jam: unsupported loop with results");
538 if (!tripCount.has_value()) {
540 LDBG(
"failed to unroll and jam: trip count could not be determined");
543 if (unrollJamFactor > *tripCount) {
544 LDBG(
"unroll and jam factor is greater than trip count, set factor to trip "
546 unrollJamFactor = *tripCount;
547 }
else if (*tripCount % unrollJamFactor != 0) {
548 LDBG(
"failed to unroll and jam: unsupported trip count that is not a "
549 "multiple of unroll jam factor");
554 if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
564 forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); });
575 for (scf::ForOp oldForOp : innerLoops) {
577 ValueRange oldIterOperands = oldForOp.getInits();
578 ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
580 cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
583 for (
unsigned i = unrollJamFactor - 1; i >= 1; --i) {
584 dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
585 dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
589 bool forOpReplaced = oldForOp == forOp;
590 scf::ForOp newForOp =
591 cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields(
592 rewriter, dupIterOperands,
false,
594 return dupYieldOperands;
596 newInnerLoops.push_back(newForOp);
601 ValueRange newIterArgs = newForOp.getRegionIterArgs();
602 unsigned oldNumIterArgs = oldIterArgs.size();
603 ValueRange newResults = newForOp.getResults();
604 unsigned oldNumResults = newResults.size() / unrollJamFactor;
605 assert(oldNumIterArgs == oldNumResults &&
606 "oldNumIterArgs must be the same as oldNumResults");
607 for (
unsigned i = unrollJamFactor - 1; i >= 1; --i) {
608 for (
unsigned j = 0;
j < oldNumIterArgs; ++
j) {
612 operandMaps[i - 1].map(newIterArgs[
j],
613 newIterArgs[i * oldNumIterArgs +
j]);
614 operandMaps[i - 1].map(newResults[
j],
615 newResults[i * oldNumResults +
j]);
622 int64_t step = forOp.getConstantStep()->getSExtValue();
624 forOp.getLoc(), forOp.getStep(),
626 forOp.getLoc(), rewriter.
getIndexAttr(unrollJamFactor)));
627 forOp.setStep(newStep);
628 auto forOpIV = forOp.getInductionVar();
631 for (
unsigned i = unrollJamFactor - 1; i >= 1; --i) {
632 for (
auto &subBlock : subBlocks) {
635 OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
644 builder.
createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag);
645 operandMaps[i - 1].map(forOpIV, ivUnroll);
648 for (
auto it = subBlock.first; it != std::next(subBlock.second); ++it)
649 builder.
clone(*it, operandMaps[i - 1]);
652 for (
auto newForOp : newInnerLoops) {
653 unsigned oldNumIterOperands =
654 newForOp.getNumRegionIterArgs() / unrollJamFactor;
655 unsigned numControlOperands = newForOp.getNumControlOperands();
656 auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
657 unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
658 assert(oldNumIterOperands == oldNumYieldOperands &&
659 "oldNumIterOperands must be the same as oldNumYieldOperands");
660 for (
unsigned j = 0;
j < oldNumIterOperands; ++
j) {
664 newForOp.setOperand(numControlOperands + i * oldNumIterOperands +
j,
665 operandMaps[i - 1].lookupOrDefault(
666 newForOp.getOperand(numControlOperands +
j)));
668 i * oldNumYieldOperands +
j,
669 operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(
j)));
675 (void)forOp.promoteIfSingleIteration(rewriter);
682 Range normalizedLoopBounds;
688 normalizedLoopBounds.
size =
690 return normalizedLoopBounds;
702 bool isZeroBased =
false;
704 isZeroBased = lbCst.value() == 0;
706 bool isStepOne =
false;
708 isStepOne = stepCst.value() == 1;
712 "expected matching types");
717 if (isZeroBased && isStepOne)
718 return {lb, ub, step};
728 newUpperBound = rewriter.
createOrFold<arith::CeilDivSIOp>(
736 return {newLowerBound, newUpperBound, newStep};
750 Value denormalizedIvVal =
757 if (
Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
758 preservedUses.insert(preservedUse);
767 if (
getType(origLb).isIndex()) {
771 Value denormalizedIv;
776 Value scaled = normalizedIv;
778 Value origStepValue =
780 scaled = rewriter.
create<arith::MulIOp>(loc, normalizedIv, origStepValue);
783 denormalizedIv = scaled;
786 denormalizedIv = rewriter.
create<arith::AddIOp>(loc, scaled, origLbValue);
795 assert(!values.empty() &&
"unexecpted empty array");
800 for (
auto v : values) {
810 assert(!values.empty() &&
"unexpected empty list");
816 std::optional<Value> productOf;
817 for (
auto v : values) {
819 if (vOne && vOne.value() == 1)
823 rewriter.
create<arith::MulIOp>(loc, productOf.value(), v).getResult();
829 .
create<arith::ConstantOp>(
833 return productOf.value();
850 rewriter.
create<affine::AffineDelinearizeIndexOp>(loc, linearizedIv,
852 auto resultVals = llvm::map_to_vector(
860 llvm::BitVector isUbOne(ubs.size());
863 if (ubCst && ubCst.value() == 1)
868 unsigned numLeadingOneUbs = 0;
870 if (!isUbOne.test(index)) {
873 delinearizedIvs[index] = rewriter.
create<arith::ConstantOp>(
878 Value previous = linearizedIv;
879 for (
unsigned i = numLeadingOneUbs, e = ubs.size(); i < e; ++i) {
880 unsigned idx = ubs.size() - (i - numLeadingOneUbs) - 1;
881 if (i != numLeadingOneUbs && !isUbOne.test(idx + 1)) {
882 previous = rewriter.
create<arith::DivSIOp>(loc, previous, ubs[idx + 1]);
883 preservedUsers.insert(previous.getDefiningOp());
887 if (!isUbOne.test(idx)) {
888 iv = rewriter.
create<arith::RemSIOp>(loc, previous, ubs[idx]);
891 iv = rewriter.
create<arith::ConstantOp>(
895 delinearizedIvs[idx] = iv;
897 return {delinearizedIvs, preservedUsers};
902 if (loops.size() < 2)
905 scf::ForOp innermost = loops.back();
906 scf::ForOp outermost = loops.front();
910 for (
auto loop : loops) {
913 Value lb = loop.getLowerBound();
914 Value ub = loop.getUpperBound();
915 Value step = loop.getStep();
921 newLoopRange.offset));
925 newLoopRange.stride));
929 loop.getInductionVar(), lb, step);
938 loops, [](
auto loop) {
return loop.getUpperBound(); });
940 outermost.setUpperBound(upperBound);
944 rewriter, loc, outermost.getInductionVar(), upperBounds);
948 for (
int i = loops.size() - 1; i > 0; --i) {
949 auto outerLoop = loops[i - 1];
950 auto innerLoop = loops[i];
952 Operation *innerTerminator = innerLoop.getBody()->getTerminator();
953 auto yieldedVals = llvm::to_vector(innerTerminator->
getOperands());
954 assert(llvm::equal(outerLoop.getRegionIterArgs(), innerLoop.getInitArgs()));
955 for (
Value &yieldedVal : yieldedVals) {
958 auto iter = llvm::find(innerLoop.getRegionIterArgs(), yieldedVal);
959 if (iter != innerLoop.getRegionIterArgs().end()) {
960 unsigned iterArgIndex = iter - innerLoop.getRegionIterArgs().begin();
962 assert(iterArgIndex < innerLoop.getInitArgs().size());
963 yieldedVal = innerLoop.getInitArgs()[iterArgIndex];
966 rewriter.
eraseOp(innerTerminator);
969 innerBlockArgs.push_back(delinearizeIvs[i]);
970 llvm::append_range(innerBlockArgs, outerLoop.getRegionIterArgs());
973 rewriter.
replaceOp(innerLoop, yieldedVals);
982 IRRewriter rewriter(loops.front().getContext());
987 LogicalResult result(failure());
997 for (
unsigned i = 0, e = loops.size(); i < e; ++i) {
998 operandsDefinedAbove[i] = i;
999 for (
unsigned j = 0;
j < i; ++
j) {
1001 loops[i].getUpperBound(),
1002 loops[i].getStep()};
1004 operandsDefinedAbove[i] =
j;
1015 iterArgChainStart[0] = 0;
1016 for (
unsigned i = 1, e = loops.size(); i < e; ++i) {
1018 iterArgChainStart[i] = i;
1019 auto outerloop = loops[i - 1];
1020 auto innerLoop = loops[i];
1021 if (outerloop.getNumRegionIterArgs() != innerLoop.getNumRegionIterArgs()) {
1024 if (!llvm::equal(outerloop.getRegionIterArgs(), innerLoop.getInitArgs())) {
1027 auto outerloopTerminator = outerloop.getBody()->getTerminator();
1028 if (!llvm::equal(outerloopTerminator->getOperands(),
1029 innerLoop.getResults())) {
1032 iterArgChainStart[i] = iterArgChainStart[i - 1];
1038 for (
unsigned end = loops.size(); end > 0; --end) {
1040 for (; start < end - 1; ++start) {
1042 *std::max_element(std::next(operandsDefinedAbove.begin(), start),
1043 std::next(operandsDefinedAbove.begin(), end));
1046 if (iterArgChainStart[end - 1] > start)
1055 if (start != end - 1)
1063 ArrayRef<std::vector<unsigned>> combinedDimensions) {
1069 auto sortedDimensions = llvm::to_vector<3>(combinedDimensions);
1070 for (
auto &dims : sortedDimensions)
1075 for (
unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
1078 Value lb = loops.getLowerBound()[i];
1079 Value ub = loops.getUpperBound()[i];
1080 Value step = loops.getStep()[i];
1083 rewriter, loops.getLoc(), newLoopRange.size));
1094 for (
auto &sortedDimension : sortedDimensions) {
1096 for (
auto idx : sortedDimension) {
1097 newUpperBound = rewriter.
create<arith::MulIOp>(
1098 loc, newUpperBound, normalizedUpperBounds[idx]);
1100 lowerBounds.push_back(cst0);
1101 steps.push_back(cst1);
1102 upperBounds.push_back(newUpperBound);
1111 auto newPloop = rewriter.
create<scf::ParallelOp>(
1112 loc, lowerBounds, upperBounds, steps,
1114 for (
unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
1115 Value previous = ploopIVs[i];
1116 unsigned numberCombinedDimensions = combinedDimensions[i].size();
1118 for (
unsigned j = numberCombinedDimensions - 1;
j > 0; --
j) {
1119 unsigned idx = combinedDimensions[i][
j];
1122 Value iv = insideBuilder.create<arith::RemSIOp>(
1123 loc, previous, normalizedUpperBounds[idx]);
1129 previous = insideBuilder.create<arith::DivSIOp>(
1130 loc, previous, normalizedUpperBounds[idx]);
1134 unsigned idx = combinedDimensions[i][0];
1136 previous, loops.getRegion());
1141 loops.getBody()->back().
erase();
1142 newPloop.getBody()->getOperations().splice(
1144 loops.getBody()->getOperations());
1157 return op != inner.getOperation();
1160 LogicalResult status = success();
1162 for (
auto &op : outer.getBody()->without_terminator()) {
1164 if (&op == inner.getOperation())
1167 if (forwardSlice.count(&op) > 0) {
1172 if (isa<scf::ForOp>(op))
1175 if (op.getNumRegions() > 0) {
1185 toHoist.push_back(&op);
1187 auto *outerForOp = outer.getOperation();
1188 for (
auto *op : toHoist)
1189 op->moveBefore(outerForOp);
1198 LogicalResult status = success();
1199 const Loops &interTile = tileLoops.first;
1200 const Loops &intraTile = tileLoops.second;
1201 auto size = interTile.size();
1202 assert(size == intraTile.size());
1205 for (
unsigned s = 1; s < size; ++s)
1206 status = succeeded(status) ?
hoistOpsBetween(intraTile[0], intraTile[s])
1208 for (
unsigned s = 1; s < size; ++s)
1209 status = succeeded(status) ?
hoistOpsBetween(interTile[0], interTile[s])
1218 template <
typename T>
1222 for (
unsigned i = 0; i < maxLoops; ++i) {
1223 forOps.push_back(rootForOp);
1225 if (body.
begin() != std::prev(body.
end(), 2))
1228 rootForOp = dyn_cast<T>(&body.
front());
1236 auto originalStep = forOp.getStep();
1237 auto iv = forOp.getInductionVar();
1240 forOp.setStep(b.
create<arith::MulIOp>(forOp.getLoc(), originalStep, factor));
1243 for (
auto t : targets) {
1245 auto begin = t.getBody()->begin();
1246 auto nOps = t.getBody()->getOperations().size();
1250 Value stepped = b.
create<arith::AddIOp>(t.getLoc(), iv, forOp.getStep());
1252 b.
create<arith::MinSIOp>(t.getLoc(), forOp.getUpperBound(), stepped);
1255 auto newForOp = b.
create<scf::ForOp>(t.getLoc(), iv, ub, originalStep);
1256 newForOp.getBody()->getOperations().splice(
1257 newForOp.getBody()->getOperations().begin(),
1258 t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
1260 newForOp.getRegion());
1262 innerLoops.push_back(newForOp);
1270 template <
typename SizeType>
1272 scf::ForOp target) {
1278 assert(res.size() == 1 &&
"Expected 1 inner forOp");
1287 for (
auto it : llvm::zip(forOps, sizes)) {
1288 auto step =
stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
1289 res.push_back(step);
1290 currentTargets = step;
1296 scf::ForOp target) {
1299 assert(loops.size() == 1);
1300 res.push_back(loops[0]);
1309 forOps.reserve(sizes.size());
1311 if (forOps.size() < sizes.size())
1312 sizes = sizes.take_front(forOps.size());
1327 forOps.reserve(sizes.size());
1329 if (forOps.size() < sizes.size())
1330 sizes = sizes.take_front(forOps.size());
1337 tileSizes.reserve(sizes.size());
1338 for (
unsigned i = 0, e = sizes.size(); i < e; ++i) {
1339 assert(sizes[i] > 0 &&
"expected strictly positive size for strip-mining");
1341 auto forOp = forOps[i];
1343 auto loc = forOp.getLoc();
1344 Value diff = builder.
create<arith::SubIOp>(loc, forOp.getUpperBound(),
1345 forOp.getLowerBound());
1347 Value iterationsPerBlock =
1349 tileSizes.push_back(iterationsPerBlock);
1353 auto intraTile =
tile(forOps, tileSizes, forOps.back());
1354 TileLoops tileLoops = std::make_pair(forOps, intraTile);
1365 scf::ForallOp source,
1367 unsigned numTargetOuts = target.getNumResults();
1368 unsigned numSourceOuts = source.getNumResults();
1372 llvm::append_range(fusedOuts, target.getOutputs());
1373 llvm::append_range(fusedOuts, source.getOutputs());
1377 scf::ForallOp fusedLoop = rewriter.
create<scf::ForallOp>(
1378 source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
1379 source.getMixedStep(), fusedOuts, source.getMapping());
1383 mapping.
map(target.getInductionVars(), fusedLoop.getInductionVars());
1384 mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
1387 mapping.map(target.getRegionIterArgs(),
1388 fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1389 mapping.map(source.getRegionIterArgs(),
1390 fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1394 for (
Operation &op : target.getBody()->without_terminator())
1395 rewriter.
clone(op, mapping);
1396 for (
Operation &op : source.getBody()->without_terminator())
1397 rewriter.
clone(op, mapping);
1400 scf::InParallelOp targetTerm = target.getTerminator();
1401 scf::InParallelOp sourceTerm = source.getTerminator();
1402 scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
1404 for (
Operation &op : targetTerm.getYieldingOps())
1405 rewriter.
clone(op, mapping);
1406 for (
Operation &op : sourceTerm.getYieldingOps())
1407 rewriter.
clone(op, mapping);
1410 rewriter.
replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1411 rewriter.
replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1419 unsigned numTargetOuts = target.getNumResults();
1420 unsigned numSourceOuts = source.getNumResults();
1424 llvm::append_range(fusedInitArgs, target.getInitArgs());
1425 llvm::append_range(fusedInitArgs, source.getInitArgs());
1430 scf::ForOp fusedLoop = rewriter.
create<scf::ForOp>(
1431 source.getLoc(), source.getLowerBound(), source.getUpperBound(),
1432 source.getStep(), fusedInitArgs);
1436 mapping.
map(target.getInductionVar(), fusedLoop.getInductionVar());
1437 mapping.map(target.getRegionIterArgs(),
1438 fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1439 mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
1440 mapping.map(source.getRegionIterArgs(),
1441 fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1445 for (
Operation &op : target.getBody()->without_terminator())
1446 rewriter.
clone(op, mapping);
1447 for (
Operation &op : source.getBody()->without_terminator())
1448 rewriter.
clone(op, mapping);
1452 for (
Value operand : target.getBody()->getTerminator()->getOperands())
1453 yieldResults.push_back(mapping.lookupOrDefault(operand));
1454 for (
Value operand : source.getBody()->getTerminator()->getOperands())
1455 yieldResults.push_back(mapping.lookupOrDefault(operand));
1456 if (!yieldResults.empty())
1457 rewriter.
create<scf::YieldOp>(source.getLoc(), yieldResults);
1460 rewriter.
replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1461 rewriter.
replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1467 scf::ForallOp forallOp) {
1480 for (
auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
1481 Range normalizedLoopParams =
1483 newLbs.push_back(normalizedLoopParams.
offset);
1484 newUbs.push_back(normalizedLoopParams.
size);
1485 newSteps.push_back(normalizedLoopParams.
stride);
1488 auto normalizedForallOp = rewriter.
create<scf::ForallOp>(
1489 forallOp.getLoc(), newLbs, newUbs, newSteps, forallOp.getOutputs(),
1493 normalizedForallOp.getBodyRegion(),
1494 normalizedForallOp.getBodyRegion().begin());
static std::optional< int64_t > getConstantTripCount(scf::ForOp forOp)
Returns the trip count of forOp if its' low bound, high bound and step are constants,...
static OpFoldResult getProductOfIndexes(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > values)
static LogicalResult tryIsolateBands(const TileLoops &tileLoops)
static void getPerfectlyNestedLoopsImpl(SmallVectorImpl< T > &forOps, T rootForOp, unsigned maxLoops=std::numeric_limits< unsigned >::max())
Collect perfectly nested loops starting from rootForOps.
static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner)
static void generateUnrolledLoop(Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor, function_ref< Value(unsigned, Value, OpBuilder)> ivRemapFn, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn, ValueRange iterArgs, ValueRange yieldedValues)
Generates unrolled copies of scf::ForOp 'loopBodyBlock', with associated 'forOpIV' by 'unrollFactor',...
static Loops stripmineSink(scf::ForOp forOp, Value factor, ArrayRef< scf::ForOp > targets)
static std::pair< SmallVector< Value >, SmallPtrSet< Operation *, 2 > > delinearizeInductionVariable(RewriterBase &rewriter, Location loc, Value linearizedIv, ArrayRef< Value > ubs)
For each original loop, the value of the induction variable can be obtained by dividing the induction...
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, int64_t divisor)
static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc, ArrayRef< Value > values)
Helper function to multiply a sequence of values.
static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter, Location loc, Value normalizedIv, OpFoldResult origLb, OpFoldResult origStep)
Range emitNormalizedLoopBoundsForIndexType(RewriterBase &rewriter, Location loc, OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
static bool areInnerBoundsInvariant(scf::ForOp forOp)
Check if bounds of all inner loops are defined outside of forOp and return false if not.
static int64_t product(ArrayRef< int64_t > vals)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
OpListType::iterator iterator
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
TypedAttr getOneAttr(Type type)
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
static OpBuilder atBlockTerminator(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the block terminator.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
result_range getResults()
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
ParentT getParentOfType()
Find the first parent operation of the given type, or nullptr if there is no ancestor operation.
bool hasOneBlock()
Return true if this region has exactly one block.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
void replaceAllOpUsesWith(Operation *from, ValueRange to)
Find uses of from and replace them with to.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< SmallVector< AffineForOp, 8 >, 8 > tile(ArrayRef< AffineForOp > forOps, ArrayRef< uint64_t > sizes, ArrayRef< AffineForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
void getPerfectlyNestedLoops(SmallVectorImpl< scf::ForOp > &nestedLoops, scf::ForOp root)
Get perfectly nested sequence of loops starting at root of loop nest (the first op being another Affi...
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult outlineIfOp(RewriterBase &b, scf::IfOp ifOp, func::FuncOp *thenFn, StringRef thenFnName, func::FuncOp *elseFn, StringRef elseFnName)
Outline the then and/or else regions of ifOp as follows:
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion)
Replace all uses of orig within the given region with replacement.
SmallVector< scf::ForOp > replaceLoopNestWithNewYields(RewriterBase &rewriter, MutableArrayRef< scf::ForOp > loopNest, ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn, bool replaceIterOperandsUsesInLoop=true)
Update a perfectly nested loop nest to yield new values from the innermost loop and propagating it up...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op)
Walk an affine.for to find a band to coalesce.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
std::pair< Loops, Loops > TileLoops
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops, ArrayRef< std::vector< unsigned >> combinedDimensions)
Take the ParallelLoop and for each set of dimension indices, combine them into a single dimension.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef< Value > sizes)
Tile a nest of scf::ForOp loops rooted at rootForOp with the given (parametric) sizes.
FailureOr< UnrolledLoopInfo > loopUnrollByFactor(scf::ForOp forOp, uint64_t unrollFactor, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr)
Unrolls this for operation by the specified unroll factor.
LogicalResult loopUnrollJamByFactor(scf::ForOp forOp, uint64_t unrollFactor)
Unrolls and jams this scf.for operation by the specified unroll factor.
bool getInnermostParallelLoops(Operation *rootOp, SmallVectorImpl< scf::ParallelOp > &result)
Get a list of innermost parallel loops contained in rootOp.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
FailureOr< func::FuncOp > outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region ®ion, StringRef funcName, func::CallOp *callOp=nullptr)
Outline a region with a single block into a new FuncOp.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool areValuesDefinedAbove(Range values, Region &limit)
Check if all values in the provided range are defined above the limit region.
void denormalizeInductionVariable(RewriterBase &rewriter, Location loc, Value normalizedIv, OpFoldResult origLb, OpFoldResult origStep)
Get back the original induction variable values after loop normalization.
scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter)
Given two scf.forall loops, target and source, fuses target into source.
LogicalResult coalesceLoops(MutableArrayRef< scf::ForOp > loops)
Replace a perfect nest of "for" loops with a single linearized loop.
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter)
Given two scf.for loops, target and source, fuses target into source.
TileLoops extractFixedOuterLoops(scf::ForOp rootFOrOp, ArrayRef< int64_t > sizes)
Range emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc, OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Materialize bounds and step of a zero-based and unit-step loop derived by normalizing the specified b...
FailureOr< scf::ForallOp > normalizeForallOp(RewriterBase &rewriter, scf::ForallOp forallOp)
Normalize an scf.forall operation.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
SmallVector< std::pair< Block::iterator, Block::iterator > > subBlocks
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
std::optional< scf::ForOp > epilogueLoopOp
std::optional< scf::ForOp > mainLoopOp
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.