28#include "llvm/ADT/ScopeExit.h"
29#include "llvm/ADT/TypeSwitch.h"
30#include "llvm/Support/Debug.h"
33#define DEBUG_TYPE "tile-using-interface"
37scf::SCFTilingOptions &
39 assert(!tileSizeComputationFunction &&
"tile sizes already set");
40 auto tileSizes = llvm::to_vector(ts);
47scf::SCFTilingOptions &
49 assert(!numThreadsComputationFunction &&
"num tiles already set");
50 auto numThreads = llvm::to_vector(nt);
61 size_t iterationDomainSize) {
63 if (filledVector.size() < iterationDomainSize) {
64 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
65 filledVector.append(range.begin(), range.end());
67 if (filledVector.size() > iterationDomainSize)
68 filledVector.resize(iterationDomainSize);
78 const scf::SCFTilingOptions &
options) {
80 if (
options.numThreadsComputationFunction &&
81 options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) {
83 loc,
"number of threads can only by specified when loop type is "
84 "set to use `scf.forall`");
88 if (!
options.interchangeVector.empty()) {
91 loc,
"invalid interchange vector, not a permutation of the entire "
103 const scf::SCFTilingOptions &
options) {
106 size_t numLoops = iterationDomain.size();
109 if (
options.numThreadsComputationFunction) {
110 numThreads =
options.numThreadsComputationFunction(rewriter, op);
111 numThreads.resize(numLoops, zero);
114 if (
options.tileSizeComputationFunction) {
115 tileSizes =
options.tileSizeComputationFunction(rewriter, op);
116 tileSizes.resize(numLoops, zero);
117 return {tileSizes, numThreads};
129 tileSizes.resize(numLoops, zero);
130 for (
auto [
index, range, nt] :
131 llvm::enumerate(iterationDomain, numThreads)) {
136 rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
138 tileSizes.resize(numLoops, zero);
139 return {tileSizes, numThreads};
146 assert(
options.tileSizeComputationFunction &&
147 "expected tile sizes to be specified");
148 tileSizes =
options.tileSizeComputationFunction(rewriter, op);
149 tileSizes.resize(numLoops, zero);
151 return {tileSizes, numThreads};
156 scf::SCFTilingOptions::LoopType loopType,
160 auto iterators = op.getLoopIteratorTypes();
161 assert(iterators.size() == givenTileSizes.size() &&
162 "expected as many tile size values as number of loops");
163 assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
164 "when specified, expected number of threads to use for each loop");
166 bool isParallelTiling =
false;
167 for (
auto [
index, iterator, givenTileSize] :
168 llvm::enumerate(iterators, givenTileSizes)) {
170 isParallelTiling |= iterator == utils::IteratorType::parallel;
173 if (loopType == scf::SCFTilingOptions::LoopType::ForallOp &&
177 if (!numThreads.empty()) {
178 if (std::optional<int64_t> constNumThreads =
180 if (constNumThreads.value() > 1 &&
181 iterator != utils::IteratorType::parallel) {
182 op.emitWarning() <<
"tiling is not thread safe at axis #" <<
index;
188 if (std::optional<int64_t> constTileSize =
190 if (constTileSize.value() > 0 &&
191 iterator != utils::IteratorType::parallel) {
192 op.emitWarning() <<
"tiling is not thread safe at axis #" <<
index;
199 if (isParallelTiling) {
200 return op->emitOpError(
"tiling parallel dimensions is not supported with "
201 "partial reduction tiling strategies");
211 const scf::SCFTilingOptions &
options) {
213 for (
auto dim :
options.reductionDims) {
216 reductionDims.insert(dim);
218 return reductionDims;
232 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
241 if (ts && ts.value() == 1)
242 return givenTileSize;
246 return givenTileSize;
268 if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
270 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
281 int materializedLoopNum = 0;
282 for (
auto [givenTileSize, loopRange] :
283 llvm::zip_equal(givenTileSizes, iterationDomain)) {
288 offsets.push_back(loopRange.offset);
289 sizes.push_back(loopRange.size);
293 Value iv = ivs[materializedLoopNum++];
295 offsets.push_back(offset);
298 sizes.push_back(size);
300 return {offsets, sizes};
309 for (
auto [loopRange, givenTileSize] :
310 llvm::zip_equal(loopRanges, givenTileSizes)) {
314 lbs.push_back(loopRange.offset);
315 ubs.push_back(loopRange.size);
316 steps.push_back(givenTileSize);
318 return {lbs, ubs, steps};
372 if (newDestArgs.empty())
374 if (
auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
375 destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
392 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
393 assert(loopRanges.size() == givenTileSizes.size() &&
394 "expected as many tile sizes as loop ranges");
398 std::tie(lbs, ubs, steps) =
409 ValueRange innerDestinationTensors(outerDestinationTensors);
410 for (
auto [lb,
ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
412 scf::ForOp::create(rewriter, loc, lb,
ub, step, innerDestinationTensors,
415 loops.push_back(loop);
416 ivs.push_back(loop.getInductionVar());
418 innerDestinationTensors = loop.getRegionIterArgs();
425 std::tie(offsets, sizes) =
430 if (failed(tiledBodyFn(rewriter, loc, ivs, offsets, sizes,
431 innerDestinationTensors, tiledResults, resultOffsets,
434 loc,
"failed to generate inner tile loop body");
439 assert(tiledResults.size() == innerDestinationTensors.size() &&
440 "Number of results of body should be equal to number of iter args");
444 for (
auto [tiledValue, destinationTensor, resultOffset, resultSize] :
445 llvm::zip_equal(tiledResults, innerDestinationTensors, resultOffsets,
449 auto insertSlice = tensor::InsertSliceOp::create(
450 rewriter, loc, tiledValue, destinationTensor, resultOffset, resultSize,
452 yieldedValues.push_back(insertSlice);
454 scf::YieldOp::create(rewriter, loc, yieldedValues);
457 for (
auto [outerLoop, innerLoop] :
461 cast<scf::ForOp>(outerLoop.getOperation()).getBody());
462 scf::YieldOp::create(rewriter, outerLoop.getLoc(), innerLoop->getResults());
479 if (numThreads.empty()) {
485 int materializedLoopNum = 0;
491 offsetExpr = d0 + d1 * s0;
492 residualTileSizeExpr = s1 - (d0 + d1 * s0);
494 for (
auto [
index, nt, givenTileSize, loopRange] :
495 llvm::enumerate(numThreads, givenTileSizes, iterationDomain)) {
500 offsets.push_back(loopRange.offset);
501 sizes.push_back(loopRange.size);
505 Value iv = ivs[materializedLoopNum++];
507 rewriter, loc, offsetExpr,
510 rewriter, loc, residualTileSizeExpr,
511 {loopRange.offset, nt, givenTileSize, loopRange.size});
517 {offset, loopRange.size});
521 {sizeMinusOffsetPerThread, givenTileSize});
537 rewriter, loc, maxMap, {rewriter.
getIndexAttr(0), size});
540 offsets.push_back(offset);
541 sizes.push_back(size);
543 return {offsets, sizes};
556static FailureOr<SmallVector<LoopLikeOpInterface>>
564 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
565 assert(loopRanges.size() == givenTileSizes.size() &&
566 "expected as many tile sizes as loop ranges");
569 std::optional<ArrayAttr> mappingAttr;
570 if (!mappingVector.empty())
573 scf::ForallOp forallOp;
574 bool useNumThreads = !numThreads.empty();
580 for (
auto nt : numThreads) {
583 nonZeroNumThreads.push_back(nt);
585 forallOp = scf::ForallOp::create(rewriter, loc, nonZeroNumThreads,
586 outerDestinationTensors, mappingAttr);
589 std::tie(lbs, ubs, steps) =
591 forallOp = scf::ForallOp::create(rewriter, loc, lbs, ubs, steps,
592 outerDestinationTensors, mappingAttr);
594 loops.push_back(forallOp);
597 ValueRange innerDestinationTensors = forallOp.getRegionOutArgs();
603 rewriter, loc, ivs, loopRanges, givenTileSizes, numThreads);
607 if (failed(tiledBodyFn(rewriter, loc, ivs, offsets, sizes,
608 innerDestinationTensors, tiledResults, resultOffsets,
613 for (
auto [tiledValue, destinationTensor, resultOffset, resultSize] :
614 llvm::zip_equal(tiledResults, innerDestinationTensors, resultOffsets,
619 tensor::ParallelInsertSliceOp::create(rewriter, loc, tiledValue,
620 destinationTensor, resultOffset,
621 resultSize, resultStride);
636static FailureOr<SmallVector<LoopLikeOpInterface>>
640 const scf::SCFTilingOptions::GenerateLoopHeaderFn &generateLoopHeaderFn,
641 const scf::SCFTilingOptions::GenerateLoopTerminatorFn
642 &generateLoopTerminatorFn,
644 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
645 assert(loopRanges.size() == givenTileSizes.size() &&
646 "expected as many tile sizes as loop ranges");
647 assert(generateLoopHeaderFn && generateLoopTerminatorFn &&
648 "expected loop header/terminator generation function");
651 FailureOr<scf::SCFTilingOptions::CustomLoopHeaderInfo> loopHeaderInfo =
652 generateLoopHeaderFn(rewriter, loc, loopRanges, givenTileSizes,
653 outerDestinationTensors);
654 if (failed(loopHeaderInfo)) {
661 if (failed(tiledBodyFn(rewriter, loc, ivs, loopHeaderInfo->tileOffset,
662 loopHeaderInfo->tileSizes,
663 loopHeaderInfo->destinationTensors, tiledResults,
664 resultOffsets, resultSizes))) {
668 if (failed(generateLoopTerminatorFn(rewriter, loc, loopHeaderInfo->loops,
669 tiledResults, resultOffsets, resultSizes,
670 loopHeaderInfo->destinationTensors))) {
674 return loopHeaderInfo->loops;
698 llvm::map_to_vector(loopRanges, [](
Range r) {
return r.
offset; });
700 llvm::map_to_vector(loopRanges, [](
Range r) {
return r.
size; });
701 if (failed(tiledBodyFn(rewriter, loc,
ValueRange{}, tileOffsets, tileSizes,
702 destinationTensors, tiledResults, resultOffsets,
708 if (
options.loopType == scf::SCFTilingOptions::LoopType::ForOp) {
710 destinationTensors, tiledBodyFn);
712 if (
options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
714 rewriter, loc, loopRanges, givenTileSizes, numThreads,
715 options.mappingVector, destinationTensors, tiledBodyFn);
717 if (
options.loopType == scf::SCFTilingOptions::LoopType::CustomOp) {
719 rewriter, loc, loopRanges, givenTileSizes, destinationTensors,
739 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
741 return op->emitOpError(
742 "PartialReductionOuterReduction tiling strategy is only supported for "
743 "operations implementing PartialReductionOpInterface");
748 AffineExpr sizeExpr = ((s0 - s1).ceilDiv(s2));
750 for (
auto [
index, domain, tileSize] :
751 llvm::enumerate(iterationDomain, givenTileSizes)) {
752 if (!numThreads.empty()) {
756 rewriter, op.getLoc(), sizeExpr,
757 {domain.size, domain.offset, domain.stride});
767 rewriter, op.getLoc(), sizeExpr,
768 {domain.size, domain.offset, domain.stride});
772 if (reductionStrategy ==
774 sizes[
index] = tileSize;
778 assert(reductionStrategy ==
781 rewriter, op.getLoc(), sizeExpr,
782 {domain.size, domain.offset, domain.stride});
784 rewriter, op.getLoc(), divExpr, {normalizedRange, tileSize});
786 return redOp.generateInitialTensorForPartialReduction(rewriter, loc, sizes,
800 splitReductionIvs.resize(reductionDims.size(), rewriter.
getIndexAttr(0));
805 if (reductionStrategy ==
807 for (
auto [
index, reductionDim] : llvm::enumerate(reductionDims)) {
808 if (!numThreads.empty()) {
809 splitReductionIvs[
index] = ivs[ivIndex++];
813 rewriter, loc, divExpr,
817 return splitReductionIvs;
820static FailureOr<TilingResult>
829 return op.getTiledImplementation(rewriter, offsets, sizes);
832 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
835 op,
"PartialReductionOuterReduction tiling strategy is only "
836 "supported for operations "
837 "implementing PartialReductionOpInterface");
842 numThreads, givenTileSizes, reductionDims);
843 return redOp.tileToPartialReduction(rewriter, op.getLoc(), reductionStrategy,
844 regionIterArg, offsets, sizes,
845 reductionDims, splitReductionIvs);
859 return op.getResultTilePosition(rewriter,
index, offsets, sizes,
860 resultOffset, resultSize);
862 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
865 op,
"PartialReductionOuterReduction tiling strategy is only supported"
866 "for operations implementing PartialReductionOpInterface");
870 numThreads, givenTileSizes, reductionDims);
871 return redOp.getPartialResultTilePosition(
872 rewriter,
index, reductionStrategy, offsets, sizes, reductionDims,
873 splitReductionIvs, resultOffset, resultSize);
876static FailureOr<MergeResult>
882 "expected merge to be called for only partial reduction cases");
884 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
887 op,
"PartialReductionOuterReduction tiling strategy is only "
888 "supported for operations "
889 "implementing PartialReductionOpInterface");
891 return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
902template <
typename LoopType>
903FailureOr<LoopLikeOpInterface>
919 auto inits = llvm::to_vector(loopOp.getInitArgs());
920 inits.append(newInitOperands.begin(), newInitOperands.end());
921 auto newLoop = scf::ForOp::create(
922 rewriter, loc, loopOp.getLowerBound(), loopOp.getUpperBound(),
924 loopOp.getUnsignedCmp());
927 Block *loopBody = loopOp.getBody();
928 Block *newLoopBody = newLoop.getBody();
930 loopBody, newLoopBody,
933 auto yieldOp = cast<scf::YieldOp>(newLoopBody->
getTerminator());
939 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
940 if (
failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
941 newRegionIterArgs, tiledValues, resultOffsets,
948 for (
auto [tiledValue, regionIterArg, resultOffset, resultSize] :
949 llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
953 Value insert = tensor::InsertSliceOp::create(
954 rewriter, yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset,
955 resultSize, resultStride);
956 newYieldValues.push_back(insert);
961 newLoop->getResults().take_front(loopOp.getNumResults()));
962 return cast<LoopLikeOpInterface>(newLoop.getOperation());
973 auto inits = llvm::to_vector(loopOp.getOutputs());
974 inits.append(newInitOperands.begin(), newInitOperands.end());
975 auto newLoop = scf::ForallOp::create(
976 rewriter, loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
977 loopOp.getMixedStep(), inits, loopOp.getMapping(),
981 Block *loopBody = loopOp.getBody();
982 Block *newLoopBody = newLoop.getBody();
984 loopBody, newLoopBody,
987 auto terminator = cast<scf::InParallelOp>(newLoopBody->
getTerminator());
992 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
993 if (
failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
994 regionIterArgs, tiledValues, resultOffsets,
998 "failed to get yielded tiled values");
1004 for (
auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
1005 tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
1008 tensor::ParallelInsertSliceOp::create(rewriter, terminator.getLoc(),
1009 tiledValue, iterArg, resultOffset,
1010 resultSize, resultStride);
1014 newLoop->getResults().take_front(loopOp.getNumResults()));
1015 return cast<LoopLikeOpInterface>(newLoop.getOperation());
1022 LoopLikeOpInterface loopLikeOp,
RewriterBase &rewriter,
1025 loopLikeOp.getOperation())
1026 .Case<scf::ForOp, scf::ForallOp>(
1027 [&](
auto loopOp) -> FailureOr<LoopLikeOpInterface> {
1029 loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
1031 .Default([&](
auto loopOp) -> FailureOr<LoopLikeOpInterface> {
1050 for (
auto &loop : loops.drop_back()) {
1054 auto forLoop = cast<scf::ForOp>(loop.getOperation());
1058 newInits.append(newInitValues.begin(), newInitValues.end());
1059 auto newLoop = scf::ForOp::create(
1060 rewriter, forLoop.getLoc(), forLoop.getLowerBound(),
1061 forLoop.getUpperBound(), forLoop.getStep(), newInits,
1063 forLoop.getUnsignedCmp());
1067 sourceBlockArgs.push_back(newLoop.getInductionVar());
1068 auto newRegionIterArgs = newLoop.getRegionIterArgs();
1069 sourceBlockArgs.append(
1070 newRegionIterArgs.begin(),
1071 std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
1072 rewriter.
mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
1074 forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
1076 ivs.push_back(newLoop.getInductionVar());
1077 newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
1081 LoopLikeOpInterface innerMostLoop = loops.back();
1082 FailureOr<LoopLikeOpInterface> newInnerMostLoop =
1084 getNewTiledYieldsFn);
1086 if (failed(newInnerMostLoop))
1087 return innerMostLoop.emitOpError(
"failed to return additional yields");
1088 loops.back() = newInnerMostLoop.value();
1092 for (
auto [outerLoop, innerLoop] :
1093 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
1095 auto outerForLoop = cast<scf::ForOp>(outerLoop.getOperation());
1096 auto outerLoopYield =
1097 cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
1099 llvm::to_vector(outerLoopYield.getOperands());
1101 innerLoop->getResults().take_back(newInitValues.size());
1102 newYields.append(additionalYields.begin(), additionalYields.end());
1111FailureOr<scf::SCFTilingResult>
1112mlir::scf::tileUsingSCF(
RewriterBase &rewriter, TilingInterface op,
1113 const scf::SCFTilingOptions &
options) {
1126 std::tie(givenTileSizes, numThreads) =
1132 givenTileSizes, numThreads))) {
1143 if (!
options.interchangeVector.empty()) {
1145 iterationDomain.size());
1147 "expected interchange vector to be a permutation");
1151 if (!numThreads.empty())
1155 FailureOr<TilingResult> tilingResult;
1169 if (!interchangeVector.empty()) {
1178 auto clonedOp = cast<TilingInterface>(
1185 tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
1195 regionIterArgs, tileOffsetsVec, tileSizesVec,
1196 ivs, numThreads, givenTileSizes, reductionDims);
1197 if (
failed(tilingResult)) {
1199 return op.emitOpError(
"faild to tile operation");
1207 for (
auto [
index, tiledValue] :
1208 llvm::enumerate(tilingResult->tiledValues)) {
1209 tiledResults.push_back(tiledValue);
1212 rewriter,
options.reductionStrategy,
index, tiledValue, op,
1213 tileOffsetsVec, tileSizesVec, ivs, numThreads, givenTileSizes,
1214 reductionDims, resultOffset, resultSize))) {
1215 for (
auto op : tilingResult->tiledOps) {
1219 op,
"failed to get slice of result produced");
1221 resultOffsets.emplace_back(std::move(resultOffset));
1222 resultSizes.emplace_back(std::move(resultSize));
1230 rewriter, op,
options.reductionStrategy, iterationDomain, numThreads,
1231 givenTileSizes, reductionDims);
1232 if (
failed(maybeInits)) {
1234 op,
"unable to create initial tensors for tiling");
1242 rewriter, op.getLoc(),
options, iterationDomain, givenTileSizes,
1243 numThreads, initTensors, innerYieldTiledValuesFn);
1245 return op.emitOpError(
"failed to generate tiling loops");
1246 assert(succeeded(tilingResult) &&
1247 "expected tiling result to be computed after loop generation");
1248 std::swap(loops, loopsOr.value());
1251 if (loops.empty()) {
1254 return scf::SCFTilingResult{tilingResult->tiledOps,
1257 tilingResult->tiledValues,
1258 tilingResult->generatedSlices,
1262 auto loopResults = llvm::map_to_vector(loops.front()->getResults(),
1267 return scf::SCFTilingResult{
1268 tilingResult->tiledOps, initTensors, loops, loopResults,
1269 tilingResult->generatedSlices, {}};
1274 rewriter, op,
options.reductionStrategy, reductionDims, loopResults);
1275 if (
failed(mergeResult)) {
1277 op,
"Failed to merge partial results from tiling");
1279 return scf::SCFTilingResult{tilingResult->tiledOps,
1282 mergeResult->replacements,
1283 tilingResult->generatedSlices,
1284 mergeResult->mergeOps};
1287FailureOr<scf::SCFTilingResult>
1289 PartialReductionOpInterface op,
1291 scf::SCFTilingOptions
options;
1292 options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
1293 options.setReductionTilingStrategy(
1295 options.setTileSizes(tileSize);
1297 for (
auto [
index, iteratorType] : llvm::enumerate(op.getLoopIteratorTypes()))
1298 if (iteratorType == utils::IteratorType::reduction)
1299 reductionDims.push_back(
index);
1300 options.setReductionDims(reductionDims);
1301 return tileUsingSCF(
b, op,
options);
1314static std::tuple<OpResult, std::optional<OpOperand *>>
1317 std::optional<OpOperand *> destinationIterArg;
1318 assert(!loops.empty() &&
"expected non empty loops container");
1319 auto loopIt = loops.rbegin();
1320 while (loopIt != loops.rend() && isa<BlockArgument>(source->
get())) {
1321 auto iterArg = cast<BlockArgument>(source->
get());
1322 auto loop = *loopIt;
1323 if (iterArg.getOwner()->getParentOp() != loop)
1325 source = loop.getTiedLoopInit(iterArg);
1328 if (loopIt == loops.rend())
1329 destinationIterArg = source;
1331 auto result = dyn_cast<OpResult>(source->
get());
1334 Operation *innermostLoop = loops.back();
1341 return {
result, destinationIterArg};
1346std::optional<scf::SCFFuseProducerOfSliceResult>
1347mlir::scf::tileAndFuseProducerOfSlice(
1348 RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1352 auto [fusableProducer, destinationInitArg] =
1355 if (!fusableProducer)
1356 return std::nullopt;
1357 unsigned resultNumber = fusableProducer.getResultNumber();
1365 Operation *fusableProducerOp = fusableProducer.getOwner();
1366 if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
1368 rewriter, fusableProducerOp->
getLoc(), fusableProducerOp,
1369 origDestinationTensors)))
1370 return std::nullopt;
1372 clonedOpDestinationTensors = origDestinationTensors;
1373 if (destinationInitArg &&
1374 isa<DestinationStyleOpInterface>(fusableProducerOp)) {
1378 clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1382 rewriter, fusableProducerOp, clonedOpDestinationTensors);
1387 llvm::to_vector(candidateSliceOp->getOperands());
1388 candidateSliceOpOperands[0] = clonedProducerOp->
getResult(resultNumber);
1389 tensor::ExtractSliceOp clonedCandidateSliceOp =
1391 candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1394 FailureOr<TilingResult> tileAndFuseResult =
1396 rewriter, clonedCandidateSliceOp,
1397 clonedProducerOp->
getResult(resultNumber));
1398 if (
failed(tileAndFuseResult))
1399 return std::nullopt;
1403 tileAndFuseResult->tiledValues[0]);
1404 rewriter.
eraseOp(clonedCandidateSliceOp);
1405 rewriter.
eraseOp(clonedProducerOp);
1450 if (destinationInitArg &&
1451 isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
1453 ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1454 .set(origDestinationTensors[resultNumber]);
1456 return scf::SCFFuseProducerOfSliceResult{
1457 fusableProducer, tileAndFuseResult->tiledValues[0],
1458 tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices};
1462FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
1463 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1464 scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
1470 Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
1471 *tiledOwner = fusedProducerInfo.tiledOps[0];
1476 yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
1478 :
llvm::to_vector(yieldResultNumber);
1480 for (
const auto &resultNumber : initNumberList) {
1482 rewriter, loc, originalOwner->
getResult(resultNumber));
1483 if (succeeded(initValue)) {
1484 initValueList.push_back(initValue.value());
1500 sliceSizes = sliceOp.getMixedSizes();
1503 if (!llvm::all_of(sliceOp.getMixedStrides(),
isOneInteger))
1506 unsigned sliceResultNumber =
1507 fusedProducerInfo.origProducer.getResultNumber();
1509 auto tilableOp = cast<TilingInterface>(originalOwner);
1516 if (
auto tiledDestStyleOp =
1517 dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1521 if (tilableOp->getNumResults() > 1 &&
1522 failed(tilableOp.getIterationDomainTileFromResultTile(
1523 rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1524 iterDomainOffset, iterDomainSizes))) {
1539 for (
const auto &resultNumber : initNumberList) {
1540 if (resultNumber == sliceResultNumber) {
1541 offsetList.push_back(sliceOffset);
1542 sizesList.push_back(sliceSizes);
1544 assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1547 if (
failed(tilableOp.getResultTilePosition(
1548 rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1552 offsetList.push_back(offset);
1553 sizesList.push_back(sizes);
1559 if (
auto tiledDestStyleOp =
1560 dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1561 for (
const auto &&[
index, newRegionArg] :
1562 llvm::enumerate(newRegionIterArgs)) {
1563 auto destSlice = tensor::ExtractSliceOp::create(
1564 rewriter, loc, newRegionArg, offsetList[
index], sizesList[
index],
1567 generatedSlices.push_back(destSlice);
1568 unsigned resultNumber = initNumberList[
index];
1570 tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1579 for (
const auto &&[
index, resultNumber] : llvm::enumerate(initNumberList)) {
1580 tiledResult.push_back(tiledOwner->getResult(resultNumber));
1581 tiledOffset.emplace_back(offsetList[
index]);
1582 tiledSizes.emplace_back(sizesList[
index]);
1588 newYieldValuesFn))) {
1591 return generatedSlices;
1605 explicit SliceTrackingListener(
1606 std::optional<FrozenRewritePatternSet> patterns);
1607 SliceTrackingListener() =
default;
1613 LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
1616 void notifyOperationInserted(Operation *op,
1617 OpBuilder::InsertPoint previous)
override;
1620 void removeOp(Operation *op);
1623 void notifyOperationErased(Operation *op)
override;
1630 std::deque<tensor::ExtractSliceOp> worklist;
1635 std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
1638SliceTrackingListener::SliceTrackingListener(
1639 std::optional<FrozenRewritePatternSet> p) {
1644SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
1645 for (Operation *op : ops) {
1646 if (
auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1647 worklist.push_back(slice);
1655 GreedyRewriteConfig().setListener(
this).setStrictness(
1656 GreedyRewriteStrictness::ExistingAndNewOps));
1659void SliceTrackingListener::notifyOperationInserted(
1660 Operation *op, OpBuilder::InsertPoint previous) {
1661 auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1664 worklist.push_back(slice);
1670void SliceTrackingListener::removeOp(Operation *op) {
1671 if (!isa<tensor::ExtractSliceOp>(op))
1673 auto iter = worklist.begin();
1674 while (iter != worklist.end()) {
1679 if (iter == worklist.end())
1682 worklist.erase(iter);
1685void SliceTrackingListener::notifyOperationErased(Operation *op) {
1689void SliceTrackingListener::notifyOperationReplaced(Operation *op,
1701class ReplacementListener :
public RewriterBase::ForwardingListener {
1703 ReplacementListener(DenseMap<Value, Value> &replacements,
1704 OpBuilder::Listener *listener)
1705 : ForwardingListener(listener), replacements(replacements) {}
1707 void updateReplacementValues(
ValueRange origValues,
1711 for (
auto &[key, val] : replacements) {
1712 for (
auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) {
1720 void notifyOperationReplaced(Operation *op, Operation *newOp)
override {
1721 ForwardingListener::notifyOperationReplaced(op, newOp);
1725 void notifyOperationReplaced(Operation *op,
ValueRange values)
override {
1726 ForwardingListener::notifyOperationReplaced(op, values);
1727 updateReplacementValues(op->
getResults(), values);
1731 DenseMap<Value, Value> &replacements;
1737FailureOr<scf::SCFTileAndFuseResult>
1738mlir::scf::tileConsumerAndFuseProducersUsingSCF(
1739 RewriterBase &rewriter, TilingInterface consumer,
1740 const scf::SCFTileAndFuseOptions &
options) {
1743 if (!consumer->getNumResults()) {
1745 consumer,
"invalid pattern for op with no results");
1751 FailureOr<scf::SCFTilingResult> tilingResult =
1752 tileUsingSCF(rewriter, consumer,
options.tilingOptions);
1754 if (
failed(tilingResult))
1756 tiledAndFusedOps.insert_range(tilingResult->tiledOps);
1758 DenseMap<Value, Value> replacements;
1760 llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
1765 auto &loops = tilingResult->loops;
1766 if (loops.empty()) {
1767 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1774 OpBuilder::Listener *previousListener = rewriter.
getListener();
1775 llvm::scope_exit resetListener(
1776 [&]() { rewriter.
setListener(previousListener); });
1777 ReplacementListener replaceListener(replacements, previousListener);
1787 struct WorklistItem {
1788 tensor::ExtractSliceOp candidateSlice;
1789 SCFTileAndFuseOptions::ControlFnResult controlFnResult;
1792 SliceTrackingListener sliceTracker =
1793 SliceTrackingListener(
options.cleanupPatterns);
1796 sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1799 OpBuilder::InsertionGuard g(rewriter);
1800 while (!sliceTracker.worklist.empty()) {
1801 auto candidateSlice = sliceTracker.worklist.front();
1802 sliceTracker.worklist.pop_front();
1804 auto [fusableProducer, destinationInitArg] =
1807 if (!fusableProducer)
1810 std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1811 options.fusionControlFn(candidateSlice, fusableProducer,
1812 destinationInitArg.has_value());
1813 if (!controlFnResult)
1816 WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
1821 std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1822 tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice,
1827 SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
1829 if (worklistItem.controlFnResult.yieldProducerReplacement) {
1834 Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1835 FailureOr<SmallVector<Operation *>> newSlices =
1836 yieldReplacementForFusedProducer(rewriter,
1837 worklistItem.candidateSlice,
1838 fusedResult.value(), loops);
1841 fusableProducerOp,
"failed to replacement value for this "
1842 "operation from within the tiled loop");
1844 worklistCandidates.append(newSlices.value());
1845 for (
auto [index,
result] :
1846 llvm::enumerate(fusableProducerOp->
getResults())) {
1847 replacements[
result] = loops.front()->getResult(
1848 loops.front()->getNumResults() -
1852 if (Operation *tiledAndFusedOp =
1853 fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1854 fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1855 tiledAndFusedOps.insert(tiledAndFusedOp);
1858 if (
failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1863 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1877 if (!llvm::hasSingleElement(uses)) {
1878 LLVM_DEBUG(llvm::dbgs() <<
"Too many uses of the candidate slice op\n");
1881 OpOperand &operandUse = (*uses.begin());
1883 if (!isa<scf::YieldOp>(userOp)) {
1884 LLVM_DEBUG(llvm::dbgs()
1885 <<
"Expected scf.yield to be the only user, but got -> "
1890 LLVM_DEBUG(llvm::dbgs() <<
"Expected tensor.insert_slice and scf.yield to "
1891 "be in the same block\n");
1900 if (!isa<LoopLikeOpInterface>(loopOp))
1919 if (isa<tensor::ParallelInsertSliceOp>(userOp))
1920 userOp = userOp->getParentOfType<scf::InParallelOp>();
1922 if (loopOp->
getBlock() != userOp->getBlock())
1926 firstUserOfLoop = userOp;
1928 return firstUserOfLoop;
1969static FailureOr<llvm::SetVector<Operation *>>
1971 bool reorderOperations) {
1973 if (failed(firstUserOfLoop))
1979 options.omitBlockArguments =
true;
1980 bool includeLoopOp =
false;
1983 includeLoopOp =
true;
1993 assert(
result.succeeded() &&
"expected a backward slice");
1997 if (!slice.empty()) {
2007 if (includeLoopOp || !reorderOperations)
2019 unsigned resultNumber) {
2020 if (!isa<LoopLikeOpInterface>(loopOp))
2025 Operation *consumerOp = opOperand.getOwner();
2027 if (!isa<TilingInterface>(consumerOp) ||
2028 !isa<DestinationStyleOpInterface>(consumerOp)) {
2035 if (loopBlock != consumerOp->
getBlock())
2042 FailureOr<llvm::SetVector<Operation *>> slice =
2048 if (!slice->empty()) {
2051 assert(succeeded(firstUserOfLoop) &&
"First user of loop is not found");
2052 for (
auto op : *slice) {
2067static FailureOr<OpOperand *>
2069 tensor::InsertSliceOp candidateSliceOp,
2071 assert(!loops.empty() &&
"unexpected loops to be empty");
2074 if (containingOp != loops.back()) {
2077 "expected slice to be within body of inner-most loop");
2083 candidateSliceOp,
"expected passed loops to be perfectly nested.");
2088 Value sliceResult = candidateSliceOp.getResult();
2094 scf::ForOp topLevelForOp = cast<scf::ForOp>(loops.front().getOperation());
2101static FailureOr<OpOperand *>
2103 tensor::ParallelInsertSliceOp candidateSliceOp,
2105 assert(!loops.empty() &&
"unexpected loops to be empty");
2107 if (loops.size() != 1) {
2109 candidateSliceOp,
"expected single surrounding scf.forall");
2111 auto forallOp = dyn_cast<scf::ForallOp>(loops.front().getOperation());
2114 candidateSliceOp,
"expected single surrounding scf.forall");
2118 Value sliceDest = candidateSliceOp.getDest();
2119 auto iterArg = dyn_cast<BlockArgument>(sliceDest);
2122 if (iterArg.getOwner()->getParentOp() != forallOp)
2125 unsigned resultNumber =
2126 forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
2137 assert(!loops.empty() &&
"unexpected empty loops");
2138 assert(!sliceOps.empty() &&
"unexpected empty list of candidate slices");
2140 for (
auto sliceOp : sliceOps) {
2141 FailureOr<OpOperand *> fusedOperand =
2143 .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2150 if (failed(fusedOperand)) {
2153 if (!fusedOperands.empty() &&
2154 fusedOperand.value()->getOwner() != fusedOperands.front()->getOwner()) {
2156 fusedOperand.value()->getOwner(),
2157 "all candidate slices must be to the same consumer");
2159 fusedOperands.push_back(fusedOperand.value());
2161 return fusedOperands;
2164template <
typename InsertSliceOpTy>
2166 InsertSliceOpTy sliceOp);
2169tensor::InsertSliceOp
2171 tensor::InsertSliceOp insertSliceOp) {
2172 return cast<tensor::InsertSliceOp>(
2173 rewriter.
clone(*insertSliceOp.getOperation()));
2178 RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp) {
2179 return tensor::InsertSliceOp::create(
2180 rewriter, insertSliceOp->getLoc(), insertSliceOp.getSource(),
2181 insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(),
2182 insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
2185static SmallVector<tensor::InsertSliceOp>
2188 assert(!candidateSlices.empty() &&
2189 "unexpected empty list of slices to clone");
2191 for (
auto sliceOp : candidateSlices) {
2193 .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2196 clonedSlices.push_back(clonedOp);
2199 .DefaultUnreachable(
2200 "unexpected slice type while cloning as insert slice");
2202 return clonedSlices;
2205static FailureOr<scf::SCFFuseConsumerOfSliceResult>
2210 assert(!loops.empty() &&
"expected loops to be not empty");
2215 loops.front(),
"the first user of loop should not dominate any define "
2216 "of consumer operand(s)");
2219 LoopLikeOpInterface outerMostLoop = loops.front();
2220 LoopLikeOpInterface innerMostLoop = loops.back();
2224 auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
2227 "consumer op is not DPS operation");
2228 if (llvm::any_of(consumerOpOperands, [&](
OpOperand *opOperand) {
2229 return dstOp.isDpsInit(opOperand);
2233 "consumer op taking the result of scf.for as init is not supported");
2240 if (failed(firstUserOfLoop)) {
2242 outerMostLoop,
"could not find the first user of outer most loop");
2244 rewriter.
moveOpBefore(outerMostLoop, *firstUserOfLoop);
2251 dyn_cast<tensor::ParallelInsertSliceOp>(candidateSlices.front())) {
2252 auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
2262 auto clonedConsumerOp = cast<TilingInterface>(rewriter.
clone(*consumerOp));
2264 llvm::map_to_vector(consumerOpOperands, [](
OpOperand *opOperand) {
2268 llvm::map_to_vector(operandNumbers, [&](
unsigned operandNum) {
2269 return &clonedConsumerOp->getOpOperand(operandNum);
2275 for (
auto [operandToReplace, clonedSliceOp] :
2276 llvm::zip_equal(clonedOpFusedOperandsList, clonedInsertSlices)) {
2277 operandToReplace->set(clonedSliceOp.getResult());
2283 FailureOr<TilingResult> tileAndFuseResult =
2285 clonedOpFusedOperandsList);
2286 if (failed(tileAndFuseResult)) {
2290 auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
2291 for (
auto [operandNum, clonedSliceOp] :
2292 llvm::zip_equal(operandNumbers, clonedInsertSlices)) {
2294 clonedSliceOp.getSource());
2308 for (
auto candidateSliceOp : clonedInsertSlices) {
2316 candidateSliceOp,
"containingOp's result yield with stride");
2319 allOffsets.emplace_back(std::move(offsets));
2320 allSizes.emplace_back(std::move(sizes));
2330 if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTiles(
2331 rewriter, operandNumbers, allOffsets, allSizes, iterDomainOffsets,
2332 iterDomainSizes))) {
2335 "can't get iter domain position from input position");
2341 unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
2343 totalNumResultsOfConsumer);
2345 totalNumResultsOfConsumer);
2346 for (
auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
2347 if (failed(tiledConsumerOp.getResultTilePosition(
2348 rewriter, idx, iterDomainOffsets, iterDomainSizes,
2349 resultOffsets[idx], resultSizes[idx]))) {
2352 "can't get result domain position from iter domain position");
2358 if (
auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
2359 tiledConsumerOp.getOperation())) {
2361 for (
const auto &&[
index, newRegionArg] :
2362 llvm::enumerate(newRegionIterArgs)) {
2363 auto destSlice = tensor::ExtractSliceOp::create(
2364 rewriter, loc, newRegionArg, resultOffsets[
index],
2370 auto dstNumber =
index;
2372 tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
2382 llvm::enumerate(tiledConsumerOp->getResults())) {
2383 tiledResult.push_back(
result);
2384 tiledOffset.emplace_back(resultOffsets[
index]);
2385 tiledSizes.emplace_back(resultSizes[
index]);
2391 newYieldValuesFn))) {
2393 "unable to add new inits to nest loop");
2399 for (
auto &&[oldResult, newResult] :
2401 loops.front()->getResults().take_back(newInits.size()))) {
2406 rewriter.
eraseOp(clonedConsumerOp);
2409 llvm::map_to_vector(operandNumbers, [&](
unsigned operandNum) {
2410 return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum);
2412 auto consumerOpOperandsVec = llvm::to_vector(consumerOpOperands);
2413 return scf::SCFFuseConsumerOfSliceResult{
2414 std::move(consumerOpOperandsVec), std::move(tiledAndFusedOpOperands),
2415 std::move(tileAndFuseResult->tiledOps)};
2420FailureOr<scf::SCFFuseConsumerOfSliceResult>
2421mlir::scf::tileAndFuseConsumerOfSlices(
2422 RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
2423 MutableArrayRef<LoopLikeOpInterface> loops) {
2424 if (candidateSlices.empty()) {
2427 "no candidate slices provided for consumer fusion");
2431 if (loops.empty()) {
2433 candidateSlices.front(),
2434 "cannot call tile and fuse consumer with an empty loop nest");
2437 if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
2438 llvm::all_of(candidateSlices,
2439 llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
2441 candidateSlices.front(),
2442 "candidates slices need to be all `tensor.extract_slice`s or "
2443 "`tensor.parallel_insert_slice`s");
2448 FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperands =
2450 if (
failed(maybeConsumerOpOperands)) {
2452 "could not fetch consumer to fuse");
2454 Operation *consumerOp = maybeConsumerOpOperands->front()->getOwner();
2457 maybeConsumerOpOperands.value(),
2458 candidateSlices, loops);
2464static std::optional<Operation *>
2466 if (
result.getOwner() != forallOp)
2467 return std::nullopt;
2472 if (combiningOps.size() != 1)
2473 return std::nullopt;
2474 return combiningOps[0];
2479static std::optional<Operation *>
2482 assert(!loops.empty() &&
"Expected loops to be not empty");
2483 LoopLikeOpInterface outerMostLoop = loops.front();
2484 if (
auto forallOp = dyn_cast<scf::ForallOp>(outerMostLoop.getOperation())) {
2485 assert(loops.size() == 1 &&
2486 "expected only a single loop when tiling using scf.forall");
2492 while (loops.size() != 1) {
2493 LoopLikeOpInterface loop = loops.front();
2494 if (
result.getOwner() != loop)
2495 return std::nullopt;
2496 auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
2498 return std::nullopt;
2499 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
2500 auto innerForResult =
2501 dyn_cast<OpResult>(yieldOp.getOperand(
result.getResultNumber()));
2502 if (!innerForResult)
2503 return std::nullopt;
2505 loops = loops.drop_front();
2507 LoopLikeOpInterface loop = loops.front();
2508 if (
result.getOwner() != loop)
2509 return std::nullopt;
2510 auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
2512 return std::nullopt;
2513 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
2514 auto insertSliceOp = yieldOp.getOperand(
result.getResultNumber())
2515 .getDefiningOp<tensor::InsertSliceOp>();
2517 return std::nullopt;
2518 return insertSliceOp;
2521FailureOr<scf::SCFFuseConsumerOfSliceResult>
2522mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer,
2523 MutableArrayRef<LoopLikeOpInterface> loops) {
2524 if (!isa<TilingInterface>(consumer)) {
2526 consumer,
"unhandled consumer that does not implement TilingInterface");
2531 if (loops.empty()) {
2533 consumer,
"cannot call tile and fuse consumer with an empty loop nest");
2536 LoopLikeOpInterface outermostLoop = loops.front();
2540 SmallVector<OpOperand *> consumerFusableOperands;
2542 if (opOperand.get().getDefiningOp() == outermostLoop) {
2543 consumerFusableOperands.push_back(&opOperand);
2548 if (consumerFusableOperands.empty()) {
2549 return mlir::scf::SCFFuseConsumerOfSliceResult{consumerFusableOperands,
2550 SmallVector<OpOperand *>{},
2551 SmallVector<Operation *>{}};
2556 SmallVector<Operation *> candidateSlices;
2557 candidateSlices.reserve(consumerFusableOperands.size());
2558 for (OpOperand *opOperand : consumerFusableOperands) {
2559 std::optional<Operation *> slice =
2564 "couldnt find producing insert-slice like operation for operand");
2566 candidateSlices.push_back(slice.value());
2569 rewriter, consumer, consumerFusableOperands, candidateSlices, loops);
2576FailureOr<SmallVector<scf::ForOp>>
2577mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
2578 TilingInterface op) {
2580 if (op->getNumResults() > 0) {
2582 op,
"unable to lower to loops operations with return values");
2585 SmallVector<Range> domain = op.getIterationDomain(rewriter);
2586 SmallVector<Value> ivs;
2587 SmallVector<scf::ForOp> loops;
2588 Location loc = op.getLoc();
2589 for (
auto loopRange : domain) {
2596 auto loop = scf::ForOp::create(rewriter, op.getLoc(), offsetVal, sizeVal,
2598 loops.push_back(loop);
2599 ivs.push_back(loop.getInductionVar());
2602 if (
failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static llvm::ManagedStatic< PassManagerOptions > options
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult givenTileSize, OpFoldResult numThreads, OpFoldResult iterationSize)
Returns true if the maximum tile offset tileSize * numThreads-1 is less than iterationSize.
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static FailureOr< MergeResult > mergeTilingResults(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, const SetVector< unsigned > &reductionDims, ValueRange partialResults)
static std::optional< Operation * > getProducingInsertSliceLikeOp(OpResult result, ArrayRef< LoopLikeOpInterface > loops)
For a given result of the loop nest that is a tiled loop nest, return the insert slice-like op that i...
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< LoopLikeOpInterface > loops)
Return the untiled producer whose slice is used in a tiled consumer.
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop< scf::ForallOp >(scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Implementation of yieldTiledValuesAndReplaceLoop for scf.forall
static FailureOr< OpOperand * > getConsumerFromLoopUses(RewriterBase &rewriter, Operation *loopOp, unsigned resultNumber)
Fetches the OpOperand of the first valid user (and use) of the value val which implements TilingInter...
static LogicalResult checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp)
A utility function that checks whether the only use of the result of a tensor.insert_slice op is in a...
std::function< LogicalResult( RewriterBase &rewriter, Location Loc, ValueRange ivs, ArrayRef< OpFoldResult > tileOffsets, ArrayRef< OpFoldResult > tileSizes, ValueRange outerDestinationTensors, SmallVector< Value > &tiledResults, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> GenerateTiledBodyFn
Typedef for function that implements the body of a tiled loop.
static LogicalResult checkTileSizes(TilingInterface op, scf::SCFTilingOptions::LoopType loopType, ReductionTilingStrategy reductionStrategy, ArrayRef< OpFoldResult > givenTileSizes, ArrayRef< OpFoldResult > numThreads)
Checks if any of the tiled loops are not parallel.
static FailureOr< scf::SCFFuseConsumerOfSliceResult > tileAndFuseConsumerOfSlicesImpl(RewriterBase &rewriter, Operation *consumerOp, ArrayRef< OpOperand * > consumerOpOperands, ArrayRef< Operation * > candidateSlices, MutableArrayRef< LoopLikeOpInterface > loops)
static FailureOr< SmallVector< LoopLikeOpInterface > > generateLoopNestUsingCustomOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes, ValueRange outerDestinationTensors, const scf::SCFTilingOptions::GenerateLoopHeaderFn &generateLoopHeaderFn, const scf::SCFTilingOptions::GenerateLoopTerminatorFn &generateLoopTerminatorFn, GenerateTiledBodyFn tiledBodyFn)
Generate the tile-loop nest using custom loop operation.
static FailureOr< SmallVector< LoopLikeOpInterface > > generateLoopNest(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes, ArrayRef< OpFoldResult > numThreads, ValueRange destinationTensors, GenerateTiledBodyFn tiledBodyFn)
Generate the tile-loop nest using the loop construct specifed in options.
static FailureOr< SmallVector< LoopLikeOpInterface > > generateLoopNestUsingForOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes, ValueRange outerDestinationTensors, GenerateTiledBodyFn tiledBodyFn)
Generate the tile-loop nest using scf.for operation.
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static FailureOr< OpOperand * > getUntiledConsumerFromSlice(RewriterBase &rewriter, tensor::InsertSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Fetch the untiled consumer of the outermost scf.for's result which is yielded by a tensor....
static SmallVector< tensor::InsertSliceOp > cloneAsInsertSlices(RewriterBase &rewriter, ArrayRef< Operation * > candidateSlices)
static FailureOr< SmallVector< OpOperand * > > getUntiledConsumerOperandsFromSlices(RewriterBase &rewriter, ArrayRef< Operation * > sliceOps, MutableArrayRef< LoopLikeOpInterface > loops)
A utility to fetch an untiled consumer of tensor.insert_slice/tensor.parallel_insert_slice.
tensor::InsertSliceOp cloneAsInsertSlice< tensor::ParallelInsertSliceOp >(RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp)
static FailureOr< Operation * > getFirstUserOfLoop(Operation *loopOp)
An utility to get the first user of the given loopOp.
static Operation * cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *op, ValueRange newDestArgs)
Clones the operation and updates the destination if the operation implements the DestinationStyleOpIn...
static FailureOr< llvm::SetVector< Operation * > > checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, bool reorderOperations)
This utility currently checks whether the first userOp of loop is NOT before the last defineOp of con...
std::function< LogicalResult( RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, SmallVector< Value > &tiledValues, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> YieldTiledValuesFn
Typedef for function that allows returning additional yielded values during yieldTiledValuesAndReplac...
tensor::InsertSliceOp cloneAsInsertSlice< tensor::InsertSliceOp >(RewriterBase &rewriter, tensor::InsertSliceOp insertSliceOp)
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > givenTileSizes)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes)
Function to return the bounds of the loops to be generated.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizesWithForAllOp(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > givenTileSizes, ArrayRef< OpFoldResult > numThreads)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, OpFoldResult offset, OpFoldResult givenTileSize)
Returns the bounded tile size given the current offset, loopRange and tileSize, i....
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop< scf::ForOp >(scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Implementation of yieldTiledValuesAndReplaceLoop for scf.for.
static bool tileDividesIterationDomain(Range loopRange)
Check if stride evenly divides the trip count size - offset.
static SetVector< unsigned > getSanitizedReductionDims(ArrayRef< OpFoldResult > givenTileSizes, const scf::SCFTilingOptions &options)
Get the reduction dims that are tiled.
static LogicalResult addInitOperandsToLoopNest(RewriterBase &rewriter, MutableArrayRef< LoopLikeOpInterface > loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn)
Method to add new init values to a loop nest.
static FailureOr< SmallVector< LoopLikeOpInterface > > generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes, ArrayRef< OpFoldResult > numThreads, ArrayRef< Attribute > mappingVector, ValueRange outerDestinationTensors, GenerateTiledBodyFn tiledBodyFn)
Generate the tile-loop nest using scf.forall operation.
static std::optional< Operation * > getProducingParallelInsertSlice(scf::ForallOp forallOp, OpResult result)
For a given result of a forallOp return the tensor.parallel_insert_slice op (or combining op) that is...
static tensor::InsertSliceOp cloneAsInsertSlice(RewriterBase &rewriter, InsertSliceOpTy sliceOp)
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims)
static FailureOr< SmallVector< Value > > createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims)
static SmallVector< OpFoldResult > getSplitReductionIvs(RewriterBase &rewriter, Location loc, ReductionTilingStrategy reductionStrategy, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims)
For the case of ReductionTilingStrategy::PartialReductionOuterParallel the PartialReductionOpInterfac...
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, ArrayRef< Range > iterationDomain, const scf::SCFTilingOptions &options)
Method to instantiate the tile sizes and/or number of threads specified by the user.
static LogicalResult verifyOptions(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options)
Verify the tile size options are set in a consistent manner.
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Append the specified additional newInitOperands operands to the loops existing init operands (or simi...
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class represents an argument of a Block.
Block represents an ordered list of Operations.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
IntegerAttr getIndexAttr(int64_t value)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
unsigned getNumResults()
Return the number of results held by this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
iterator_range< use_iterator > use_range
Operation * getOwner() const
Return the owner of this operand.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
FailureOr< TilingResult > replaceInsertSlicesWithTiledConsumer(OpBuilder &builder, ArrayRef< tensor::InsertSliceOp > sliceOps, ArrayRef< OpOperand * > consumerOperands)
Method to swap tensor.insert_slices with their consumers when the consumer implements the TilingInter...
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Include the generated interface declarations.
bool isPerfectlyNestedForLoops(MutableArrayRef< LoopLikeOpInterface > loops)
Check if the provided loops are perfectly nested for-loops.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
ReductionTilingStrategy
Tiling can be thought of as splitting a dimension into 2 and materializing the outer dimension as a l...
@ PartialReductionOuterReduction
@ PartialReductionOuterParallel
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
llvm::SetVector< T, Vector, Set, N > SetVector
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Container for result values of tiling.