28#include "llvm/ADT/ScopeExit.h"
29#include "llvm/ADT/TypeSwitch.h"
30#include "llvm/Support/Debug.h"
33#define DEBUG_TYPE "tile-using-interface"
37scf::SCFTilingOptions &
39 assert(!tileSizeComputationFunction &&
"tile sizes already set");
40 auto tileSizes = llvm::to_vector(ts);
47scf::SCFTilingOptions &
49 assert(!numThreadsComputationFunction &&
"num tiles already set");
50 auto numThreads = llvm::to_vector(nt);
61 size_t iterationDomainSize) {
63 if (filledVector.size() < iterationDomainSize) {
64 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
65 filledVector.append(range.begin(), range.end());
67 if (filledVector.size() > iterationDomainSize)
68 filledVector.resize(iterationDomainSize);
78 const scf::SCFTilingOptions &
options) {
80 if (
options.numThreadsComputationFunction &&
81 options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) {
83 loc,
"number of threads can only by specified when loop type is "
84 "set to use `scf.forall`");
88 if (!
options.interchangeVector.empty()) {
91 loc,
"invalid interchange vector, not a permutation of the entire "
103 const scf::SCFTilingOptions &
options) {
106 size_t numLoops = iterationDomain.size();
109 if (
options.numThreadsComputationFunction) {
110 numThreads =
options.numThreadsComputationFunction(rewriter, op);
111 numThreads.resize(numLoops, zero);
114 if (
options.tileSizeComputationFunction) {
115 tileSizes =
options.tileSizeComputationFunction(rewriter, op);
116 tileSizes.resize(numLoops, zero);
117 return {tileSizes, numThreads};
129 tileSizes.resize(numLoops, zero);
130 for (
auto [
index, range, nt] :
131 llvm::enumerate(iterationDomain, numThreads)) {
136 rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
138 tileSizes.resize(numLoops, zero);
139 return {tileSizes, numThreads};
146 assert(
options.tileSizeComputationFunction &&
147 "expected tile sizes to be specified");
148 tileSizes =
options.tileSizeComputationFunction(rewriter, op);
149 tileSizes.resize(numLoops, zero);
151 return {tileSizes, numThreads};
156 scf::SCFTilingOptions::LoopType loopType,
160 auto iterators = op.getLoopIteratorTypes();
161 assert(iterators.size() == givenTileSizes.size() &&
162 "expected as many tile size values as number of loops");
163 assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
164 "when specified, expected number of threads to use for each loop");
166 bool isParallelTiling =
false;
167 for (
auto [
index, iterator, givenTileSize] :
168 llvm::enumerate(iterators, givenTileSizes)) {
170 isParallelTiling |= iterator == utils::IteratorType::parallel;
173 if (loopType == scf::SCFTilingOptions::LoopType::ForallOp &&
177 if (!numThreads.empty()) {
178 if (std::optional<int64_t> constNumThreads =
180 if (constNumThreads.value() > 1 &&
181 iterator != utils::IteratorType::parallel) {
182 op.emitWarning() <<
"tiling is not thread safe at axis #" <<
index;
188 if (std::optional<int64_t> constTileSize =
190 if (constTileSize.value() > 0 &&
191 iterator != utils::IteratorType::parallel) {
192 op.emitWarning() <<
"tiling is not thread safe at axis #" <<
index;
199 if (isParallelTiling) {
200 return op->emitOpError(
"tiling parallel dimensions is not supported with "
201 "partial reduction tiling strategies");
211 const scf::SCFTilingOptions &
options) {
213 for (
auto dim :
options.reductionDims) {
216 reductionDims.insert(dim);
218 return reductionDims;
232 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
241 if (ts && ts.value() == 1)
242 return givenTileSize;
246 return givenTileSize;
268 if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
270 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
281 int materializedLoopNum = 0;
282 for (
auto [givenTileSize, loopRange] :
283 llvm::zip_equal(givenTileSizes, iterationDomain)) {
288 offsets.push_back(loopRange.offset);
289 sizes.push_back(loopRange.size);
293 Value iv = ivs[materializedLoopNum++];
295 offsets.push_back(offset);
298 sizes.push_back(size);
300 return {offsets, sizes};
309 for (
auto [loopRange, givenTileSize] :
310 llvm::zip_equal(loopRanges, givenTileSizes)) {
314 lbs.push_back(loopRange.offset);
315 ubs.push_back(loopRange.size);
316 steps.push_back(givenTileSize);
318 return {lbs, ubs, steps};
372 if (newDestArgs.empty())
374 if (
auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
375 destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
392 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
393 assert(loopRanges.size() == givenTileSizes.size() &&
394 "expected as many tile sizes as loop ranges");
398 std::tie(lbs, ubs, steps) =
409 ValueRange innerDestinationTensors(outerDestinationTensors);
410 for (
auto [lb,
ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
412 scf::ForOp::create(rewriter, loc, lb,
ub, step, innerDestinationTensors,
415 loops.push_back(loop);
416 ivs.push_back(loop.getInductionVar());
418 innerDestinationTensors = loop.getRegionIterArgs();
425 std::tie(offsets, sizes) =
430 if (failed(tiledBodyFn(rewriter, loc, ivs, offsets, sizes,
431 innerDestinationTensors, tiledResults, resultOffsets,
434 loc,
"failed to generate inner tile loop body");
439 assert(tiledResults.size() == innerDestinationTensors.size() &&
440 "Number of results of body should be equal to number of iter args");
444 for (
auto [tiledValue, destinationTensor, resultOffset, resultSize] :
445 llvm::zip_equal(tiledResults, innerDestinationTensors, resultOffsets,
449 auto insertSlice = tensor::InsertSliceOp::create(
450 rewriter, loc, tiledValue, destinationTensor, resultOffset, resultSize,
452 yieldedValues.push_back(insertSlice);
454 scf::YieldOp::create(rewriter, loc, yieldedValues);
457 for (
auto [outerLoop, innerLoop] :
461 cast<scf::ForOp>(outerLoop.getOperation()).getBody());
462 scf::YieldOp::create(rewriter, outerLoop.getLoc(), innerLoop->getResults());
479 if (numThreads.empty()) {
485 int materializedLoopNum = 0;
491 offsetExpr = d0 + d1 * s0;
492 residualTileSizeExpr = s1 - (d0 + d1 * s0);
494 for (
auto [
index, nt, givenTileSize, loopRange] :
495 llvm::enumerate(numThreads, givenTileSizes, iterationDomain)) {
500 offsets.push_back(loopRange.offset);
501 sizes.push_back(loopRange.size);
505 Value iv = ivs[materializedLoopNum++];
507 rewriter, loc, offsetExpr,
510 rewriter, loc, residualTileSizeExpr,
511 {loopRange.offset, nt, givenTileSize, loopRange.size});
517 {offset, loopRange.size});
521 {sizeMinusOffsetPerThread, givenTileSize});
537 rewriter, loc, maxMap, {rewriter.
getIndexAttr(0), size});
540 offsets.push_back(offset);
541 sizes.push_back(size);
543 return {offsets, sizes};
556static FailureOr<SmallVector<LoopLikeOpInterface>>
564 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
565 assert(loopRanges.size() == givenTileSizes.size() &&
566 "expected as many tile sizes as loop ranges");
569 std::optional<ArrayAttr> mappingAttr;
570 if (!mappingVector.empty())
573 scf::ForallOp forallOp;
574 bool useNumThreads = !numThreads.empty();
580 for (
auto nt : numThreads) {
583 nonZeroNumThreads.push_back(nt);
585 forallOp = scf::ForallOp::create(rewriter, loc, nonZeroNumThreads,
586 outerDestinationTensors, mappingAttr);
589 std::tie(lbs, ubs, steps) =
591 forallOp = scf::ForallOp::create(rewriter, loc, lbs, ubs, steps,
592 outerDestinationTensors, mappingAttr);
594 loops.push_back(forallOp);
597 ValueRange innerDestinationTensors = forallOp.getRegionOutArgs();
603 rewriter, loc, ivs, loopRanges, givenTileSizes, numThreads);
607 if (failed(tiledBodyFn(rewriter, loc, ivs, offsets, sizes,
608 innerDestinationTensors, tiledResults, resultOffsets,
613 for (
auto [tiledValue, destinationTensor, resultOffset, resultSize] :
614 llvm::zip_equal(tiledResults, innerDestinationTensors, resultOffsets,
619 tensor::ParallelInsertSliceOp::create(rewriter, loc, tiledValue,
620 destinationTensor, resultOffset,
621 resultSize, resultStride);
636static FailureOr<SmallVector<LoopLikeOpInterface>>
640 const scf::SCFTilingOptions::GenerateLoopHeaderFn &generateLoopHeaderFn,
641 const scf::SCFTilingOptions::GenerateLoopTerminatorFn
642 &generateLoopTerminatorFn,
644 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
645 assert(loopRanges.size() == givenTileSizes.size() &&
646 "expected as many tile sizes as loop ranges");
647 assert(generateLoopHeaderFn && generateLoopTerminatorFn &&
648 "expected loop header/terminator generation function");
651 FailureOr<scf::SCFTilingOptions::CustomLoopHeaderInfo> loopHeaderInfo =
652 generateLoopHeaderFn(rewriter, loc, loopRanges, givenTileSizes,
653 outerDestinationTensors);
654 if (failed(loopHeaderInfo)) {
661 if (failed(tiledBodyFn(rewriter, loc, ivs, loopHeaderInfo->tileOffset,
662 loopHeaderInfo->tileSizes,
663 loopHeaderInfo->destinationTensors, tiledResults,
664 resultOffsets, resultSizes))) {
668 if (failed(generateLoopTerminatorFn(rewriter, loc, loopHeaderInfo->loops,
669 tiledResults, resultOffsets, resultSizes,
670 loopHeaderInfo->destinationTensors))) {
674 return loopHeaderInfo->loops;
698 llvm::map_to_vector(loopRanges, [](
Range r) {
return r.
offset; });
700 llvm::map_to_vector(loopRanges, [](
Range r) {
return r.
size; });
701 if (failed(tiledBodyFn(rewriter, loc,
ValueRange{}, tileOffsets, tileSizes,
702 destinationTensors, tiledResults, resultOffsets,
708 if (
options.loopType == scf::SCFTilingOptions::LoopType::ForOp) {
710 destinationTensors, tiledBodyFn);
712 if (
options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
714 rewriter, loc, loopRanges, givenTileSizes, numThreads,
715 options.mappingVector, destinationTensors, tiledBodyFn);
717 if (
options.loopType == scf::SCFTilingOptions::LoopType::CustomOp) {
719 rewriter, loc, loopRanges, givenTileSizes, destinationTensors,
739 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
741 return op->emitOpError(
742 "PartialReductionOuterReduction tiling strategy is only supported for "
743 "operations implementing PartialReductionOpInterface");
748 AffineExpr sizeExpr = ((s0 - s1).ceilDiv(s2));
750 for (
auto [
index, domain, tileSize] :
751 llvm::enumerate(iterationDomain, givenTileSizes)) {
752 if (!numThreads.empty()) {
756 rewriter, op.getLoc(), sizeExpr,
757 {domain.size, domain.offset, domain.stride});
767 rewriter, op.getLoc(), sizeExpr,
768 {domain.size, domain.offset, domain.stride});
772 if (reductionStrategy ==
774 sizes[
index] = tileSize;
778 assert(reductionStrategy ==
781 rewriter, op.getLoc(), sizeExpr,
782 {domain.size, domain.offset, domain.stride});
784 rewriter, op.getLoc(), divExpr, {normalizedRange, tileSize});
786 return redOp.generateInitialTensorForPartialReduction(rewriter, loc, sizes,
800 splitReductionIvs.resize(reductionDims.size(), rewriter.
getIndexAttr(0));
805 if (reductionStrategy ==
807 for (
auto [
index, reductionDim] : llvm::enumerate(reductionDims)) {
808 if (!numThreads.empty()) {
809 splitReductionIvs[
index] = ivs[ivIndex++];
813 rewriter, loc, divExpr,
817 return splitReductionIvs;
820static FailureOr<TilingResult>
829 return op.getTiledImplementation(rewriter, offsets, sizes);
832 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
835 op,
"PartialReductionOuterReduction tiling strategy is only "
836 "supported for operations "
837 "implementing PartialReductionOpInterface");
842 numThreads, givenTileSizes, reductionDims);
843 return redOp.tileToPartialReduction(rewriter, op.getLoc(), reductionStrategy,
844 regionIterArg, offsets, sizes,
845 reductionDims, splitReductionIvs);
859 return op.getResultTilePosition(rewriter,
index, offsets, sizes,
860 resultOffset, resultSize);
862 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
865 op,
"PartialReductionOuterReduction tiling strategy is only supported"
866 "for operations implementing PartialReductionOpInterface");
870 numThreads, givenTileSizes, reductionDims);
871 return redOp.getPartialResultTilePosition(
872 rewriter,
index, reductionStrategy, offsets, sizes, reductionDims,
873 splitReductionIvs, resultOffset, resultSize);
876static FailureOr<MergeResult>
882 "expected merge to be called for only partial reduction cases");
884 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
887 op,
"PartialReductionOuterReduction tiling strategy is only "
888 "supported for operations "
889 "implementing PartialReductionOpInterface");
891 return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
902template <
typename LoopType>
903FailureOr<LoopLikeOpInterface>
919 auto inits = llvm::to_vector(loopOp.getInitArgs());
920 inits.append(newInitOperands.begin(), newInitOperands.end());
921 auto newLoop = scf::ForOp::create(
922 rewriter, loc, loopOp.getLowerBound(), loopOp.getUpperBound(),
924 loopOp.getUnsignedCmp());
927 Block *loopBody = loopOp.getBody();
928 Block *newLoopBody = newLoop.getBody();
930 loopBody, newLoopBody,
933 auto yieldOp = cast<scf::YieldOp>(newLoopBody->
getTerminator());
939 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
940 if (
failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
941 newRegionIterArgs, tiledValues, resultOffsets,
948 for (
auto [tiledValue, regionIterArg, resultOffset, resultSize] :
949 llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
953 Value insert = tensor::InsertSliceOp::create(
954 rewriter, yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset,
955 resultSize, resultStride);
956 newYieldValues.push_back(insert);
961 newLoop->getResults().take_front(loopOp.getNumResults()));
962 return cast<LoopLikeOpInterface>(newLoop.getOperation());
973 auto inits = llvm::to_vector(loopOp.getOutputs());
974 inits.append(newInitOperands.begin(), newInitOperands.end());
975 auto newLoop = scf::ForallOp::create(
976 rewriter, loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
977 loopOp.getMixedStep(), inits, loopOp.getMapping(),
981 Block *loopBody = loopOp.getBody();
982 Block *newLoopBody = newLoop.getBody();
984 loopBody, newLoopBody,
987 auto terminator = cast<scf::InParallelOp>(newLoopBody->
getTerminator());
992 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
993 if (
failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
994 regionIterArgs, tiledValues, resultOffsets,
998 "failed to get yielded tiled values");
1004 for (
auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
1005 tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
1008 tensor::ParallelInsertSliceOp::create(rewriter, terminator.getLoc(),
1009 tiledValue, iterArg, resultOffset,
1010 resultSize, resultStride);
1014 newLoop->getResults().take_front(loopOp.getNumResults()));
1015 return cast<LoopLikeOpInterface>(newLoop.getOperation());
1022 LoopLikeOpInterface loopLikeOp,
RewriterBase &rewriter,
1025 loopLikeOp.getOperation())
1026 .Case<scf::ForOp, scf::ForallOp>(
1027 [&](
auto loopOp) -> FailureOr<LoopLikeOpInterface> {
1029 loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
1031 .Default([&](
auto loopOp) -> FailureOr<LoopLikeOpInterface> {
1050 for (
auto &loop : loops.drop_back()) {
1054 auto forLoop = cast<scf::ForOp>(loop.getOperation());
1058 newInits.append(newInitValues.begin(), newInitValues.end());
1059 auto newLoop = scf::ForOp::create(
1060 rewriter, forLoop.getLoc(), forLoop.getLowerBound(),
1061 forLoop.getUpperBound(), forLoop.getStep(), newInits,
1063 forLoop.getUnsignedCmp());
1067 sourceBlockArgs.push_back(newLoop.getInductionVar());
1068 auto newRegionIterArgs = newLoop.getRegionIterArgs();
1069 sourceBlockArgs.append(
1070 newRegionIterArgs.begin(),
1071 std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
1072 rewriter.
mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
1074 forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
1076 ivs.push_back(newLoop.getInductionVar());
1077 newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
1081 LoopLikeOpInterface innerMostLoop = loops.back();
1082 FailureOr<LoopLikeOpInterface> newInnerMostLoop =
1084 getNewTiledYieldsFn);
1086 if (failed(newInnerMostLoop))
1087 return innerMostLoop.emitOpError(
"failed to return additional yields");
1088 loops.back() = newInnerMostLoop.value();
1092 for (
auto [outerLoop, innerLoop] :
1093 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
1095 auto outerForLoop = cast<scf::ForOp>(outerLoop.getOperation());
1096 auto outerLoopYield =
1097 cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
1099 llvm::to_vector(outerLoopYield.getOperands());
1101 innerLoop->getResults().take_back(newInitValues.size());
1102 newYields.append(additionalYields.begin(), additionalYields.end());
1111FailureOr<scf::SCFTilingResult>
1112mlir::scf::tileUsingSCF(
RewriterBase &rewriter, TilingInterface op,
1113 const scf::SCFTilingOptions &
options) {
1126 std::tie(givenTileSizes, numThreads) =
1132 givenTileSizes, numThreads))) {
1143 if (!
options.interchangeVector.empty()) {
1145 iterationDomain.size());
1147 "expected interchange vector to be a permutation");
1151 if (!numThreads.empty())
1155 FailureOr<TilingResult> tilingResult;
1169 if (!interchangeVector.empty()) {
1178 auto clonedOp = cast<TilingInterface>(
1185 tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
1195 regionIterArgs, tileOffsetsVec, tileSizesVec,
1196 ivs, numThreads, givenTileSizes, reductionDims);
1197 if (
failed(tilingResult)) {
1199 return op.emitOpError(
"faild to tile operation");
1207 for (
auto [
index, tiledValue] :
1208 llvm::enumerate(tilingResult->tiledValues)) {
1209 tiledResults.push_back(tiledValue);
1212 rewriter,
options.reductionStrategy,
index, tiledValue, op,
1213 tileOffsetsVec, tileSizesVec, ivs, numThreads, givenTileSizes,
1214 reductionDims, resultOffset, resultSize))) {
1215 for (
auto op : tilingResult->tiledOps) {
1219 op,
"failed to get slice of result produced");
1221 resultOffsets.emplace_back(std::move(resultOffset));
1222 resultSizes.emplace_back(std::move(resultSize));
1230 rewriter, op,
options.reductionStrategy, iterationDomain, numThreads,
1231 givenTileSizes, reductionDims);
1232 if (
failed(maybeInits)) {
1234 op,
"unable to create initial tensors for tiling");
1242 rewriter, op.getLoc(),
options, iterationDomain, givenTileSizes,
1243 numThreads, initTensors, innerYieldTiledValuesFn);
1245 return op.emitOpError(
"failed to generate tiling loops");
1246 assert(succeeded(tilingResult) &&
1247 "expected tiling result to be computed after loop generation");
1248 std::swap(loops, loopsOr.value());
1251 if (loops.empty()) {
1254 return scf::SCFTilingResult{tilingResult->tiledOps,
1257 tilingResult->tiledValues,
1258 tilingResult->generatedSlices,
1262 auto loopResults = llvm::map_to_vector(loops.front()->getResults(),
1267 return scf::SCFTilingResult{
1268 tilingResult->tiledOps, initTensors, loops, loopResults,
1269 tilingResult->generatedSlices, {}};
1274 rewriter, op,
options.reductionStrategy, reductionDims, loopResults);
1275 if (
failed(mergeResult)) {
1277 op,
"Failed to merge partial results from tiling");
1279 return scf::SCFTilingResult{tilingResult->tiledOps,
1282 mergeResult->replacements,
1283 tilingResult->generatedSlices,
1284 mergeResult->mergeOps};
1287FailureOr<scf::SCFTilingResult>
1289 PartialReductionOpInterface op,
1291 scf::SCFTilingOptions
options;
1292 options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
1293 options.setReductionTilingStrategy(
1295 options.setTileSizes(tileSize);
1297 for (
auto [
index, iteratorType] : llvm::enumerate(op.getLoopIteratorTypes()))
1298 if (iteratorType == utils::IteratorType::reduction)
1299 reductionDims.push_back(
index);
1300 options.setReductionDims(reductionDims);
1301 return tileUsingSCF(
b, op,
options);
1314static std::tuple<OpResult, std::optional<OpOperand *>>
1317 std::optional<OpOperand *> destinationIterArg;
1318 assert(!loops.empty() &&
"expected non empty loops container");
1319 auto loopIt = loops.rbegin();
1320 while (loopIt != loops.rend() && isa<BlockArgument>(source->
get())) {
1321 auto iterArg = cast<BlockArgument>(source->
get());
1322 auto loop = *loopIt;
1323 if (iterArg.getOwner()->getParentOp() != loop)
1325 source = loop.getTiedLoopInit(iterArg);
1328 if (loopIt == loops.rend())
1329 destinationIterArg = source;
1331 auto result = dyn_cast<OpResult>(source->
get());
1334 Operation *innermostLoop = loops.back();
1341 return {
result, destinationIterArg};
1346std::optional<scf::SCFFuseProducerOfSliceResult>
1347mlir::scf::tileAndFuseProducerOfSlice(
1348 RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1352 auto [fusableProducer, destinationInitArg] =
1355 if (!fusableProducer)
1356 return std::nullopt;
1357 unsigned resultNumber = fusableProducer.getResultNumber();
1365 Operation *fusableProducerOp = fusableProducer.getOwner();
1366 if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
1368 rewriter, fusableProducerOp->
getLoc(), fusableProducerOp,
1369 origDestinationTensors)))
1370 return std::nullopt;
1372 clonedOpDestinationTensors = origDestinationTensors;
1373 if (destinationInitArg &&
1374 isa<DestinationStyleOpInterface>(fusableProducerOp)) {
1378 clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1382 rewriter, fusableProducerOp, clonedOpDestinationTensors);
1387 llvm::to_vector(candidateSliceOp->getOperands());
1388 candidateSliceOpOperands[0] = clonedProducerOp->
getResult(resultNumber);
1389 tensor::ExtractSliceOp clonedCandidateSliceOp =
1391 candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1394 FailureOr<TilingResult> tileAndFuseResult =
1396 rewriter, clonedCandidateSliceOp,
1397 clonedProducerOp->
getResult(resultNumber));
1398 if (
failed(tileAndFuseResult))
1399 return std::nullopt;
1403 tileAndFuseResult->tiledValues[0]);
1404 rewriter.
eraseOp(clonedCandidateSliceOp);
1405 rewriter.
eraseOp(clonedProducerOp);
1450 if (destinationInitArg &&
1451 isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
1453 ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1454 .set(origDestinationTensors[resultNumber]);
1456 return scf::SCFFuseProducerOfSliceResult{
1457 fusableProducer, tileAndFuseResult->tiledValues[0],
1458 tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices};
1462FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
1463 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1464 scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
1470 Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
1471 *tiledOwner = fusedProducerInfo.tiledOps[0];
1476 yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
1478 :
llvm::to_vector(yieldResultNumber);
1480 for (
const auto &resultNumber : initNumberList) {
1482 rewriter, loc, originalOwner->
getResult(resultNumber));
1483 if (succeeded(initValue)) {
1484 initValueList.push_back(initValue.value());
1500 sliceSizes = sliceOp.getMixedSizes();
1503 if (!llvm::all_of(sliceOp.getMixedStrides(),
isOneInteger))
1506 unsigned sliceResultNumber =
1507 fusedProducerInfo.origProducer.getResultNumber();
1509 auto tilableOp = cast<TilingInterface>(originalOwner);
1513 if (tilableOp->getNumResults() > 1 &&
1514 failed(tilableOp.getIterationDomainTileFromResultTile(
1515 rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1516 iterDomainOffset, iterDomainSizes))) {
1531 for (
const auto &resultNumber : initNumberList) {
1532 if (resultNumber == sliceResultNumber) {
1533 offsetList.push_back(sliceOffset);
1534 sizesList.push_back(sliceSizes);
1536 assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1539 if (
failed(tilableOp.getResultTilePosition(
1540 rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1544 offsetList.push_back(offset);
1545 sizesList.push_back(sizes);
1551 if (
auto tiledDestStyleOp =
1552 dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1554 for (
const auto &&[
index, newRegionArg] :
1555 llvm::enumerate(newRegionIterArgs)) {
1556 auto destSlice = tensor::ExtractSliceOp::create(
1557 rewriter, loc, newRegionArg, offsetList[
index], sizesList[
index],
1560 generatedSlices.push_back(destSlice);
1561 unsigned resultNumber = initNumberList[
index];
1563 tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1572 for (
const auto &&[
index, resultNumber] : llvm::enumerate(initNumberList)) {
1573 tiledResult.push_back(tiledOwner->getResult(resultNumber));
1574 tiledOffset.emplace_back(offsetList[
index]);
1575 tiledSizes.emplace_back(sizesList[
index]);
1581 newYieldValuesFn))) {
1584 return generatedSlices;
1598 explicit SliceTrackingListener(
1599 std::optional<FrozenRewritePatternSet> patterns);
1600 SliceTrackingListener() =
default;
1606 LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
1609 void notifyOperationInserted(Operation *op,
1610 OpBuilder::InsertPoint previous)
override;
1613 void removeOp(Operation *op);
1616 void notifyOperationErased(Operation *op)
override;
1623 std::deque<tensor::ExtractSliceOp> worklist;
1628 std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
1631SliceTrackingListener::SliceTrackingListener(
1632 std::optional<FrozenRewritePatternSet> p) {
1637SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
1638 for (Operation *op : ops) {
1639 if (
auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1640 worklist.push_back(slice);
1648 GreedyRewriteConfig().setListener(
this).setStrictness(
1649 GreedyRewriteStrictness::ExistingAndNewOps));
1652void SliceTrackingListener::notifyOperationInserted(
1653 Operation *op, OpBuilder::InsertPoint previous) {
1654 auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1657 worklist.push_back(slice);
1663void SliceTrackingListener::removeOp(Operation *op) {
1664 if (!isa<tensor::ExtractSliceOp>(op))
1666 auto iter = worklist.begin();
1667 while (iter != worklist.end()) {
1672 if (iter == worklist.end())
1675 worklist.erase(iter);
1678void SliceTrackingListener::notifyOperationErased(Operation *op) {
1682void SliceTrackingListener::notifyOperationReplaced(Operation *op,
1694class ReplacementListener :
public RewriterBase::ForwardingListener {
1696 ReplacementListener(DenseMap<Value, Value> &replacements,
1697 OpBuilder::Listener *listener)
1698 : ForwardingListener(listener), replacements(replacements) {}
1700 void updateReplacementValues(
ValueRange origValues,
1704 for (
auto &[key, val] : replacements) {
1705 for (
auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) {
1713 void notifyOperationReplaced(Operation *op, Operation *newOp)
override {
1714 ForwardingListener::notifyOperationReplaced(op, newOp);
1718 void notifyOperationReplaced(Operation *op,
ValueRange values)
override {
1719 ForwardingListener::notifyOperationReplaced(op, values);
1720 updateReplacementValues(op->
getResults(), values);
1724 DenseMap<Value, Value> &replacements;
1730FailureOr<scf::SCFTileAndFuseResult>
1731mlir::scf::tileConsumerAndFuseProducersUsingSCF(
1732 RewriterBase &rewriter, TilingInterface consumer,
1733 const scf::SCFTileAndFuseOptions &
options) {
1736 if (!consumer->getNumResults()) {
1738 consumer,
"invalid pattern for op with no results");
1744 FailureOr<scf::SCFTilingResult> tilingResult =
1745 tileUsingSCF(rewriter, consumer,
options.tilingOptions);
1747 if (
failed(tilingResult))
1749 tiledAndFusedOps.insert_range(tilingResult->tiledOps);
1751 DenseMap<Value, Value> replacements;
1753 llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
1758 auto &loops = tilingResult->loops;
1759 if (loops.empty()) {
1760 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1767 OpBuilder::Listener *previousListener = rewriter.
getListener();
1768 llvm::scope_exit resetListener(
1769 [&]() { rewriter.
setListener(previousListener); });
1770 ReplacementListener replaceListener(replacements, previousListener);
1780 struct WorklistItem {
1781 tensor::ExtractSliceOp candidateSlice;
1782 SCFTileAndFuseOptions::ControlFnResult controlFnResult;
1785 SliceTrackingListener sliceTracker =
1786 SliceTrackingListener(
options.cleanupPatterns);
1789 sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1792 OpBuilder::InsertionGuard g(rewriter);
1793 while (!sliceTracker.worklist.empty()) {
1794 auto candidateSlice = sliceTracker.worklist.front();
1795 sliceTracker.worklist.pop_front();
1797 auto [fusableProducer, destinationInitArg] =
1800 if (!fusableProducer)
1803 std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1804 options.fusionControlFn(candidateSlice, fusableProducer,
1805 destinationInitArg.has_value());
1806 if (!controlFnResult)
1809 WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
1814 std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1815 tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice,
1820 SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
1822 if (worklistItem.controlFnResult.yieldProducerReplacement) {
1827 Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1828 FailureOr<SmallVector<Operation *>> newSlices =
1829 yieldReplacementForFusedProducer(rewriter,
1830 worklistItem.candidateSlice,
1831 fusedResult.value(), loops);
1834 fusableProducerOp,
"failed to replacement value for this "
1835 "operation from within the tiled loop");
1837 worklistCandidates.append(newSlices.value());
1838 for (
auto [index,
result] :
1839 llvm::enumerate(fusableProducerOp->
getResults())) {
1840 replacements[
result] = loops.front()->getResult(
1841 loops.front()->getNumResults() -
1845 if (Operation *tiledAndFusedOp =
1846 fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1847 fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1848 tiledAndFusedOps.insert(tiledAndFusedOp);
1851 if (
failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1856 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1870 if (!llvm::hasSingleElement(uses)) {
1871 LLVM_DEBUG(llvm::dbgs() <<
"Too many uses of the candidate slice op\n");
1874 OpOperand &operandUse = (*uses.begin());
1876 if (!isa<scf::YieldOp>(userOp)) {
1877 LLVM_DEBUG(llvm::dbgs()
1878 <<
"Expected scf.yield to be the only user, but got -> "
1883 LLVM_DEBUG(llvm::dbgs() <<
"Expected tensor.insert_slice and scf.yield to "
1884 "be in the same block\n");
1893 if (!isa<LoopLikeOpInterface>(loopOp))
1912 if (isa<tensor::ParallelInsertSliceOp>(userOp))
1913 userOp = userOp->getParentOfType<scf::InParallelOp>();
1915 if (loopOp->
getBlock() != userOp->getBlock())
1919 firstUserOfLoop = userOp;
1921 return firstUserOfLoop;
1962static FailureOr<llvm::SetVector<Operation *>>
1964 bool reorderOperations) {
1966 if (failed(firstUserOfLoop))
1972 options.omitBlockArguments =
true;
1973 bool includeLoopOp =
false;
1976 includeLoopOp =
true;
1986 assert(
result.succeeded() &&
"expected a backward slice");
1990 if (!slice.empty()) {
2000 if (includeLoopOp || !reorderOperations)
2012 unsigned resultNumber) {
2013 if (!isa<LoopLikeOpInterface>(loopOp))
2018 Operation *consumerOp = opOperand.getOwner();
2020 if (!isa<TilingInterface>(consumerOp) ||
2021 !isa<DestinationStyleOpInterface>(consumerOp)) {
2028 if (loopBlock != consumerOp->
getBlock())
2035 FailureOr<llvm::SetVector<Operation *>> slice =
2041 if (!slice->empty()) {
2044 assert(succeeded(firstUserOfLoop) &&
"First user of loop is not found");
2045 for (
auto op : *slice) {
2060static FailureOr<OpOperand *>
2062 tensor::InsertSliceOp candidateSliceOp,
2064 assert(!loops.empty() &&
"unexpected loops to be empty");
2067 if (containingOp != loops.back()) {
2070 "expected slice to be within body of inner-most loop");
2076 candidateSliceOp,
"expected passed loops to be perfectly nested.");
2081 Value sliceResult = candidateSliceOp.getResult();
2087 scf::ForOp topLevelForOp = cast<scf::ForOp>(loops.front().getOperation());
2094static FailureOr<OpOperand *>
2096 tensor::ParallelInsertSliceOp candidateSliceOp,
2098 assert(!loops.empty() &&
"unexpected loops to be empty");
2100 if (loops.size() != 1) {
2102 candidateSliceOp,
"expected single surrounding scf.forall");
2104 auto forallOp = dyn_cast<scf::ForallOp>(loops.front().getOperation());
2107 candidateSliceOp,
"expected single surrounding scf.forall");
2111 Value sliceDest = candidateSliceOp.getDest();
2112 auto iterArg = dyn_cast<BlockArgument>(sliceDest);
2115 if (iterArg.getOwner()->getParentOp() != forallOp)
2118 unsigned resultNumber =
2119 forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
2130 assert(!loops.empty() &&
"unexpected empty loops");
2131 assert(!sliceOps.empty() &&
"unexpected empty list of candidate slices");
2133 for (
auto sliceOp : sliceOps) {
2134 FailureOr<OpOperand *> fusedOperand =
2136 .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2143 if (failed(fusedOperand)) {
2146 if (!fusedOperands.empty() &&
2147 fusedOperand.value()->getOwner() != fusedOperands.front()->getOwner()) {
2149 fusedOperand.value()->getOwner(),
2150 "all candidate slices must be to the same consumer");
2152 fusedOperands.push_back(fusedOperand.value());
2154 return fusedOperands;
2157template <
typename InsertSliceOpTy>
2159 InsertSliceOpTy sliceOp);
2162tensor::InsertSliceOp
2164 tensor::InsertSliceOp insertSliceOp) {
2165 return cast<tensor::InsertSliceOp>(
2166 rewriter.
clone(*insertSliceOp.getOperation()));
2171 RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp) {
2172 return tensor::InsertSliceOp::create(
2173 rewriter, insertSliceOp->getLoc(), insertSliceOp.getSource(),
2174 insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(),
2175 insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
2178static SmallVector<tensor::InsertSliceOp>
2181 assert(!candidateSlices.empty() &&
2182 "unexpected empty list of slices to clone");
2184 for (
auto sliceOp : candidateSlices) {
2186 .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2189 clonedSlices.push_back(clonedOp);
2192 .DefaultUnreachable(
2193 "unexpected slice type while cloning as insert slice");
2195 return clonedSlices;
2198static FailureOr<scf::SCFFuseConsumerOfSliceResult>
2203 assert(!loops.empty() &&
"expected loops to be not empty");
2208 loops.front(),
"the first user of loop should not dominate any define "
2209 "of consumer operand(s)");
2212 LoopLikeOpInterface outerMostLoop = loops.front();
2213 LoopLikeOpInterface innerMostLoop = loops.back();
2217 auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
2220 "consumer op is not DPS operation");
2221 if (llvm::any_of(consumerOpOperands, [&](
OpOperand *opOperand) {
2222 return dstOp.isDpsInit(opOperand);
2226 "consumer op taking the result of scf.for as init is not supported");
2233 if (failed(firstUserOfLoop)) {
2235 outerMostLoop,
"could not find the first user of outer most loop");
2237 rewriter.
moveOpBefore(outerMostLoop, *firstUserOfLoop);
2244 dyn_cast<tensor::ParallelInsertSliceOp>(candidateSlices.front())) {
2245 auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
2255 auto clonedConsumerOp = cast<TilingInterface>(rewriter.
clone(*consumerOp));
2257 llvm::map_to_vector(consumerOpOperands, [](
OpOperand *opOperand) {
2261 llvm::map_to_vector(operandNumbers, [&](
unsigned operandNum) {
2262 return &clonedConsumerOp->getOpOperand(operandNum);
2268 for (
auto [operandToReplace, clonedSliceOp] :
2269 llvm::zip_equal(clonedOpFusedOperandsList, clonedInsertSlices)) {
2270 operandToReplace->set(clonedSliceOp.getResult());
2276 FailureOr<TilingResult> tileAndFuseResult =
2278 clonedOpFusedOperandsList);
2279 if (failed(tileAndFuseResult)) {
2283 auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
2284 for (
auto [operandNum, clonedSliceOp] :
2285 llvm::zip_equal(operandNumbers, clonedInsertSlices)) {
2287 clonedSliceOp.getSource());
2301 for (
auto candidateSliceOp : clonedInsertSlices) {
2309 candidateSliceOp,
"containingOp's result yield with stride");
2312 allOffsets.emplace_back(std::move(offsets));
2313 allSizes.emplace_back(std::move(sizes));
2323 if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTiles(
2324 rewriter, operandNumbers, allOffsets, allSizes, iterDomainOffsets,
2325 iterDomainSizes))) {
2328 "can't get iter domain position from input position");
2334 unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
2336 totalNumResultsOfConsumer);
2338 totalNumResultsOfConsumer);
2339 for (
auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
2340 if (failed(tiledConsumerOp.getResultTilePosition(
2341 rewriter, idx, iterDomainOffsets, iterDomainSizes,
2342 resultOffsets[idx], resultSizes[idx]))) {
2345 "can't get result domain position from iter domain position");
2351 if (
auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
2352 tiledConsumerOp.getOperation())) {
2354 for (
const auto &&[
index, newRegionArg] :
2355 llvm::enumerate(newRegionIterArgs)) {
2356 auto destSlice = tensor::ExtractSliceOp::create(
2357 rewriter, loc, newRegionArg, resultOffsets[
index],
2363 auto dstNumber =
index;
2365 tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
2375 llvm::enumerate(tiledConsumerOp->getResults())) {
2376 tiledResult.push_back(
result);
2377 tiledOffset.emplace_back(resultOffsets[
index]);
2378 tiledSizes.emplace_back(resultSizes[
index]);
2384 newYieldValuesFn))) {
2386 "unable to add new inits to nest loop");
2392 for (
auto &&[oldResult, newResult] :
2394 loops.front()->getResults().take_back(newInits.size()))) {
2399 rewriter.
eraseOp(clonedConsumerOp);
2402 llvm::map_to_vector(operandNumbers, [&](
unsigned operandNum) {
2403 return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum);
2405 auto consumerOpOperandsVec = llvm::to_vector(consumerOpOperands);
2406 return scf::SCFFuseConsumerOfSliceResult{
2407 std::move(consumerOpOperandsVec), std::move(tiledAndFusedOpOperands),
2408 std::move(tileAndFuseResult->tiledOps)};
2413FailureOr<scf::SCFFuseConsumerOfSliceResult>
2414mlir::scf::tileAndFuseConsumerOfSlices(
2415 RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
2416 MutableArrayRef<LoopLikeOpInterface> loops) {
2417 if (candidateSlices.empty()) {
2420 "no candidate slices provided for consumer fusion");
2424 if (loops.empty()) {
2426 candidateSlices.front(),
2427 "cannot call tile and fuse consumer with an empty loop nest");
2430 if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
2431 llvm::all_of(candidateSlices,
2432 llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
2434 candidateSlices.front(),
2435 "candidates slices need to be all `tensor.extract_slice`s or "
2436 "`tensor.parallel_insert_slice`s");
2441 FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperands =
2443 if (
failed(maybeConsumerOpOperands)) {
2445 "could not fetch consumer to fuse");
2447 Operation *consumerOp = maybeConsumerOpOperands->front()->getOwner();
2450 maybeConsumerOpOperands.value(),
2451 candidateSlices, loops);
2457static std::optional<Operation *>
2459 if (
result.getOwner() != forallOp)
2460 return std::nullopt;
2465 if (combiningOps.size() != 1)
2466 return std::nullopt;
2467 return combiningOps[0];
2472static std::optional<Operation *>
2475 assert(!loops.empty() &&
"Expected loops to be not empty");
2476 LoopLikeOpInterface outerMostLoop = loops.front();
2477 if (
auto forallOp = dyn_cast<scf::ForallOp>(outerMostLoop.getOperation())) {
2478 assert(loops.size() == 1 &&
2479 "expected only a single loop when tiling using scf.forall");
2485 while (loops.size() != 1) {
2486 LoopLikeOpInterface loop = loops.front();
2487 if (
result.getOwner() != loop)
2488 return std::nullopt;
2489 auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
2491 return std::nullopt;
2492 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
2493 auto innerForResult =
2494 dyn_cast<OpResult>(yieldOp.getOperand(
result.getResultNumber()));
2495 if (!innerForResult)
2496 return std::nullopt;
2498 loops = loops.drop_front();
2500 LoopLikeOpInterface loop = loops.front();
2501 if (
result.getOwner() != loop)
2502 return std::nullopt;
2503 auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
2505 return std::nullopt;
2506 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
2507 auto insertSliceOp = yieldOp.getOperand(
result.getResultNumber())
2508 .getDefiningOp<tensor::InsertSliceOp>();
2510 return std::nullopt;
2511 return insertSliceOp;
2514FailureOr<scf::SCFFuseConsumerOfSliceResult>
2515mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer,
2516 MutableArrayRef<LoopLikeOpInterface> loops) {
2517 if (!isa<TilingInterface>(consumer)) {
2519 consumer,
"unhandled consumer that does not implement TilingInterface");
2524 if (loops.empty()) {
2526 consumer,
"cannot call tile and fuse consumer with an empty loop nest");
2529 LoopLikeOpInterface outermostLoop = loops.front();
2533 SmallVector<OpOperand *> consumerFusableOperands;
2535 if (opOperand.get().getDefiningOp() == outermostLoop) {
2536 consumerFusableOperands.push_back(&opOperand);
2541 if (consumerFusableOperands.empty()) {
2542 return mlir::scf::SCFFuseConsumerOfSliceResult{consumerFusableOperands,
2543 SmallVector<OpOperand *>{},
2544 SmallVector<Operation *>{}};
2549 SmallVector<Operation *> candidateSlices;
2550 candidateSlices.reserve(consumerFusableOperands.size());
2551 for (OpOperand *opOperand : consumerFusableOperands) {
2552 std::optional<Operation *> slice =
2557 "couldnt find producing insert-slice like operation for operand");
2559 candidateSlices.push_back(slice.value());
2562 rewriter, consumer, consumerFusableOperands, candidateSlices, loops);
2569FailureOr<SmallVector<scf::ForOp>>
2570mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
2571 TilingInterface op) {
2573 if (op->getNumResults() > 0) {
2575 op,
"unable to lower to loops operations with return values");
2578 SmallVector<Range> domain = op.getIterationDomain(rewriter);
2579 SmallVector<Value> ivs;
2580 SmallVector<scf::ForOp> loops;
2581 Location loc = op.getLoc();
2582 for (
auto loopRange : domain) {
2589 auto loop = scf::ForOp::create(rewriter, op.getLoc(), offsetVal, sizeVal,
2591 loops.push_back(loop);
2592 ivs.push_back(loop.getInductionVar());
2595 if (
failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static llvm::ManagedStatic< PassManagerOptions > options
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult givenTileSize, OpFoldResult numThreads, OpFoldResult iterationSize)
Returns true if the maximum tile offset tileSize * numThreads-1 is less than iterationSize.
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static FailureOr< MergeResult > mergeTilingResults(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, const SetVector< unsigned > &reductionDims, ValueRange partialResults)
static std::optional< Operation * > getProducingInsertSliceLikeOp(OpResult result, ArrayRef< LoopLikeOpInterface > loops)
For a given result of the loop nest that is a tiled loop nest, return the insert slice-like op that i...
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< LoopLikeOpInterface > loops)
Return the untiled producer whose slice is used in a tiled consumer.
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop< scf::ForallOp >(scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Implementation of yieldTiledValuesAndReplaceLoop for scf.forall
static FailureOr< OpOperand * > getConsumerFromLoopUses(RewriterBase &rewriter, Operation *loopOp, unsigned resultNumber)
Fetches the OpOperand of the first valid user (and use) of the value val which implements TilingInter...
static LogicalResult checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp)
A utility function that checks whether the only use of the result of a tensor.insert_slice op is in a...
std::function< LogicalResult( RewriterBase &rewriter, Location Loc, ValueRange ivs, ArrayRef< OpFoldResult > tileOffsets, ArrayRef< OpFoldResult > tileSizes, ValueRange outerDestinationTensors, SmallVector< Value > &tiledResults, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> GenerateTiledBodyFn
Typedef for function that implements the body of a tiled loop.
static LogicalResult checkTileSizes(TilingInterface op, scf::SCFTilingOptions::LoopType loopType, ReductionTilingStrategy reductionStrategy, ArrayRef< OpFoldResult > givenTileSizes, ArrayRef< OpFoldResult > numThreads)
Checks if any of the tiled loops are not parallel.
static FailureOr< scf::SCFFuseConsumerOfSliceResult > tileAndFuseConsumerOfSlicesImpl(RewriterBase &rewriter, Operation *consumerOp, ArrayRef< OpOperand * > consumerOpOperands, ArrayRef< Operation * > candidateSlices, MutableArrayRef< LoopLikeOpInterface > loops)
static FailureOr< SmallVector< LoopLikeOpInterface > > generateLoopNestUsingCustomOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes, ValueRange outerDestinationTensors, const scf::SCFTilingOptions::GenerateLoopHeaderFn &generateLoopHeaderFn, const scf::SCFTilingOptions::GenerateLoopTerminatorFn &generateLoopTerminatorFn, GenerateTiledBodyFn tiledBodyFn)
Generate the tile-loop nest using custom loop operation.
static FailureOr< SmallVector< LoopLikeOpInterface > > generateLoopNest(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes, ArrayRef< OpFoldResult > numThreads, ValueRange destinationTensors, GenerateTiledBodyFn tiledBodyFn)
Generate the tile-loop nest using the loop construct specifed in options.
static FailureOr< SmallVector< LoopLikeOpInterface > > generateLoopNestUsingForOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes, ValueRange outerDestinationTensors, GenerateTiledBodyFn tiledBodyFn)
Generate the tile-loop nest using scf.for operation.
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static FailureOr< OpOperand * > getUntiledConsumerFromSlice(RewriterBase &rewriter, tensor::InsertSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Fetch the untiled consumer of the outermost scf.for's result which is yielded by a tensor....
static SmallVector< tensor::InsertSliceOp > cloneAsInsertSlices(RewriterBase &rewriter, ArrayRef< Operation * > candidateSlices)
static FailureOr< SmallVector< OpOperand * > > getUntiledConsumerOperandsFromSlices(RewriterBase &rewriter, ArrayRef< Operation * > sliceOps, MutableArrayRef< LoopLikeOpInterface > loops)
A utility to fetch an untiled consumer of tensor.insert_slice/tensor.parallel_insert_slice.
tensor::InsertSliceOp cloneAsInsertSlice< tensor::ParallelInsertSliceOp >(RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp)
static FailureOr< Operation * > getFirstUserOfLoop(Operation *loopOp)
An utility to get the first user of the given loopOp.
static Operation * cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *op, ValueRange newDestArgs)
Clones the operation and updates the destination if the operation implements the DestinationStyleOpIn...
static FailureOr< llvm::SetVector< Operation * > > checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, bool reorderOperations)
This utility currently checks whether the first userOp of loop is NOT before the last defineOp of con...
std::function< LogicalResult( RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, SmallVector< Value > &tiledValues, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> YieldTiledValuesFn
Typedef for function that allows returning additional yielded values during yieldTiledValuesAndReplac...
tensor::InsertSliceOp cloneAsInsertSlice< tensor::InsertSliceOp >(RewriterBase &rewriter, tensor::InsertSliceOp insertSliceOp)
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > givenTileSizes)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes)
Function to return the bounds of the loops to be generated.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizesWithForAllOp(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > givenTileSizes, ArrayRef< OpFoldResult > numThreads)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, OpFoldResult offset, OpFoldResult givenTileSize)
Returns the bounded tile size given the current offset, loopRange and tileSize, i....
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop< scf::ForOp >(scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Implementation of yieldTiledValuesAndReplaceLoop for scf.for.
static bool tileDividesIterationDomain(Range loopRange)
Check if stride evenly divides the trip count size - offset.
static SetVector< unsigned > getSanitizedReductionDims(ArrayRef< OpFoldResult > givenTileSizes, const scf::SCFTilingOptions &options)
Get the reduction dims that are tiled.
static LogicalResult addInitOperandsToLoopNest(RewriterBase &rewriter, MutableArrayRef< LoopLikeOpInterface > loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn)
Method to add new init values to a loop nest.
static FailureOr< SmallVector< LoopLikeOpInterface > > generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes, ArrayRef< OpFoldResult > numThreads, ArrayRef< Attribute > mappingVector, ValueRange outerDestinationTensors, GenerateTiledBodyFn tiledBodyFn)
Generate the tile-loop nest using scf.forall operation.
static std::optional< Operation * > getProducingParallelInsertSlice(scf::ForallOp forallOp, OpResult result)
For a given result of a forallOp return the tensor.parallel_insert_slice op (or combining op) that is...
static tensor::InsertSliceOp cloneAsInsertSlice(RewriterBase &rewriter, InsertSliceOpTy sliceOp)
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims)
static FailureOr< SmallVector< Value > > createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims)
static SmallVector< OpFoldResult > getSplitReductionIvs(RewriterBase &rewriter, Location loc, ReductionTilingStrategy reductionStrategy, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims)
For the case of ReductionTilingStrategy::PartialReductionOuterParallel the PartialReductionOpInterfac...
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, ArrayRef< Range > iterationDomain, const scf::SCFTilingOptions &options)
Method to instantiate the tile sizes and/or number of threads specified by the user.
static LogicalResult verifyOptions(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options)
Verify the tile size options are set in a consistent manner.
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Append the specified additional newInitOperands operands to the loops existing init operands (or simi...
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class represents an argument of a Block.
Block represents an ordered list of Operations.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
IntegerAttr getIndexAttr(int64_t value)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
unsigned getNumResults()
Return the number of results held by this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
iterator_range< use_iterator > use_range
Operation * getOwner() const
Return the owner of this operand.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
FailureOr< TilingResult > replaceInsertSlicesWithTiledConsumer(OpBuilder &builder, ArrayRef< tensor::InsertSliceOp > sliceOps, ArrayRef< OpOperand * > consumerOperands)
Method to swap tensor.insert_slices with their consumers when the consumer implements the TilingInter...
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Include the generated interface declarations.
bool isPerfectlyNestedForLoops(MutableArrayRef< LoopLikeOpInterface > loops)
Check if the provided loops are perfectly nested for-loops.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
ReductionTilingStrategy
Tiling can be thought of as splitting a dimension into 2 and materializing the outer dimension as a l...
@ PartialReductionOuterReduction
@ PartialReductionOuterParallel
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
llvm::SetVector< T, Vector, Set, N > SetVector
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Container for result values of tiling.