28#include "llvm/ADT/ScopeExit.h"
29#include "llvm/ADT/TypeSwitch.h"
30#include "llvm/Support/Debug.h"
33#define DEBUG_TYPE "tile-using-interface"
37scf::SCFTilingOptions &
39 assert(!tileSizeComputationFunction &&
"tile sizes already set");
40 auto tileSizes = llvm::to_vector(ts);
47scf::SCFTilingOptions &
49 assert(!numThreadsComputationFunction &&
"num tiles already set");
50 auto numThreads = llvm::to_vector(nt);
61 size_t iterationDomainSize) {
63 if (filledVector.size() < iterationDomainSize) {
64 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
65 filledVector.append(range.begin(), range.end());
67 if (filledVector.size() > iterationDomainSize)
68 filledVector.resize(iterationDomainSize);
78 const scf::SCFTilingOptions &
options) {
80 if (
options.numThreadsComputationFunction &&
81 options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) {
83 loc,
"number of threads can only by specified when loop type is "
84 "set to use `scf.forall`");
88 if (!
options.interchangeVector.empty()) {
91 loc,
"invalid interchange vector, not a permutation of the entire "
103 const scf::SCFTilingOptions &
options) {
106 size_t numLoops = iterationDomain.size();
109 if (
options.numThreadsComputationFunction) {
110 numThreads =
options.numThreadsComputationFunction(rewriter, op);
111 numThreads.resize(numLoops, zero);
114 if (
options.tileSizeComputationFunction) {
115 tileSizes =
options.tileSizeComputationFunction(rewriter, op);
116 tileSizes.resize(numLoops, zero);
117 return {tileSizes, numThreads};
129 tileSizes.resize(numLoops, zero);
130 for (
auto [
index, range, nt] :
131 llvm::enumerate(iterationDomain, numThreads)) {
136 rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
138 tileSizes.resize(numLoops, zero);
139 return {tileSizes, numThreads};
146 assert(
options.tileSizeComputationFunction &&
147 "expected tile sizes to be specified");
148 tileSizes =
options.tileSizeComputationFunction(rewriter, op);
149 tileSizes.resize(numLoops, zero);
151 return {tileSizes, numThreads};
156 scf::SCFTilingOptions::LoopType loopType,
160 auto iterators = op.getLoopIteratorTypes();
161 assert(iterators.size() == givenTileSizes.size() &&
162 "expected as many tile size values as number of loops");
163 assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
164 "when specified, expected number of threads to use for each loop");
166 bool isParallelTiling =
false;
167 for (
auto [
index, iterator, givenTileSize] :
168 llvm::enumerate(iterators, givenTileSizes)) {
170 isParallelTiling |= iterator == utils::IteratorType::parallel;
173 if (loopType == scf::SCFTilingOptions::LoopType::ForallOp &&
177 if (!numThreads.empty()) {
178 if (std::optional<int64_t> constNumThreads =
180 if (constNumThreads.value() > 1 &&
181 iterator != utils::IteratorType::parallel) {
182 op.emitWarning() <<
"tiling is not thread safe at axis #" <<
index;
188 if (std::optional<int64_t> constTileSize =
190 if (constTileSize.value() > 0 &&
191 iterator != utils::IteratorType::parallel) {
192 op.emitWarning() <<
"tiling is not thread safe at axis #" <<
index;
199 if (isParallelTiling) {
200 return op->emitOpError(
"tiling parallel dimensions is not supported with "
201 "partial reduction tiling strategies");
211 const scf::SCFTilingOptions &
options) {
213 for (
auto dim :
options.reductionDims) {
216 reductionDims.insert(dim);
218 return reductionDims;
232 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
241 if (ts && ts.value() == 1)
242 return givenTileSize;
246 return givenTileSize;
268 if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
270 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
281 int materializedLoopNum = 0;
282 for (
auto [givenTileSize, loopRange] :
283 llvm::zip_equal(givenTileSizes, iterationDomain)) {
288 offsets.push_back(loopRange.offset);
289 sizes.push_back(loopRange.size);
293 Value iv = ivs[materializedLoopNum++];
295 offsets.push_back(offset);
298 sizes.push_back(size);
300 return {offsets, sizes};
309 for (
auto [loopRange, givenTileSize] :
310 llvm::zip_equal(loopRanges, givenTileSizes)) {
314 lbs.push_back(loopRange.offset);
315 ubs.push_back(loopRange.size);
316 steps.push_back(givenTileSize);
318 return {lbs, ubs, steps};
372 if (newDestArgs.empty())
374 if (
auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
375 destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
392 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
393 assert(loopRanges.size() == givenTileSizes.size() &&
394 "expected as many tile sizes as loop ranges");
398 std::tie(lbs, ubs, steps) =
409 ValueRange innerDestinationTensors(outerDestinationTensors);
410 for (
auto [lb,
ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
412 scf::ForOp::create(rewriter, loc, lb,
ub, step, innerDestinationTensors,
415 loops.push_back(loop);
416 ivs.push_back(loop.getInductionVar());
418 innerDestinationTensors = loop.getRegionIterArgs();
425 std::tie(offsets, sizes) =
430 if (failed(tiledBodyFn(rewriter, loc, ivs, offsets, sizes,
431 innerDestinationTensors, tiledResults, resultOffsets,
434 loc,
"failed to generate inner tile loop body");
439 assert(tiledResults.size() == innerDestinationTensors.size() &&
440 "Number of results of body should be equal to number of iter args");
444 for (
auto [tiledValue, destinationTensor, resultOffset, resultSize] :
445 llvm::zip_equal(tiledResults, innerDestinationTensors, resultOffsets,
449 auto insertSlice = tensor::InsertSliceOp::create(
450 rewriter, loc, tiledValue, destinationTensor, resultOffset, resultSize,
452 yieldedValues.push_back(insertSlice);
454 scf::YieldOp::create(rewriter, loc, yieldedValues);
457 for (
auto [outerLoop, innerLoop] :
461 cast<scf::ForOp>(outerLoop.getOperation()).getBody());
462 scf::YieldOp::create(rewriter, outerLoop.getLoc(), innerLoop->getResults());
479 if (numThreads.empty()) {
485 int materializedLoopNum = 0;
491 offsetExpr = d0 + d1 * s0;
492 residualTileSizeExpr = s1 - (d0 + d1 * s0);
494 for (
auto [
index, nt, givenTileSize, loopRange] :
495 llvm::enumerate(numThreads, givenTileSizes, iterationDomain)) {
500 offsets.push_back(loopRange.offset);
501 sizes.push_back(loopRange.size);
505 Value iv = ivs[materializedLoopNum++];
507 rewriter, loc, offsetExpr,
510 rewriter, loc, residualTileSizeExpr,
511 {loopRange.offset, nt, givenTileSize, loopRange.size});
517 {offset, loopRange.size});
521 {sizeMinusOffsetPerThread, givenTileSize});
537 rewriter, loc, maxMap, {rewriter.
getIndexAttr(0), size});
540 offsets.push_back(offset);
541 sizes.push_back(size);
543 return {offsets, sizes};
556static FailureOr<SmallVector<LoopLikeOpInterface>>
564 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
565 assert(loopRanges.size() == givenTileSizes.size() &&
566 "expected as many tile sizes as loop ranges");
569 std::optional<ArrayAttr> mappingAttr;
570 if (!mappingVector.empty())
573 scf::ForallOp forallOp;
574 bool useNumThreads = !numThreads.empty();
580 for (
auto nt : numThreads) {
583 nonZeroNumThreads.push_back(nt);
585 forallOp = scf::ForallOp::create(rewriter, loc, nonZeroNumThreads,
586 outerDestinationTensors, mappingAttr);
589 std::tie(lbs, ubs, steps) =
591 forallOp = scf::ForallOp::create(rewriter, loc, lbs, ubs, steps,
592 outerDestinationTensors, mappingAttr);
594 loops.push_back(forallOp);
597 ValueRange innerDestinationTensors = forallOp.getRegionOutArgs();
603 rewriter, loc, ivs, loopRanges, givenTileSizes, numThreads);
607 if (failed(tiledBodyFn(rewriter, loc, ivs, offsets, sizes,
608 innerDestinationTensors, tiledResults, resultOffsets,
613 for (
auto [tiledValue, destinationTensor, resultOffset, resultSize] :
614 llvm::zip_equal(tiledResults, innerDestinationTensors, resultOffsets,
619 tensor::ParallelInsertSliceOp::create(rewriter, loc, tiledValue,
620 destinationTensor, resultOffset,
621 resultSize, resultStride);
636static FailureOr<SmallVector<LoopLikeOpInterface>>
640 const scf::SCFTilingOptions::GenerateLoopHeaderFn &generateLoopHeaderFn,
641 const scf::SCFTilingOptions::GenerateLoopTerminatorFn
642 &generateLoopTerminatorFn,
644 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
645 assert(loopRanges.size() == givenTileSizes.size() &&
646 "expected as many tile sizes as loop ranges");
647 assert(generateLoopHeaderFn && generateLoopTerminatorFn &&
648 "expected loop header/terminator generation function");
651 FailureOr<scf::SCFTilingOptions::CustomLoopHeaderInfo> loopHeaderInfo =
652 generateLoopHeaderFn(rewriter, loc, loopRanges, givenTileSizes,
653 outerDestinationTensors);
654 if (failed(loopHeaderInfo)) {
661 if (failed(tiledBodyFn(rewriter, loc, ivs, loopHeaderInfo->tileOffset,
662 loopHeaderInfo->tileSizes,
663 loopHeaderInfo->destinationTensors, tiledResults,
664 resultOffsets, resultSizes))) {
668 if (failed(generateLoopTerminatorFn(rewriter, loc, loopHeaderInfo->loops,
669 tiledResults, resultOffsets, resultSizes,
670 loopHeaderInfo->destinationTensors))) {
674 return loopHeaderInfo->loops;
698 llvm::map_to_vector(loopRanges, [](
Range r) {
return r.
offset; });
700 llvm::map_to_vector(loopRanges, [](
Range r) {
return r.
size; });
701 if (failed(tiledBodyFn(rewriter, loc,
ValueRange{}, tileOffsets, tileSizes,
702 destinationTensors, tiledResults, resultOffsets,
708 if (
options.loopType == scf::SCFTilingOptions::LoopType::ForOp) {
710 destinationTensors, tiledBodyFn);
712 if (
options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
714 rewriter, loc, loopRanges, givenTileSizes, numThreads,
715 options.mappingVector, destinationTensors, tiledBodyFn);
717 if (
options.loopType == scf::SCFTilingOptions::LoopType::CustomOp) {
719 rewriter, loc, loopRanges, givenTileSizes, destinationTensors,
739 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
741 return op->emitOpError(
742 "PartialReductionOuterReduction tiling strategy is only supported for "
743 "operations implementing PartialReductionOpInterface");
748 AffineExpr sizeExpr = ((s0 - s1).ceilDiv(s2));
750 for (
auto [
index, domain, tileSize] :
751 llvm::enumerate(iterationDomain, givenTileSizes)) {
752 if (!numThreads.empty()) {
756 rewriter, op.getLoc(), sizeExpr,
757 {domain.size, domain.offset, domain.stride});
767 rewriter, op.getLoc(), sizeExpr,
768 {domain.size, domain.offset, domain.stride});
772 if (reductionStrategy ==
774 sizes[
index] = tileSize;
778 assert(reductionStrategy ==
781 rewriter, op.getLoc(), sizeExpr,
782 {domain.size, domain.offset, domain.stride});
784 rewriter, op.getLoc(), divExpr, {normalizedRange, tileSize});
786 return redOp.generateInitialTensorForPartialReduction(rewriter, loc, sizes,
800 splitReductionIvs.resize(reductionDims.size(), rewriter.
getIndexAttr(0));
805 if (reductionStrategy ==
807 for (
auto [
index, reductionDim] : llvm::enumerate(reductionDims)) {
808 if (!numThreads.empty()) {
809 splitReductionIvs[
index] = ivs[ivIndex++];
813 rewriter, loc, divExpr,
817 return splitReductionIvs;
820static FailureOr<TilingResult>
829 return op.getTiledImplementation(rewriter, offsets, sizes);
832 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
835 op,
"PartialReductionOuterReduction tiling strategy is only "
836 "supported for operations "
837 "implementing PartialReductionOpInterface");
842 numThreads, givenTileSizes, reductionDims);
843 return redOp.tileToPartialReduction(rewriter, op.getLoc(), reductionStrategy,
844 regionIterArg, offsets, sizes,
845 reductionDims, splitReductionIvs);
859 return op.getResultTilePosition(rewriter,
index, offsets, sizes,
860 resultOffset, resultSize);
862 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
865 op,
"PartialReductionOuterReduction tiling strategy is only supported"
866 "for operations implementing PartialReductionOpInterface");
870 numThreads, givenTileSizes, reductionDims);
871 return redOp.getPartialResultTilePosition(
872 rewriter,
index, reductionStrategy, offsets, sizes, reductionDims,
873 splitReductionIvs, resultOffset, resultSize);
876static FailureOr<MergeResult>
882 "expected merge to be called for only partial reduction cases");
884 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
887 op,
"PartialReductionOuterReduction tiling strategy is only "
888 "supported for operations "
889 "implementing PartialReductionOpInterface");
891 return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
902template <
typename LoopType>
903FailureOr<LoopLikeOpInterface>
919 auto inits = llvm::to_vector(loopOp.getInitArgs());
920 inits.append(newInitOperands.begin(), newInitOperands.end());
921 auto newLoop = scf::ForOp::create(
922 rewriter, loc, loopOp.getLowerBound(), loopOp.getUpperBound(),
924 loopOp.getUnsignedCmp());
927 Block *loopBody = loopOp.getBody();
928 Block *newLoopBody = newLoop.getBody();
930 loopBody, newLoopBody,
933 auto yieldOp = cast<scf::YieldOp>(newLoopBody->
getTerminator());
939 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
940 if (
failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
941 newRegionIterArgs, tiledValues, resultOffsets,
948 for (
auto [tiledValue, regionIterArg, resultOffset, resultSize] :
949 llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
953 Value insert = tensor::InsertSliceOp::create(
954 rewriter, yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset,
955 resultSize, resultStride);
956 newYieldValues.push_back(insert);
961 newLoop->getResults().take_front(loopOp.getNumResults()));
962 return cast<LoopLikeOpInterface>(newLoop.getOperation());
973 auto inits = llvm::to_vector(loopOp.getOutputs());
974 inits.append(newInitOperands.begin(), newInitOperands.end());
975 auto newLoop = scf::ForallOp::create(
976 rewriter, loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
977 loopOp.getMixedStep(), inits, loopOp.getMapping(),
981 Block *loopBody = loopOp.getBody();
982 Block *newLoopBody = newLoop.getBody();
984 loopBody, newLoopBody,
987 auto terminator = cast<scf::InParallelOp>(newLoopBody->
getTerminator());
992 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
993 if (
failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
994 regionIterArgs, tiledValues, resultOffsets,
998 "failed to get yielded tiled values");
1004 for (
auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
1005 tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
1008 tensor::ParallelInsertSliceOp::create(rewriter, terminator.getLoc(),
1009 tiledValue, iterArg, resultOffset,
1010 resultSize, resultStride);
1014 newLoop->getResults().take_front(loopOp.getNumResults()));
1015 return cast<LoopLikeOpInterface>(newLoop.getOperation());
1022 LoopLikeOpInterface loopLikeOp,
RewriterBase &rewriter,
1025 loopLikeOp.getOperation())
1026 .Case<scf::ForOp, scf::ForallOp>(
1027 [&](
auto loopOp) -> FailureOr<LoopLikeOpInterface> {
1029 loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
1031 .Default([&](
auto loopOp) -> FailureOr<LoopLikeOpInterface> {
1050 for (
auto &loop : loops.drop_back()) {
1054 auto forLoop = cast<scf::ForOp>(loop.getOperation());
1058 newInits.append(newInitValues.begin(), newInitValues.end());
1059 auto newLoop = scf::ForOp::create(
1060 rewriter, forLoop.getLoc(), forLoop.getLowerBound(),
1061 forLoop.getUpperBound(), forLoop.getStep(), newInits,
1063 forLoop.getUnsignedCmp());
1067 sourceBlockArgs.push_back(newLoop.getInductionVar());
1068 auto newRegionIterArgs = newLoop.getRegionIterArgs();
1069 sourceBlockArgs.append(
1070 newRegionIterArgs.begin(),
1071 std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
1072 rewriter.
mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
1074 forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
1076 ivs.push_back(newLoop.getInductionVar());
1077 newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
1081 LoopLikeOpInterface innerMostLoop = loops.back();
1082 FailureOr<LoopLikeOpInterface> newInnerMostLoop =
1084 getNewTiledYieldsFn);
1086 if (failed(newInnerMostLoop))
1087 return innerMostLoop.emitOpError(
"failed to return additional yields");
1088 loops.back() = newInnerMostLoop.value();
1092 for (
auto [outerLoop, innerLoop] :
1093 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
1095 auto outerForLoop = cast<scf::ForOp>(outerLoop.getOperation());
1096 auto outerLoopYield =
1097 cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
1099 llvm::to_vector(outerLoopYield.getOperands());
1101 innerLoop->getResults().take_back(newInitValues.size());
1102 newYields.append(additionalYields.begin(), additionalYields.end());
1111FailureOr<scf::SCFTilingResult>
1112mlir::scf::tileUsingSCF(
RewriterBase &rewriter, TilingInterface op,
1113 const scf::SCFTilingOptions &
options) {
1126 std::tie(givenTileSizes, numThreads) =
1132 givenTileSizes, numThreads))) {
1143 if (!
options.interchangeVector.empty()) {
1145 iterationDomain.size());
1147 "expected interchange vector to be a permutation");
1151 if (!numThreads.empty())
1155 FailureOr<TilingResult> tilingResult;
1169 if (!interchangeVector.empty()) {
1178 auto clonedOp = cast<TilingInterface>(
1185 tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
1195 regionIterArgs, tileOffsetsVec, tileSizesVec,
1196 ivs, numThreads, givenTileSizes, reductionDims);
1197 if (
failed(tilingResult)) {
1199 return op.emitOpError(
"faild to tile operation");
1207 for (
auto [
index, tiledValue] :
1208 llvm::enumerate(tilingResult->tiledValues)) {
1209 tiledResults.push_back(tiledValue);
1212 rewriter,
options.reductionStrategy,
index, tiledValue, op,
1213 tileOffsetsVec, tileSizesVec, ivs, numThreads, givenTileSizes,
1214 reductionDims, resultOffset, resultSize))) {
1215 for (
auto op : tilingResult->tiledOps) {
1219 op,
"failed to get slice of result produced");
1221 resultOffsets.emplace_back(std::move(resultOffset));
1222 resultSizes.emplace_back(std::move(resultSize));
1230 rewriter, op,
options.reductionStrategy, iterationDomain, numThreads,
1231 givenTileSizes, reductionDims);
1232 if (
failed(maybeInits)) {
1234 op,
"unable to create initial tensors for tiling");
1242 rewriter, op.getLoc(),
options, iterationDomain, givenTileSizes,
1243 numThreads, initTensors, innerYieldTiledValuesFn);
1245 return op.emitOpError(
"failed to generate tiling loops");
1246 assert(succeeded(tilingResult) &&
1247 "expected tiling result to be computed after loop generation");
1248 std::swap(loops, loopsOr.value());
1251 if (loops.empty()) {
1254 return scf::SCFTilingResult{tilingResult->tiledOps,
1257 tilingResult->tiledValues,
1258 tilingResult->generatedSlices,
1262 auto loopResults = llvm::map_to_vector(loops.front()->getResults(),
1267 return scf::SCFTilingResult{
1268 tilingResult->tiledOps, initTensors, loops, loopResults,
1269 tilingResult->generatedSlices, {}};
1274 rewriter, op,
options.reductionStrategy, reductionDims, loopResults);
1275 if (
failed(mergeResult)) {
1277 op,
"Failed to merge partial results from tiling");
1279 return scf::SCFTilingResult{tilingResult->tiledOps,
1282 mergeResult->replacements,
1283 tilingResult->generatedSlices,
1284 mergeResult->mergeOps};
1287FailureOr<scf::SCFTilingResult>
1289 PartialReductionOpInterface op,
1291 scf::SCFTilingOptions
options;
1292 options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
1293 options.setReductionTilingStrategy(
1295 options.setTileSizes(tileSize);
1297 for (
auto [
index, iteratorType] : llvm::enumerate(op.getLoopIteratorTypes()))
1298 if (iteratorType == utils::IteratorType::reduction)
1299 reductionDims.push_back(
index);
1300 options.setReductionDims(reductionDims);
1301 return tileUsingSCF(
b, op,
options);
1314static std::tuple<OpResult, std::optional<OpOperand *>>
1317 std::optional<OpOperand *> destinationIterArg;
1318 assert(!loops.empty() &&
"expected non empty loops container");
1319 auto loopIt = loops.rbegin();
1320 while (loopIt != loops.rend() && isa<BlockArgument>(source->
get())) {
1321 auto iterArg = cast<BlockArgument>(source->
get());
1322 auto loop = *loopIt;
1323 if (iterArg.getOwner()->getParentOp() != loop)
1325 source = loop.getTiedLoopInit(iterArg);
1328 if (loopIt == loops.rend())
1329 destinationIterArg = source;
1330 return {dyn_cast<OpResult>(source->
get()), destinationIterArg};
1335std::optional<scf::SCFFuseProducerOfSliceResult>
1336mlir::scf::tileAndFuseProducerOfSlice(
1337 RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1341 auto [fusableProducer, destinationInitArg] =
1344 if (!fusableProducer)
1345 return std::nullopt;
1346 unsigned resultNumber = fusableProducer.getResultNumber();
1354 Operation *fusableProducerOp = fusableProducer.getOwner();
1355 if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
1357 rewriter, fusableProducerOp->
getLoc(), fusableProducerOp,
1358 origDestinationTensors)))
1359 return std::nullopt;
1361 clonedOpDestinationTensors = origDestinationTensors;
1362 if (destinationInitArg &&
1363 isa<DestinationStyleOpInterface>(fusableProducerOp)) {
1367 clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1371 rewriter, fusableProducerOp, clonedOpDestinationTensors);
1376 llvm::to_vector(candidateSliceOp->getOperands());
1377 candidateSliceOpOperands[0] = clonedProducerOp->
getResult(resultNumber);
1378 tensor::ExtractSliceOp clonedCandidateSliceOp =
1380 candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1383 FailureOr<TilingResult> tileAndFuseResult =
1385 rewriter, clonedCandidateSliceOp,
1386 clonedProducerOp->
getResult(resultNumber));
1387 if (
failed(tileAndFuseResult))
1388 return std::nullopt;
1392 tileAndFuseResult->tiledValues[0]);
1393 rewriter.
eraseOp(clonedCandidateSliceOp);
1394 rewriter.
eraseOp(clonedProducerOp);
1439 if (destinationInitArg &&
1440 isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
1442 ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1443 .set(origDestinationTensors[resultNumber]);
1445 return scf::SCFFuseProducerOfSliceResult{
1446 fusableProducer, tileAndFuseResult->tiledValues[0],
1447 tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices};
1451FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
1452 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1453 scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
1459 Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
1460 *tiledOwner = fusedProducerInfo.tiledOps[0];
1465 yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
1467 :
llvm::to_vector(yieldResultNumber);
1469 for (
const auto &resultNumber : initNumberList) {
1471 rewriter, loc, originalOwner->
getResult(resultNumber));
1472 if (succeeded(initValue)) {
1473 initValueList.push_back(initValue.value());
1489 sliceSizes = sliceOp.getMixedSizes();
1492 if (!llvm::all_of(sliceOp.getMixedStrides(),
isOneInteger))
1495 unsigned sliceResultNumber =
1496 fusedProducerInfo.origProducer.getResultNumber();
1498 auto tilableOp = cast<TilingInterface>(originalOwner);
1502 if (tilableOp->getNumResults() > 1 &&
1503 failed(tilableOp.getIterationDomainTileFromResultTile(
1504 rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1505 iterDomainOffset, iterDomainSizes))) {
1520 for (
const auto &resultNumber : initNumberList) {
1521 if (resultNumber == sliceResultNumber) {
1522 offsetList.push_back(sliceOffset);
1523 sizesList.push_back(sliceSizes);
1525 assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1528 if (
failed(tilableOp.getResultTilePosition(
1529 rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1533 offsetList.push_back(offset);
1534 sizesList.push_back(sizes);
1540 if (
auto tiledDestStyleOp =
1541 dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1543 for (
const auto &&[
index, newRegionArg] :
1544 llvm::enumerate(newRegionIterArgs)) {
1545 auto destSlice = tensor::ExtractSliceOp::create(
1546 rewriter, loc, newRegionArg, offsetList[
index], sizesList[
index],
1549 generatedSlices.push_back(destSlice);
1550 unsigned resultNumber = initNumberList[
index];
1552 tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1561 for (
const auto &&[
index, resultNumber] : llvm::enumerate(initNumberList)) {
1562 tiledResult.push_back(tiledOwner->getResult(resultNumber));
1563 tiledOffset.emplace_back(offsetList[
index]);
1564 tiledSizes.emplace_back(sizesList[
index]);
1570 newYieldValuesFn))) {
1573 return generatedSlices;
1587 explicit SliceTrackingListener(
1588 std::optional<FrozenRewritePatternSet> patterns);
1589 SliceTrackingListener() =
default;
1595 LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
1598 void notifyOperationInserted(Operation *op,
1599 OpBuilder::InsertPoint previous)
override;
1602 void removeOp(Operation *op);
1605 void notifyOperationErased(Operation *op)
override;
1612 std::deque<tensor::ExtractSliceOp> worklist;
1617 std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
1620SliceTrackingListener::SliceTrackingListener(
1621 std::optional<FrozenRewritePatternSet> p) {
1626SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
1627 for (Operation *op : ops) {
1628 if (
auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1629 worklist.push_back(slice);
1637 GreedyRewriteConfig().setListener(
this).setStrictness(
1638 GreedyRewriteStrictness::ExistingAndNewOps));
1641void SliceTrackingListener::notifyOperationInserted(
1642 Operation *op, OpBuilder::InsertPoint previous) {
1643 auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1646 worklist.push_back(slice);
1652void SliceTrackingListener::removeOp(Operation *op) {
1653 if (!isa<tensor::ExtractSliceOp>(op))
1655 auto iter = worklist.begin();
1656 while (iter != worklist.end()) {
1661 if (iter == worklist.end())
1664 worklist.erase(iter);
1667void SliceTrackingListener::notifyOperationErased(Operation *op) {
1671void SliceTrackingListener::notifyOperationReplaced(Operation *op,
1683class ReplacementListener :
public RewriterBase::ForwardingListener {
1685 ReplacementListener(DenseMap<Value, Value> &replacements,
1686 OpBuilder::Listener *listener)
1687 : ForwardingListener(listener), replacements(replacements) {}
1689 void updateReplacementValues(
ValueRange origValues,
1693 for (
auto &[key, val] : replacements) {
1694 for (
auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) {
1702 void notifyOperationReplaced(Operation *op, Operation *newOp)
override {
1703 ForwardingListener::notifyOperationReplaced(op, newOp);
1707 void notifyOperationReplaced(Operation *op,
ValueRange values)
override {
1708 ForwardingListener::notifyOperationReplaced(op, values);
1709 updateReplacementValues(op->
getResults(), values);
1713 DenseMap<Value, Value> &replacements;
1719FailureOr<scf::SCFTileAndFuseResult>
1720mlir::scf::tileConsumerAndFuseProducersUsingSCF(
1721 RewriterBase &rewriter, TilingInterface consumer,
1722 const scf::SCFTileAndFuseOptions &
options) {
1725 if (!consumer->getNumResults()) {
1727 consumer,
"invalid pattern for op with no results");
1733 FailureOr<scf::SCFTilingResult> tilingResult =
1734 tileUsingSCF(rewriter, consumer,
options.tilingOptions);
1736 if (
failed(tilingResult))
1738 tiledAndFusedOps.insert_range(tilingResult->tiledOps);
1740 DenseMap<Value, Value> replacements;
1742 llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
1747 auto &loops = tilingResult->loops;
1748 if (loops.empty()) {
1749 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1756 OpBuilder::Listener *previousListener = rewriter.
getListener();
1757 auto resetListener =
1758 llvm::make_scope_exit([&]() { rewriter.
setListener(previousListener); });
1759 ReplacementListener replaceListener(replacements, previousListener);
1769 struct WorklistItem {
1770 tensor::ExtractSliceOp candidateSlice;
1771 SCFTileAndFuseOptions::ControlFnResult controlFnResult;
1774 SliceTrackingListener sliceTracker =
1775 SliceTrackingListener(
options.cleanupPatterns);
1778 sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1781 OpBuilder::InsertionGuard g(rewriter);
1782 while (!sliceTracker.worklist.empty()) {
1783 auto candidateSlice = sliceTracker.worklist.front();
1784 sliceTracker.worklist.pop_front();
1786 auto [fusableProducer, destinationInitArg] =
1789 if (!fusableProducer)
1792 std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1793 options.fusionControlFn(candidateSlice, fusableProducer,
1794 destinationInitArg.has_value());
1795 if (!controlFnResult)
1798 WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
1803 std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1804 tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice,
1809 SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
1811 if (worklistItem.controlFnResult.yieldProducerReplacement) {
1816 Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1817 FailureOr<SmallVector<Operation *>> newSlices =
1818 yieldReplacementForFusedProducer(rewriter,
1819 worklistItem.candidateSlice,
1820 fusedResult.value(), loops);
1823 fusableProducerOp,
"failed to replacement value for this "
1824 "operation from within the tiled loop");
1826 worklistCandidates.append(newSlices.value());
1827 for (
auto [index,
result] :
1828 llvm::enumerate(fusableProducerOp->
getResults())) {
1829 replacements[
result] = loops.front()->getResult(
1830 loops.front()->getNumResults() -
1834 if (Operation *tiledAndFusedOp =
1835 fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1836 fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1837 tiledAndFusedOps.insert(tiledAndFusedOp);
1840 if (
failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1845 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1859 if (!llvm::hasSingleElement(uses)) {
1860 LLVM_DEBUG(llvm::dbgs() <<
"Too many uses of the candidate slice op\n");
1863 OpOperand &operandUse = (*uses.begin());
1865 if (!isa<scf::YieldOp>(userOp)) {
1866 LLVM_DEBUG(llvm::dbgs()
1867 <<
"Expected scf.yield to be the only user, but got -> "
1872 LLVM_DEBUG(llvm::dbgs() <<
"Expected tensor.insert_slice and scf.yield to "
1873 "be in the same block\n");
1882 if (!isa<LoopLikeOpInterface>(loopOp))
1901 if (isa<tensor::ParallelInsertSliceOp>(userOp))
1902 userOp = userOp->getParentOfType<scf::InParallelOp>();
1904 if (loopOp->
getBlock() != userOp->getBlock())
1908 firstUserOfLoop = userOp;
1910 return firstUserOfLoop;
1951static FailureOr<llvm::SetVector<Operation *>>
1953 bool reorderOperations) {
1955 if (failed(firstUserOfLoop))
1961 options.omitBlockArguments =
true;
1962 bool includeLoopOp =
false;
1965 includeLoopOp =
true;
1975 assert(
result.succeeded() &&
"expected a backward slice");
1979 if (!slice.empty()) {
1989 if (includeLoopOp || !reorderOperations)
2001 unsigned resultNumber) {
2002 if (!isa<LoopLikeOpInterface>(loopOp))
2007 Operation *consumerOp = opOperand.getOwner();
2009 if (!isa<TilingInterface>(consumerOp) ||
2010 !isa<DestinationStyleOpInterface>(consumerOp)) {
2017 if (loopBlock != consumerOp->
getBlock())
2024 FailureOr<llvm::SetVector<Operation *>> slice =
2030 if (!slice->empty()) {
2033 assert(succeeded(firstUserOfLoop) &&
"First user of loop is not found");
2034 for (
auto op : *slice) {
2049static FailureOr<OpOperand *>
2051 tensor::InsertSliceOp candidateSliceOp,
2053 assert(!loops.empty() &&
"unexpected loops to be empty");
2056 if (containingOp != loops.back()) {
2059 "expected slice to be within body of inner-most loop");
2065 candidateSliceOp,
"expected passed loops to be perfectly nested.");
2070 Value sliceResult = candidateSliceOp.getResult();
2076 scf::ForOp topLevelForOp = cast<scf::ForOp>(loops.front().getOperation());
2083static FailureOr<OpOperand *>
2085 tensor::ParallelInsertSliceOp candidateSliceOp,
2087 assert(!loops.empty() &&
"unexpected loops to be empty");
2089 if (loops.size() != 1) {
2091 candidateSliceOp,
"expected single surrounding scf.forall");
2093 auto forallOp = dyn_cast<scf::ForallOp>(loops.front().getOperation());
2096 candidateSliceOp,
"expected single surrounding scf.forall");
2100 Value sliceDest = candidateSliceOp.getDest();
2101 auto iterArg = dyn_cast<BlockArgument>(sliceDest);
2104 if (iterArg.getOwner()->getParentOp() != forallOp)
2107 unsigned resultNumber =
2108 forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
2119 assert(!loops.empty() &&
"unexpected empty loops");
2120 assert(!sliceOps.empty() &&
"unexpected empty list of candidate slices");
2122 for (
auto sliceOp : sliceOps) {
2123 FailureOr<OpOperand *> fusedOperand =
2125 .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2132 if (failed(fusedOperand)) {
2135 if (!fusedOperands.empty() &&
2136 fusedOperand.value()->getOwner() != fusedOperands.front()->getOwner()) {
2138 fusedOperand.value()->getOwner(),
2139 "all candidate slices must be to the same consumer");
2141 fusedOperands.push_back(fusedOperand.value());
2143 return fusedOperands;
2146template <
typename InsertSliceOpTy>
2148 InsertSliceOpTy sliceOp);
2151tensor::InsertSliceOp
2153 tensor::InsertSliceOp insertSliceOp) {
2154 return cast<tensor::InsertSliceOp>(
2155 rewriter.
clone(*insertSliceOp.getOperation()));
2160 RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp) {
2161 return tensor::InsertSliceOp::create(
2162 rewriter, insertSliceOp->getLoc(), insertSliceOp.getSource(),
2163 insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(),
2164 insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
2167static SmallVector<tensor::InsertSliceOp>
2170 assert(!candidateSlices.empty() &&
2171 "unexpected empty list of slices to clone");
2173 for (
auto sliceOp : candidateSlices) {
2175 .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2178 clonedSlices.push_back(clonedOp);
2181 .DefaultUnreachable(
2182 "unexpected slice type while cloning as insert slice");
2184 return clonedSlices;
2187static FailureOr<scf::SCFFuseConsumerOfSliceResult>
2192 assert(!loops.empty() &&
"expected loops to be not empty");
2197 loops.front(),
"the first user of loop should not dominate any define "
2198 "of consumer operand(s)");
2201 LoopLikeOpInterface outerMostLoop = loops.front();
2202 LoopLikeOpInterface innerMostLoop = loops.back();
2206 auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
2209 "consumer op is not DPS operation");
2210 if (llvm::any_of(consumerOpOperands, [&](
OpOperand *opOperand) {
2211 return dstOp.isDpsInit(opOperand);
2215 "consumer op taking the result of scf.for as init is not supported");
2222 if (failed(firstUserOfLoop)) {
2224 outerMostLoop,
"could not find the first user of outer most loop");
2226 rewriter.
moveOpBefore(outerMostLoop, *firstUserOfLoop);
2233 dyn_cast<tensor::ParallelInsertSliceOp>(candidateSlices.front())) {
2234 auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
2244 auto clonedConsumerOp = cast<TilingInterface>(rewriter.
clone(*consumerOp));
2246 llvm::map_to_vector(consumerOpOperands, [](
OpOperand *opOperand) {
2250 llvm::map_to_vector(operandNumbers, [&](
unsigned operandNum) {
2251 return &clonedConsumerOp->getOpOperand(operandNum);
2257 for (
auto [operandToReplace, clonedSliceOp] :
2258 llvm::zip_equal(clonedOpFusedOperandsList, clonedInsertSlices)) {
2259 operandToReplace->set(clonedSliceOp.getResult());
2265 FailureOr<TilingResult> tileAndFuseResult =
2267 clonedOpFusedOperandsList);
2268 if (failed(tileAndFuseResult)) {
2272 auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
2273 for (
auto [operandNum, clonedSliceOp] :
2274 llvm::zip_equal(operandNumbers, clonedInsertSlices)) {
2276 clonedSliceOp.getSource());
2290 for (
auto candidateSliceOp : clonedInsertSlices) {
2298 candidateSliceOp,
"containingOp's result yield with stride");
2301 allOffsets.emplace_back(std::move(offsets));
2302 allSizes.emplace_back(std::move(sizes));
2312 if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTiles(
2313 rewriter, operandNumbers, allOffsets, allSizes, iterDomainOffsets,
2314 iterDomainSizes))) {
2317 "can't get iter domain position from input position");
2323 unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
2325 totalNumResultsOfConsumer);
2327 totalNumResultsOfConsumer);
2328 for (
auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
2329 if (failed(tiledConsumerOp.getResultTilePosition(
2330 rewriter, idx, iterDomainOffsets, iterDomainSizes,
2331 resultOffsets[idx], resultSizes[idx]))) {
2334 "can't get result domain position from iter domain position");
2340 if (
auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
2341 tiledConsumerOp.getOperation())) {
2343 for (
const auto &&[
index, newRegionArg] :
2344 llvm::enumerate(newRegionIterArgs)) {
2345 auto destSlice = tensor::ExtractSliceOp::create(
2346 rewriter, loc, newRegionArg, resultOffsets[
index],
2352 auto dstNumber =
index;
2354 tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
2364 llvm::enumerate(tiledConsumerOp->getResults())) {
2365 tiledResult.push_back(
result);
2366 tiledOffset.emplace_back(resultOffsets[
index]);
2367 tiledSizes.emplace_back(resultSizes[
index]);
2373 newYieldValuesFn))) {
2375 "unable to add new inits to nest loop");
2381 for (
auto &&[oldResult, newResult] :
2383 loops.front()->getResults().take_back(newInits.size()))) {
2388 rewriter.
eraseOp(clonedConsumerOp);
2391 llvm::map_to_vector(operandNumbers, [&](
unsigned operandNum) {
2392 return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum);
2394 auto consumerOpOperandsVec = llvm::to_vector(consumerOpOperands);
2395 return scf::SCFFuseConsumerOfSliceResult{
2396 std::move(consumerOpOperandsVec), std::move(tiledAndFusedOpOperands),
2397 std::move(tileAndFuseResult->tiledOps)};
2402FailureOr<scf::SCFFuseConsumerOfSliceResult>
2403mlir::scf::tileAndFuseConsumerOfSlices(
2404 RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
2405 MutableArrayRef<LoopLikeOpInterface> loops) {
2406 if (candidateSlices.empty()) {
2409 "no candidate slices provided for consumer fusion");
2413 if (loops.empty()) {
2415 candidateSlices.front(),
2416 "cannot call tile and fuse consumer with an empty loop nest");
2419 if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
2420 llvm::all_of(candidateSlices,
2421 llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
2423 candidateSlices.front(),
2424 "candidates slices need to be all `tensor.extract_slice`s or "
2425 "`tensor.parallel_insert_slice`s");
2430 FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperands =
2432 if (
failed(maybeConsumerOpOperands)) {
2434 "could not fetch consumer to fuse");
2436 Operation *consumerOp = maybeConsumerOpOperands->front()->getOwner();
2439 maybeConsumerOpOperands.value(),
2440 candidateSlices, loops);
2446static std::optional<Operation *>
2448 if (
result.getOwner() != forallOp)
2449 return std::nullopt;
2454 if (combiningOps.size() != 1)
2455 return std::nullopt;
2456 return combiningOps[0];
2461static std::optional<Operation *>
2464 assert(!loops.empty() &&
"Expected loops to be not empty");
2465 LoopLikeOpInterface outerMostLoop = loops.front();
2466 if (
auto forallOp = dyn_cast<scf::ForallOp>(outerMostLoop.getOperation())) {
2467 assert(loops.size() == 1 &&
2468 "expected only a single loop when tiling using scf.forall");
2474 while (loops.size() != 1) {
2475 LoopLikeOpInterface loop = loops.front();
2476 if (
result.getOwner() != loop)
2477 return std::nullopt;
2478 auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
2480 return std::nullopt;
2481 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
2482 auto innerForResult =
2483 dyn_cast<OpResult>(yieldOp.getOperand(
result.getResultNumber()));
2484 if (!innerForResult)
2485 return std::nullopt;
2487 loops = loops.drop_front();
2489 LoopLikeOpInterface loop = loops.front();
2490 if (
result.getOwner() != loop)
2491 return std::nullopt;
2492 auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
2494 return std::nullopt;
2495 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
2496 auto insertSliceOp = yieldOp.getOperand(
result.getResultNumber())
2497 .getDefiningOp<tensor::InsertSliceOp>();
2499 return std::nullopt;
2500 return insertSliceOp;
2503FailureOr<scf::SCFFuseConsumerOfSliceResult>
2504mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer,
2505 MutableArrayRef<LoopLikeOpInterface> loops) {
2506 if (!isa<TilingInterface>(consumer)) {
2508 consumer,
"unhandled consumer that does not implement TilingInterface");
2513 if (loops.empty()) {
2515 consumer,
"cannot call tile and fuse consumer with an empty loop nest");
2518 LoopLikeOpInterface outermostLoop = loops.front();
2522 SmallVector<OpOperand *> consumerFusableOperands;
2524 if (opOperand.get().getDefiningOp() == outermostLoop) {
2525 consumerFusableOperands.push_back(&opOperand);
2530 if (consumerFusableOperands.empty()) {
2531 return mlir::scf::SCFFuseConsumerOfSliceResult{consumerFusableOperands,
2532 SmallVector<OpOperand *>{},
2533 SmallVector<Operation *>{}};
2538 SmallVector<Operation *> candidateSlices;
2539 candidateSlices.reserve(consumerFusableOperands.size());
2540 for (OpOperand *opOperand : consumerFusableOperands) {
2541 std::optional<Operation *> slice =
2546 "couldnt find producing insert-slice like operation for operand");
2548 candidateSlices.push_back(slice.value());
2551 rewriter, consumer, consumerFusableOperands, candidateSlices, loops);
2558FailureOr<SmallVector<scf::ForOp>>
2559mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
2560 TilingInterface op) {
2562 if (op->getNumResults() > 0) {
2564 op,
"unable to lower to loops operations with return values");
2567 SmallVector<Range> domain = op.getIterationDomain(rewriter);
2568 SmallVector<Value> ivs;
2569 SmallVector<scf::ForOp> loops;
2570 Location loc = op.getLoc();
2571 for (
auto loopRange : domain) {
2578 auto loop = scf::ForOp::create(rewriter, op.getLoc(), offsetVal, sizeVal,
2580 loops.push_back(loop);
2581 ivs.push_back(loop.getInductionVar());
2584 if (
failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static llvm::ManagedStatic< PassManagerOptions > options
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult givenTileSize, OpFoldResult numThreads, OpFoldResult iterationSize)
Returns true if the maximum tile offset tileSize * numThreads-1 is less than iterationSize.
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static FailureOr< MergeResult > mergeTilingResults(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, const SetVector< unsigned > &reductionDims, ValueRange partialResults)
static std::optional< Operation * > getProducingInsertSliceLikeOp(OpResult result, ArrayRef< LoopLikeOpInterface > loops)
For a given result of the loop nest that is a tiled loop nest, return the insert slice-like op that i...
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< LoopLikeOpInterface > loops)
Return the untiled producer whose slice is used in a tiled consumer.
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop< scf::ForallOp >(scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Implementation of yieldTiledValuesAndReplaceLoop for scf.forall
static FailureOr< OpOperand * > getConsumerFromLoopUses(RewriterBase &rewriter, Operation *loopOp, unsigned resultNumber)
Fetches the OpOperand of the first valid user (and use) of the value val which implements TilingInter...
static LogicalResult checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp)
A utility function that checks whether the only use of the result of a tensor.insert_slice op is in a...
std::function< LogicalResult( RewriterBase &rewriter, Location Loc, ValueRange ivs, ArrayRef< OpFoldResult > tileOffsets, ArrayRef< OpFoldResult > tileSizes, ValueRange outerDestinationTensors, SmallVector< Value > &tiledResults, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> GenerateTiledBodyFn
Typedef for function that implements the body of a tiled loop.
static LogicalResult checkTileSizes(TilingInterface op, scf::SCFTilingOptions::LoopType loopType, ReductionTilingStrategy reductionStrategy, ArrayRef< OpFoldResult > givenTileSizes, ArrayRef< OpFoldResult > numThreads)
Checks if any of the tiled loops are not parallel.
static FailureOr< scf::SCFFuseConsumerOfSliceResult > tileAndFuseConsumerOfSlicesImpl(RewriterBase &rewriter, Operation *consumerOp, ArrayRef< OpOperand * > consumerOpOperands, ArrayRef< Operation * > candidateSlices, MutableArrayRef< LoopLikeOpInterface > loops)
static FailureOr< SmallVector< LoopLikeOpInterface > > generateLoopNestUsingCustomOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes, ValueRange outerDestinationTensors, const scf::SCFTilingOptions::GenerateLoopHeaderFn &generateLoopHeaderFn, const scf::SCFTilingOptions::GenerateLoopTerminatorFn &generateLoopTerminatorFn, GenerateTiledBodyFn tiledBodyFn)
Generate the tile-loop nest using custom loop operation.
static FailureOr< SmallVector< LoopLikeOpInterface > > generateLoopNest(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes, ArrayRef< OpFoldResult > numThreads, ValueRange destinationTensors, GenerateTiledBodyFn tiledBodyFn)
Generate the tile-loop nest using the loop construct specifed in options.
static FailureOr< SmallVector< LoopLikeOpInterface > > generateLoopNestUsingForOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes, ValueRange outerDestinationTensors, GenerateTiledBodyFn tiledBodyFn)
Generate the tile-loop nest using scf.for operation.
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static FailureOr< OpOperand * > getUntiledConsumerFromSlice(RewriterBase &rewriter, tensor::InsertSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Fetch the untiled consumer of the outermost scf.for's result which is yielded by a tensor....
static SmallVector< tensor::InsertSliceOp > cloneAsInsertSlices(RewriterBase &rewriter, ArrayRef< Operation * > candidateSlices)
static FailureOr< SmallVector< OpOperand * > > getUntiledConsumerOperandsFromSlices(RewriterBase &rewriter, ArrayRef< Operation * > sliceOps, MutableArrayRef< LoopLikeOpInterface > loops)
A utility to fetch an untiled consumer of tensor.insert_slice/tensor.parallel_insert_slice.
tensor::InsertSliceOp cloneAsInsertSlice< tensor::ParallelInsertSliceOp >(RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp)
static FailureOr< Operation * > getFirstUserOfLoop(Operation *loopOp)
An utility to get the first user of the given loopOp.
static Operation * cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *op, ValueRange newDestArgs)
Clones the operation and updates the destination if the operation implements the DestinationStyleOpIn...
static FailureOr< llvm::SetVector< Operation * > > checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, bool reorderOperations)
This utility currently checks whether the first userOp of loop is NOT before the last defineOp of con...
std::function< LogicalResult( RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, SmallVector< Value > &tiledValues, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> YieldTiledValuesFn
Typedef for function that allows returning additional yielded values during yieldTiledValuesAndReplac...
tensor::InsertSliceOp cloneAsInsertSlice< tensor::InsertSliceOp >(RewriterBase &rewriter, tensor::InsertSliceOp insertSliceOp)
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > givenTileSizes)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes)
Function to return the bounds of the loops to be generated.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizesWithForAllOp(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > givenTileSizes, ArrayRef< OpFoldResult > numThreads)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, OpFoldResult offset, OpFoldResult givenTileSize)
Returns the bounded tile size given the current offset, loopRange and tileSize, i....
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop< scf::ForOp >(scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Implementation of yieldTiledValuesAndReplaceLoop for scf.for.
static bool tileDividesIterationDomain(Range loopRange)
Check if stride evenly divides the trip count size - offset.
static SetVector< unsigned > getSanitizedReductionDims(ArrayRef< OpFoldResult > givenTileSizes, const scf::SCFTilingOptions &options)
Get the reduction dims that are tiled.
static LogicalResult addInitOperandsToLoopNest(RewriterBase &rewriter, MutableArrayRef< LoopLikeOpInterface > loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn)
Method to add new init values to a loop nest.
static FailureOr< SmallVector< LoopLikeOpInterface > > generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes, ArrayRef< OpFoldResult > numThreads, ArrayRef< Attribute > mappingVector, ValueRange outerDestinationTensors, GenerateTiledBodyFn tiledBodyFn)
Generate the tile-loop nest using scf.forall operation.
static std::optional< Operation * > getProducingParallelInsertSlice(scf::ForallOp forallOp, OpResult result)
For a given result of a forallOp return the tensor.parallel_insert_slice op (or combining op) that is...
static tensor::InsertSliceOp cloneAsInsertSlice(RewriterBase &rewriter, InsertSliceOpTy sliceOp)
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims)
static FailureOr< SmallVector< Value > > createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims)
static SmallVector< OpFoldResult > getSplitReductionIvs(RewriterBase &rewriter, Location loc, ReductionTilingStrategy reductionStrategy, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims)
For the case of ReductionTilingStrategy::PartialReductionOuterParallel the PartialReductionOpInterfac...
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, ArrayRef< Range > iterationDomain, const scf::SCFTilingOptions &options)
Method to instantiate the tile sizes and/or number of threads specified by the user.
static LogicalResult verifyOptions(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options)
Verify the tile size options are set in a consistent manner.
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Append the specified additional newInitOperands operands to the loops existing init operands (or simi...
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class represents an argument of a Block.
Block represents an ordered list of Operations.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
IntegerAttr getIndexAttr(int64_t value)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
iterator_range< use_iterator > use_range
Operation * getOwner() const
Return the owner of this operand.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
FailureOr< TilingResult > replaceInsertSlicesWithTiledConsumer(OpBuilder &builder, ArrayRef< tensor::InsertSliceOp > sliceOps, ArrayRef< OpOperand * > consumerOperands)
Method to swap tensor.insert_slices with their consumers when the consumer implements the TilingInter...
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Include the generated interface declarations.
bool isPerfectlyNestedForLoops(MutableArrayRef< LoopLikeOpInterface > loops)
Check if the provided loops are perfectly nested for-loops.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
ReductionTilingStrategy
Tiling can be thought of as splitting a dimension into 2 and materializing the outer dimension as a l...
@ PartialReductionOuterReduction
@ PartialReductionOuterParallel
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
llvm::SetVector< T, Vector, Set, N > SetVector
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Container for result values of tiling.