25 #include "llvm/ADT/APInt.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/Support/DebugLog.h"
33 #define DEBUG_TYPE "scf-utils"
38 bool replaceIterOperandsUsesInLoop) {
44 assert(loopNest.size() <= 10 &&
45 "exceeded recursion limit when yielding value from loop nest");
77 if (loopNest.size() == 1) {
79 cast<scf::ForOp>(*loopNest.back().replaceWithAdditionalYields(
80 rewriter, newIterOperands, replaceIterOperandsUsesInLoop,
82 return {innerMostLoop};
92 innerNewBBArgs, newYieldValuesFn,
93 replaceIterOperandsUsesInLoop);
94 return llvm::to_vector(llvm::map_range(
95 newLoopNest.front().getResults().take_back(innerNewBBArgs.size()),
98 scf::ForOp outerMostLoop =
99 cast<scf::ForOp>(*loopNest.front().replaceWithAdditionalYields(
100 rewriter, newIterOperands, replaceIterOperandsUsesInLoop, fn));
101 newLoopNest.insert(newLoopNest.begin(), outerMostLoop);
118 func::CallOp *callOp) {
119 assert(!funcName.empty() &&
"funcName cannot be empty");
133 ValueRange outlinedValues(captures.getArrayRef());
140 outlinedFuncArgTypes.push_back(arg.getType());
141 outlinedFuncArgLocs.push_back(arg.getLoc());
143 for (
Value value : outlinedValues) {
144 outlinedFuncArgTypes.push_back(value.getType());
145 outlinedFuncArgLocs.push_back(value.getLoc());
147 FunctionType outlinedFuncType =
151 func::FuncOp::create(rewriter, loc, funcName, outlinedFuncType);
152 Block *outlinedFuncBody = outlinedFunc.addEntryBlock();
157 auto outlinedFuncBlockArgs = outlinedFuncBody->
getArguments();
162 originalBlock, outlinedFuncBody,
163 outlinedFuncBlockArgs.take_front(numOriginalBlockArguments));
166 func::ReturnOp::create(rewriter, loc, originalTerminator->
getResultTypes(),
173 ®ion, region.
begin(),
174 TypeRange{outlinedFuncArgTypes}.take_front(numOriginalBlockArguments),
176 .take_front(numOriginalBlockArguments));
181 llvm::append_range(callValues, newBlock->
getArguments());
182 llvm::append_range(callValues, outlinedValues);
183 auto call = func::CallOp::create(rewriter, loc, outlinedFunc, callValues);
192 rewriter.
clone(*originalTerminator, bvm);
193 rewriter.
eraseOp(originalTerminator);
198 for (
auto it : llvm::zip(outlinedValues, outlinedFuncBlockArgs.take_back(
199 outlinedValues.size()))) {
200 Value orig = std::get<0>(it);
201 Value repl = std::get<1>(it);
210 return outlinedFunc->isProperAncestor(opOperand.
getOwner());
218 func::FuncOp *thenFn, StringRef thenFnName,
219 func::FuncOp *elseFn, StringRef elseFnName) {
222 FailureOr<func::FuncOp> outlinedFuncOpOrFailure;
223 if (thenFn && !ifOp.getThenRegion().empty()) {
225 rewriter, loc, ifOp.getThenRegion(), thenFnName);
226 if (
failed(outlinedFuncOpOrFailure))
228 *thenFn = *outlinedFuncOpOrFailure;
230 if (elseFn && !ifOp.getElseRegion().empty()) {
232 rewriter, loc, ifOp.getElseRegion(), elseFnName);
233 if (
failed(outlinedFuncOpOrFailure))
235 *elseFn = *outlinedFuncOpOrFailure;
242 assert(rootOp !=
nullptr &&
"Root operation must not be a nullptr.");
243 bool rootEnclosesPloops =
false;
245 for (
Block &block : region.getBlocks()) {
248 rootEnclosesPloops |= enclosesPloops;
249 if (
auto ploop = dyn_cast<scf::ParallelOp>(op)) {
250 rootEnclosesPloops =
true;
254 result.push_back(ploop);
259 return rootEnclosesPloops;
267 assert(divisor > 0 &&
"expected positive divisor");
269 "expected integer or index-typed value");
271 Value divisorMinusOneCst = arith::ConstantOp::create(
273 Value divisorCst = arith::ConstantOp::create(
275 Value sum = arith::AddIOp::create(builder, loc, dividend, divisorMinusOneCst);
276 return arith::DivUIOp::create(builder, loc, sum, divisorCst);
286 "expected integer or index-typed value");
287 Value cstOne = arith::ConstantOp::create(
289 Value divisorMinusOne = arith::SubIOp::create(builder, loc, divisor, cstOne);
290 Value sum = arith::AddIOp::create(builder, loc, dividend, divisorMinusOne);
291 return arith::DivUIOp::create(builder, loc, sum, divisor);
299 Block *loopBodyBlock,
Value forOpIV, uint64_t unrollFactor,
309 annotateFn = defaultAnnotateFn;
318 for (
unsigned i = 1; i < unrollFactor; i++) {
322 operandMap.
map(iterArgs, lastYielded);
327 Value ivUnroll = ivRemapFn(i, forOpIV, builder);
328 operandMap.
map(forOpIV, ivUnroll);
332 for (
auto it = loopBodyBlock->
begin(); it != std::next(srcBlockEnd); it++) {
334 annotateFn(i, clonedOp, builder);
338 for (
unsigned i = 0, e = lastYielded.size(); i < e; i++)
344 for (
auto it = loopBodyBlock->
begin(); it != std::next(srcBlockEnd); it++)
345 annotateFn(0, &*it, builder);
354 scf::ForOp forOp, uint64_t unrollFactor,
356 assert(unrollFactor > 0 &&
"expected positive unroll factor");
359 if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
366 auto loc = forOp.getLoc();
367 Value step = forOp.getStep();
368 Value upperBoundUnrolled;
370 bool generateEpilogueLoop =
true;
372 std::optional<APInt> constTripCount = forOp.getStaticTripCount();
373 if (constTripCount) {
378 if (unrollFactor == 1) {
379 if (*constTripCount == 1 &&
380 failed(forOp.promoteIfSingleIteration(rewriter)))
385 int64_t tripCountEvenMultiple =
386 constTripCount->getSExtValue() -
387 (constTripCount->getSExtValue() % unrollFactor);
388 int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
389 int64_t stepUnrolledCst = stepCst * unrollFactor;
392 generateEpilogueLoop = upperBoundUnrolledCst < ubCst;
393 if (generateEpilogueLoop)
394 upperBoundUnrolled = arith::ConstantOp::create(
397 upperBoundUnrolledCst));
399 upperBoundUnrolled = forOp.getUpperBound();
403 stepCst == stepUnrolledCst
405 : arith::ConstantOp::create(boundsBuilder, loc,
407 step.
getType(), stepUnrolledCst));
412 auto lowerBound = forOp.getLowerBound();
413 auto upperBound = forOp.getUpperBound();
415 arith::SubIOp::create(boundsBuilder, loc, upperBound, lowerBound);
417 Value unrollFactorCst = arith::ConstantOp::create(
421 arith::RemSIOp::create(boundsBuilder, loc, tripCount, unrollFactorCst);
423 Value tripCountEvenMultiple =
424 arith::SubIOp::create(boundsBuilder, loc, tripCount, tripCountRem);
426 upperBoundUnrolled = arith::AddIOp::create(
427 boundsBuilder, loc, lowerBound,
428 arith::MulIOp::create(boundsBuilder, loc, tripCountEvenMultiple, step));
431 arith::MulIOp::create(boundsBuilder, loc, step, unrollFactorCst);
437 if (generateEpilogueLoop) {
438 OpBuilder epilogueBuilder(forOp->getContext());
440 auto epilogueForOp = cast<scf::ForOp>(epilogueBuilder.
clone(*forOp));
441 epilogueForOp.setLowerBound(upperBoundUnrolled);
444 auto results = forOp.getResults();
445 auto epilogueResults = epilogueForOp.getResults();
447 for (
auto e : llvm::zip(results, epilogueResults)) {
448 std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
450 epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
451 epilogueForOp.getInitArgs().size(), results);
452 if (epilogueForOp.promoteIfSingleIteration(rewriter).failed())
457 forOp.setUpperBound(upperBoundUnrolled);
458 forOp.setStep(stepUnrolled);
460 auto iterArgs =
ValueRange(forOp.getRegionIterArgs());
461 auto yieldedValues = forOp.getBody()->getTerminator()->getOperands();
464 forOp.getBody(), forOp.getInductionVar(), unrollFactor,
467 auto stride = arith::MulIOp::create(
469 arith::ConstantOp::create(b, loc,
470 b.getIntegerAttr(iv.getType(), i)));
471 return arith::AddIOp::create(b, loc, iv, stride);
473 annotateFn, iterArgs, yieldedValues);
475 if (forOp.promoteIfSingleIteration(rewriter).failed())
483 std::optional<APInt> mayBeConstantTripCount = forOp.getStaticTripCount();
484 if (!mayBeConstantTripCount.has_value())
486 const APInt &tripCount = *mayBeConstantTripCount;
487 if (tripCount.isZero())
489 if (tripCount.getSExtValue() == 1)
490 return forOp.promoteIfSingleIteration(rewriter);
497 auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
498 if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
499 !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
500 !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
505 return !walkResult.wasInterrupted();
510 uint64_t unrollJamFactor) {
511 assert(unrollJamFactor > 0 &&
"unroll jam factor should be positive");
513 if (unrollJamFactor == 1)
519 LDBG() <<
"failed to unroll and jam: inner bounds are not invariant";
524 if (forOp->getNumResults() > 0) {
525 LDBG() <<
"failed to unroll and jam: unsupported loop with results";
531 std::optional<APInt> tripCount = forOp.getStaticTripCount();
532 if (!tripCount.has_value()) {
534 LDBG() <<
"failed to unroll and jam: trip count could not be determined";
537 if (unrollJamFactor > tripCount->getZExtValue()) {
538 LDBG() <<
"unroll and jam factor is greater than trip count, set factor to "
541 unrollJamFactor = tripCount->getZExtValue();
542 }
else if (tripCount->getSExtValue() % unrollJamFactor != 0) {
543 LDBG() <<
"failed to unroll and jam: unsupported trip count that is not a "
544 "multiple of unroll jam factor";
549 if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
559 forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); });
570 for (scf::ForOp oldForOp : innerLoops) {
572 ValueRange oldIterOperands = oldForOp.getInits();
573 ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
575 cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
578 for (
unsigned i = unrollJamFactor - 1; i >= 1; --i) {
579 dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
580 dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
584 bool forOpReplaced = oldForOp == forOp;
585 scf::ForOp newForOp =
586 cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields(
587 rewriter, dupIterOperands,
false,
589 return dupYieldOperands;
591 newInnerLoops.push_back(newForOp);
596 ValueRange newIterArgs = newForOp.getRegionIterArgs();
597 unsigned oldNumIterArgs = oldIterArgs.size();
598 ValueRange newResults = newForOp.getResults();
599 unsigned oldNumResults = newResults.size() / unrollJamFactor;
600 assert(oldNumIterArgs == oldNumResults &&
601 "oldNumIterArgs must be the same as oldNumResults");
602 for (
unsigned i = unrollJamFactor - 1; i >= 1; --i) {
603 for (
unsigned j = 0;
j < oldNumIterArgs; ++
j) {
607 operandMaps[i - 1].map(newIterArgs[
j],
608 newIterArgs[i * oldNumIterArgs +
j]);
609 operandMaps[i - 1].map(newResults[
j],
610 newResults[i * oldNumResults +
j]);
617 int64_t step = forOp.getConstantStep()->getSExtValue();
619 forOp.getLoc(), forOp.getStep(),
621 forOp.getLoc(), rewriter.
getIndexAttr(unrollJamFactor)));
622 forOp.setStep(newStep);
623 auto forOpIV = forOp.getInductionVar();
626 for (
unsigned i = unrollJamFactor - 1; i >= 1; --i) {
627 for (
auto &subBlock : subBlocks) {
630 OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
639 builder.
createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag);
640 operandMaps[i - 1].map(forOpIV, ivUnroll);
643 for (
auto it = subBlock.first; it != std::next(subBlock.second); ++it)
644 builder.
clone(*it, operandMaps[i - 1]);
647 for (
auto newForOp : newInnerLoops) {
648 unsigned oldNumIterOperands =
649 newForOp.getNumRegionIterArgs() / unrollJamFactor;
650 unsigned numControlOperands = newForOp.getNumControlOperands();
651 auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
652 unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
653 assert(oldNumIterOperands == oldNumYieldOperands &&
654 "oldNumIterOperands must be the same as oldNumYieldOperands");
655 for (
unsigned j = 0;
j < oldNumIterOperands; ++
j) {
659 newForOp.setOperand(numControlOperands + i * oldNumIterOperands +
j,
660 operandMaps[i - 1].lookupOrDefault(
661 newForOp.getOperand(numControlOperands +
j)));
663 i * oldNumYieldOperands +
j,
664 operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(
j)));
670 (void)forOp.promoteIfSingleIteration(rewriter);
677 Range normalizedLoopBounds;
683 normalizedLoopBounds.
size =
685 return normalizedLoopBounds;
697 bool isZeroBased =
false;
699 isZeroBased = lbCst.value() == 0;
701 bool isStepOne =
false;
703 isStepOne = stepCst.value() == 1;
707 "expected matching types");
712 if (isZeroBased && isStepOne)
713 return {lb, ub, step};
723 newUpperBound = rewriter.
createOrFold<arith::CeilDivSIOp>(
731 return {newLowerBound, newUpperBound, newStep};
745 Value denormalizedIvVal =
752 if (
Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
753 preservedUses.insert(preservedUse);
762 if (
getType(origLb).isIndex()) {
766 Value denormalizedIv;
771 Value scaled = normalizedIv;
773 Value origStepValue =
775 scaled = arith::MulIOp::create(rewriter, loc, normalizedIv, origStepValue);
778 denormalizedIv = scaled;
781 denormalizedIv = arith::AddIOp::create(rewriter, loc, scaled, origLbValue);
790 assert(!values.empty() &&
"unexecpted empty array");
795 for (
auto v : values) {
805 assert(!values.empty() &&
"unexpected empty list");
811 std::optional<Value> productOf;
812 for (
auto v : values) {
814 if (vOne && vOne.value() == 1)
817 productOf = arith::MulIOp::create(rewriter, loc, productOf.value(), v)
823 productOf = arith::ConstantOp::create(
827 return productOf.value();
843 Operation *delinearizedOp = affine::AffineDelinearizeIndexOp::create(
844 rewriter, loc, linearizedIv, ubs);
845 auto resultVals = llvm::map_to_vector(
853 llvm::BitVector isUbOne(ubs.size());
856 if (ubCst && ubCst.value() == 1)
861 unsigned numLeadingOneUbs = 0;
863 if (!isUbOne.test(index)) {
866 delinearizedIvs[index] = arith::ConstantOp::create(
867 rewriter, loc, rewriter.
getZeroAttr(ub.getType()));
871 Value previous = linearizedIv;
872 for (
unsigned i = numLeadingOneUbs, e = ubs.size(); i < e; ++i) {
873 unsigned idx = ubs.size() - (i - numLeadingOneUbs) - 1;
874 if (i != numLeadingOneUbs && !isUbOne.test(idx + 1)) {
875 previous = arith::DivSIOp::create(rewriter, loc, previous, ubs[idx + 1]);
880 if (!isUbOne.test(idx)) {
881 iv = arith::RemSIOp::create(rewriter, loc, previous, ubs[idx]);
884 iv = arith::ConstantOp::create(
885 rewriter, loc, rewriter.
getZeroAttr(ubs[idx].getType()));
888 delinearizedIvs[idx] = iv;
890 return {delinearizedIvs, preservedUsers};
895 if (loops.size() < 2)
898 scf::ForOp innermost = loops.back();
899 scf::ForOp outermost = loops.front();
903 for (
auto loop : loops) {
906 Value lb = loop.getLowerBound();
907 Value ub = loop.getUpperBound();
908 Value step = loop.getStep();
914 newLoopRange.offset));
918 newLoopRange.stride));
922 loop.getInductionVar(), lb, step);
931 loops, [](
auto loop) {
return loop.getUpperBound(); });
933 outermost.setUpperBound(upperBound);
937 rewriter, loc, outermost.getInductionVar(), upperBounds);
941 for (
int i = loops.size() - 1; i > 0; --i) {
942 auto outerLoop = loops[i - 1];
943 auto innerLoop = loops[i];
945 Operation *innerTerminator = innerLoop.getBody()->getTerminator();
946 auto yieldedVals = llvm::to_vector(innerTerminator->
getOperands());
947 assert(llvm::equal(outerLoop.getRegionIterArgs(), innerLoop.getInitArgs()));
948 for (
Value &yieldedVal : yieldedVals) {
951 auto iter = llvm::find(innerLoop.getRegionIterArgs(), yieldedVal);
952 if (iter != innerLoop.getRegionIterArgs().end()) {
953 unsigned iterArgIndex = iter - innerLoop.getRegionIterArgs().begin();
955 assert(iterArgIndex < innerLoop.getInitArgs().size());
956 yieldedVal = innerLoop.getInitArgs()[iterArgIndex];
959 rewriter.
eraseOp(innerTerminator);
962 innerBlockArgs.push_back(delinearizeIvs[i]);
963 llvm::append_range(innerBlockArgs, outerLoop.getRegionIterArgs());
966 rewriter.
replaceOp(innerLoop, yieldedVals);
975 IRRewriter rewriter(loops.front().getContext());
980 LogicalResult result(failure());
990 for (
unsigned i = 0, e = loops.size(); i < e; ++i) {
991 operandsDefinedAbove[i] = i;
992 for (
unsigned j = 0;
j < i; ++
j) {
994 loops[i].getUpperBound(),
997 operandsDefinedAbove[i] =
j;
1008 iterArgChainStart[0] = 0;
1009 for (
unsigned i = 1, e = loops.size(); i < e; ++i) {
1011 iterArgChainStart[i] = i;
1012 auto outerloop = loops[i - 1];
1013 auto innerLoop = loops[i];
1014 if (outerloop.getNumRegionIterArgs() != innerLoop.getNumRegionIterArgs()) {
1017 if (!llvm::equal(outerloop.getRegionIterArgs(), innerLoop.getInitArgs())) {
1020 auto outerloopTerminator = outerloop.getBody()->getTerminator();
1021 if (!llvm::equal(outerloopTerminator->getOperands(),
1022 innerLoop.getResults())) {
1025 iterArgChainStart[i] = iterArgChainStart[i - 1];
1031 for (
unsigned end = loops.size(); end > 0; --end) {
1033 for (; start < end - 1; ++start) {
1035 *std::max_element(std::next(operandsDefinedAbove.begin(), start),
1036 std::next(operandsDefinedAbove.begin(), end));
1039 if (iterArgChainStart[end - 1] > start)
1048 if (start != end - 1)
1056 ArrayRef<std::vector<unsigned>> combinedDimensions) {
1062 auto sortedDimensions = llvm::to_vector<3>(combinedDimensions);
1063 for (
auto &dims : sortedDimensions)
1068 for (
unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
1071 Value lb = loops.getLowerBound()[i];
1072 Value ub = loops.getUpperBound()[i];
1073 Value step = loops.getStep()[i];
1076 rewriter, loops.getLoc(), newLoopRange.size));
1087 for (
auto &sortedDimension : sortedDimensions) {
1089 for (
auto idx : sortedDimension) {
1090 newUpperBound = arith::MulIOp::create(rewriter, loc, newUpperBound,
1091 normalizedUpperBounds[idx]);
1093 lowerBounds.push_back(cst0);
1094 steps.push_back(cst1);
1095 upperBounds.push_back(newUpperBound);
1104 auto newPloop = scf::ParallelOp::create(
1105 rewriter, loc, lowerBounds, upperBounds, steps,
1107 for (
unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
1108 Value previous = ploopIVs[i];
1109 unsigned numberCombinedDimensions = combinedDimensions[i].size();
1111 for (unsigned j = numberCombinedDimensions - 1; j > 0; --j) {
1112 unsigned idx = combinedDimensions[i][j];
1115 Value iv = arith::RemSIOp::create(insideBuilder, loc, previous,
1116 normalizedUpperBounds[idx]);
1117 replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv,
1122 previous = arith::DivSIOp::create(insideBuilder, loc, previous,
1123 normalizedUpperBounds[idx]);
1127 unsigned idx = combinedDimensions[i][0];
1129 previous, loops.getRegion());
1134 loops.getBody()->back().erase();
1135 newPloop.getBody()->getOperations().splice(
1137 loops.getBody()->getOperations());
1150 return op != inner.getOperation();
1153 LogicalResult status = success();
1155 for (
auto &op : outer.getBody()->without_terminator()) {
1157 if (&op == inner.getOperation())
1160 if (forwardSlice.count(&op) > 0) {
1165 if (isa<scf::ForOp>(op))
1168 if (op.getNumRegions() > 0) {
1178 toHoist.push_back(&op);
1180 auto *outerForOp = outer.getOperation();
1181 for (
auto *op : toHoist)
1182 op->moveBefore(outerForOp);
1191 LogicalResult status = success();
1192 const Loops &interTile = tileLoops.first;
1193 const Loops &intraTile = tileLoops.second;
1194 auto size = interTile.size();
1195 assert(size == intraTile.size());
1198 for (
unsigned s = 1; s < size; ++s)
1199 status = succeeded(status) ?
hoistOpsBetween(intraTile[0], intraTile[s])
1201 for (
unsigned s = 1; s < size; ++s)
1202 status = succeeded(status) ?
hoistOpsBetween(interTile[0], interTile[s])
1211 template <
typename T>
1215 for (
unsigned i = 0; i < maxLoops; ++i) {
1216 forOps.push_back(rootForOp);
1218 if (body.
begin() != std::prev(body.
end(), 2))
1221 rootForOp = dyn_cast<T>(&body.
front());
1229 assert(!forOp.getUnsignedCmp() &&
"unsigned loops are not supported");
1230 auto originalStep = forOp.getStep();
1231 auto iv = forOp.getInductionVar();
1234 forOp.setStep(arith::MulIOp::create(b, forOp.getLoc(), originalStep, factor));
1237 for (
auto t : targets) {
1238 assert(!t.getUnsignedCmp() &&
"unsigned loops are not supported");
1241 auto begin = t.getBody()->begin();
1242 auto nOps = t.getBody()->getOperations().size();
1245 auto b = OpBuilder::atBlockTerminator((t.getBody()));
1246 Value stepped = arith::AddIOp::create(b, t.getLoc(), iv, forOp.getStep());
1248 arith::MinSIOp::create(b, t.getLoc(), forOp.getUpperBound(), stepped);
1251 auto newForOp = scf::ForOp::create(b, t.getLoc(), iv, ub, originalStep);
1252 newForOp.getBody()->getOperations().splice(
1253 newForOp.getBody()->getOperations().begin(),
1254 t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
1256 newForOp.getRegion());
1258 innerLoops.push_back(newForOp);
1266 template <
typename SizeType>
1268 scf::ForOp target) {
1274 assert(res.size() == 1 &&
"Expected 1 inner forOp");
1283 for (
auto it : llvm::zip(forOps, sizes)) {
1284 auto step =
stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
1285 res.push_back(step);
1286 currentTargets = step;
1292 scf::ForOp target) {
1295 res.push_back(llvm::getSingleElement(loops));
1303 forOps.reserve(sizes.size());
1305 if (forOps.size() < sizes.size())
1306 sizes = sizes.take_front(forOps.size());
1321 forOps.reserve(sizes.size());
1323 if (forOps.size() < sizes.size())
1324 sizes = sizes.take_front(forOps.size());
1331 tileSizes.reserve(sizes.size());
1332 for (
unsigned i = 0, e = sizes.size(); i < e; ++i) {
1333 assert(sizes[i] > 0 &&
"expected strictly positive size for strip-mining");
1335 auto forOp = forOps[i];
1337 auto loc = forOp.getLoc();
1338 Value diff = arith::SubIOp::create(builder, loc, forOp.getUpperBound(),
1339 forOp.getLowerBound());
1341 Value iterationsPerBlock =
1343 tileSizes.push_back(iterationsPerBlock);
1347 auto intraTile =
tile(forOps, tileSizes, forOps.back());
1348 TileLoops tileLoops = std::make_pair(forOps, intraTile);
1359 scf::ForallOp source,
1361 unsigned numTargetOuts = target.getNumResults();
1362 unsigned numSourceOuts = source.getNumResults();
1366 llvm::append_range(fusedOuts, target.getOutputs());
1367 llvm::append_range(fusedOuts, source.getOutputs());
1371 scf::ForallOp fusedLoop = scf::ForallOp::create(
1372 rewriter, source.getLoc(), source.getMixedLowerBound(),
1373 source.getMixedUpperBound(), source.getMixedStep(), fusedOuts,
1374 source.getMapping());
1378 mapping.
map(target.getInductionVars(), fusedLoop.getInductionVars());
1379 mapping.
map(source.getInductionVars(), fusedLoop.getInductionVars());
1382 mapping.
map(target.getRegionIterArgs(),
1383 fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1384 mapping.
map(source.getRegionIterArgs(),
1385 fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1389 for (
Operation &op : target.getBody()->without_terminator())
1390 rewriter.
clone(op, mapping);
1391 for (
Operation &op : source.getBody()->without_terminator())
1392 rewriter.
clone(op, mapping);
1395 scf::InParallelOp targetTerm = target.getTerminator();
1396 scf::InParallelOp sourceTerm = source.getTerminator();
1397 scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
1399 for (
Operation &op : targetTerm.getYieldingOps())
1400 rewriter.
clone(op, mapping);
1401 for (
Operation &op : sourceTerm.getYieldingOps())
1402 rewriter.
clone(op, mapping);
1405 rewriter.
replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1406 rewriter.
replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1414 assert(source.getUnsignedCmp() == target.getUnsignedCmp() &&
1415 "incompatible signedness");
1416 unsigned numTargetOuts = target.getNumResults();
1417 unsigned numSourceOuts = source.getNumResults();
1421 llvm::append_range(fusedInitArgs, target.getInitArgs());
1422 llvm::append_range(fusedInitArgs, source.getInitArgs());
1427 scf::ForOp fusedLoop = scf::ForOp::create(
1428 rewriter, source.getLoc(), source.getLowerBound(), source.getUpperBound(),
1429 source.getStep(), fusedInitArgs,
nullptr,
1430 source.getUnsignedCmp());
1434 mapping.
map(target.getInductionVar(), fusedLoop.getInductionVar());
1435 mapping.
map(target.getRegionIterArgs(),
1436 fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
1437 mapping.
map(source.getInductionVar(), fusedLoop.getInductionVar());
1438 mapping.
map(source.getRegionIterArgs(),
1439 fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
1443 for (
Operation &op : target.getBody()->without_terminator())
1444 rewriter.
clone(op, mapping);
1445 for (
Operation &op : source.getBody()->without_terminator())
1446 rewriter.
clone(op, mapping);
1450 for (
Value operand : target.getBody()->getTerminator()->getOperands())
1452 for (
Value operand : source.getBody()->getTerminator()->getOperands())
1454 if (!yieldResults.empty())
1455 scf::YieldOp::create(rewriter, source.getLoc(), yieldResults);
1458 rewriter.
replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1459 rewriter.
replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
1465 scf::ForallOp forallOp) {
1470 if (forallOp.isNormalized())
1474 auto loc = forallOp.getLoc();
1477 for (
auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
1478 Range normalizedLoopParams =
1480 newUbs.push_back(normalizedLoopParams.
size);
1486 auto normalizedForallOp = scf::ForallOp::create(
1487 rewriter, loc, newUbs, forallOp.getOutputs(), forallOp.getMapping(),
1491 normalizedForallOp.getBodyRegion(),
1492 normalizedForallOp.getBodyRegion().begin());
1494 rewriter.
eraseBlock(&normalizedForallOp.getBodyRegion().back());
1498 for (
auto [idx, iv] :
1505 rewriter.
replaceOp(forallOp, normalizedForallOp);
1506 return normalizedForallOp;
1511 assert(!loops.empty() &&
"unexpected empty loop nest");
1512 if (loops.size() == 1)
1513 return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
1514 for (
auto [outerLoop, innerLoop] :
1515 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
1516 auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
1517 auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
1518 if (!outerFor || !innerFor)
1520 auto outerBBArgs = outerFor.getRegionIterArgs();
1521 auto innerIterArgs = innerFor.getInitArgs();
1522 if (outerBBArgs.size() != innerIterArgs.size())
1525 for (
auto [outerBBArg, innerIterArg] :
1526 llvm::zip_equal(outerBBArgs, innerIterArgs)) {
1527 if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
1528 innerIterArg != outerBBArg)
1533 cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
1534 ValueRange innerResults = innerFor.getResults();
1535 if (outerYields.size() != innerResults.size())
1537 for (
auto [outerYield, innerResult] :
1538 llvm::zip_equal(outerYields, innerResults)) {
1539 if (!llvm::hasSingleElement(innerResult.getUses()) ||
1540 outerYield != innerResult)
static OpFoldResult getProductOfIndexes(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > values)
static LogicalResult tryIsolateBands(const TileLoops &tileLoops)
static void getPerfectlyNestedLoopsImpl(SmallVectorImpl< T > &forOps, T rootForOp, unsigned maxLoops=std::numeric_limits< unsigned >::max())
Collect perfectly nested loops starting from rootForOps.
static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner)
static void generateUnrolledLoop(Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor, function_ref< Value(unsigned, Value, OpBuilder)> ivRemapFn, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn, ValueRange iterArgs, ValueRange yieldedValues)
Generates unrolled copies of scf::ForOp 'loopBodyBlock', with associated 'forOpIV' by 'unrollFactor',...
static Loops stripmineSink(scf::ForOp forOp, Value factor, ArrayRef< scf::ForOp > targets)
static std::pair< SmallVector< Value >, SmallPtrSet< Operation *, 2 > > delinearizeInductionVariable(RewriterBase &rewriter, Location loc, Value linearizedIv, ArrayRef< Value > ubs)
For each original loop, the value of the induction variable can be obtained by dividing the induction...
static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, int64_t divisor)
static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc, ArrayRef< Value > values)
Helper function to multiply a sequence of values.
static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter, Location loc, Value normalizedIv, OpFoldResult origLb, OpFoldResult origStep)
Range emitNormalizedLoopBoundsForIndexType(RewriterBase &rewriter, Location loc, OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
static bool areInnerBoundsInvariant(scf::ForOp forOp)
Check if bounds of all inner loops are defined outside of forOp and return false if not.
static int64_t product(ArrayRef< int64_t > vals)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Base type for affine expression.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
OpListType::iterator iterator
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
TypedAttr getOneAttr(Type type)
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
static OpBuilder atBlockTerminator(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the block terminator.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
result_range getResults()
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
ParentT getParentOfType()
Find the first parent operation of the given type, or nullptr if there is no ancestor operation.
bool hasOneBlock()
Return true if this region has exactly one block.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndex() const
Return true if this is an integer (of any signedness) or an index type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
void replaceUsesWithIf(Value newValue, function_ref< bool(OpOperand &)> shouldReplace)
Replace all uses of 'this' value with 'newValue' if the given callback returns true.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< SmallVector< AffineForOp, 8 >, 8 > tile(ArrayRef< AffineForOp > forOps, ArrayRef< uint64_t > sizes, ArrayRef< AffineForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
void getPerfectlyNestedLoops(SmallVectorImpl< scf::ForOp > &nestedLoops, scf::ForOp root)
Get perfectly nested sequence of loops starting at root of loop nest (the first op being another Affi...
bool isPerfectlyNestedForLoops(MutableArrayRef< LoopLikeOpInterface > loops)
Check if the provided loops are perfectly nested for-loops.
LogicalResult outlineIfOp(RewriterBase &b, scf::IfOp ifOp, func::FuncOp *thenFn, StringRef thenFnName, func::FuncOp *elseFn, StringRef elseFnName)
Outline the then and/or else regions of ifOp as follows:
void replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion)
Replace all uses of orig within the given region with replacement.
SmallVector< scf::ForOp > replaceLoopNestWithNewYields(RewriterBase &rewriter, MutableArrayRef< scf::ForOp > loopNest, ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn, bool replaceIterOperandsUsesInLoop=true)
Update a perfectly nested loop nest to yield new values from the innermost loop and propagating it up...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op)
Walk an affine.for to find a band to coalesce.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
std::pair< Loops, Loops > TileLoops
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
LogicalResult loopUnrollFull(scf::ForOp forOp)
Unrolls this loop completely.
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops, ArrayRef< std::vector< unsigned >> combinedDimensions)
Take the ParallelLoop and for each set of dimension indices, combine them into a single dimension.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef< Value > sizes)
Tile a nest of scf::ForOp loops rooted at rootForOp with the given (parametric) sizes.
FailureOr< UnrolledLoopInfo > loopUnrollByFactor(scf::ForOp forOp, uint64_t unrollFactor, function_ref< void(unsigned, Operation *, OpBuilder)> annotateFn=nullptr)
Unrolls this for operation by the specified unroll factor.
LogicalResult loopUnrollJamByFactor(scf::ForOp forOp, uint64_t unrollFactor)
Unrolls and jams this scf.for operation by the specified unroll factor.
bool getInnermostParallelLoops(Operation *rootOp, SmallVectorImpl< scf::ParallelOp > &result)
Get a list of innermost parallel loops contained in rootOp.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void getUsedValuesDefinedAbove(Region ®ion, Region &limit, SetVector< Value > &values)
Fill values with a list of values defined at the ancestors of the limit region and used within region...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
FailureOr< func::FuncOp > outlineSingleBlockRegion(RewriterBase &rewriter, Location loc, Region ®ion, StringRef funcName, func::CallOp *callOp=nullptr)
Outline a region with a single block into a new FuncOp.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool areValuesDefinedAbove(Range values, Region &limit)
Check if all values in the provided range are defined above the limit region.
void denormalizeInductionVariable(RewriterBase &rewriter, Location loc, Value normalizedIv, OpFoldResult origLb, OpFoldResult origStep)
Get back the original induction variable values after loop normalization.
scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForallOp source, RewriterBase &rewriter)
Given two scf.forall loops, target and source, fuses target into source.
LogicalResult coalesceLoops(MutableArrayRef< scf::ForOp > loops)
Replace a perfect nest of "for" loops with a single linearized loop.
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter)
Given two scf.for loops, target and source, fuses target into source.
TileLoops extractFixedOuterLoops(scf::ForOp rootFOrOp, ArrayRef< int64_t > sizes)
Range emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc, OpFoldResult lb, OpFoldResult ub, OpFoldResult step)
Materialize bounds and step of a zero-based and unit-step loop derived by normalizing the specified b...
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
FailureOr< scf::ForallOp > normalizeForallOp(RewriterBase &rewriter, scf::ForallOp forallOp)
Normalize an scf.forall operation.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
SmallVector< std::pair< Block::iterator, Block::iterator > > subBlocks
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
std::optional< scf::ForOp > epilogueLoopOp
std::optional< scf::ForOp > mainLoopOp
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.