28 #include "llvm/ADT/ScopeExit.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/Support/Debug.h"
33 #define DEBUG_TYPE "tile-using-interface"
40 auto tileSizes = llvm::to_vector(ts);
49 assert(!numThreadsComputationFunction &&
"num tiles already set");
50 auto numThreads = llvm::to_vector(nt);
61 size_t iterationDomainSize) {
63 if (filledVector.size() < iterationDomainSize) {
64 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
65 filledVector.append(range.begin(), range.end());
67 if (filledVector.size() > iterationDomainSize)
68 filledVector.resize(iterationDomainSize);
80 if (
options.numThreadsComputationFunction &&
83 loc,
"number of threads can only by specified when loop type is "
84 "set to use `scf.forall`");
88 if (!
options.interchangeVector.empty()) {
91 loc,
"invalid interchange vector, not a permutation of the entire "
106 size_t numLoops = iterationDomain.size();
109 if (
options.numThreadsComputationFunction) {
110 numThreads =
options.numThreadsComputationFunction(rewriter, op);
111 numThreads.resize(numLoops, zero);
114 if (
options.tileSizeComputationFunction) {
115 tileSizes =
options.tileSizeComputationFunction(rewriter, op);
116 tileSizes.resize(numLoops, zero);
117 return {tileSizes, numThreads};
129 tileSizes.resize(numLoops, zero);
130 for (
auto [index, range, nt] :
136 rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
138 tileSizes.resize(numLoops, zero);
139 return {tileSizes, numThreads};
146 assert(
options.tileSizeComputationFunction &&
147 "expected tile sizes to be specified");
148 tileSizes =
options.tileSizeComputationFunction(rewriter, op);
149 tileSizes.resize(numLoops, zero);
151 return {tileSizes, numThreads};
160 auto iterators = op.getLoopIteratorTypes();
161 assert(iterators.size() == givenTileSizes.size() &&
162 "expected as many tile size values as number of loops");
163 assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
164 "when specified, expected number of threads to use for each loop");
166 bool isParallelTiling =
false;
167 for (
auto [index, iterator, givenTileSize] :
170 isParallelTiling |= iterator == utils::IteratorType::parallel;
177 if (!numThreads.empty()) {
178 if (std::optional<int64_t> constNumThreads =
180 if (constNumThreads.value() > 1 &&
181 iterator != utils::IteratorType::parallel) {
182 op.emitWarning() <<
"tiling is not thread safe at axis #" << index;
188 if (std::optional<int64_t> constTileSize =
190 if (constTileSize.value() > 0 &&
191 iterator != utils::IteratorType::parallel) {
192 op.emitWarning() <<
"tiling is not thread safe at axis #" << index;
199 if (isParallelTiling) {
200 return op->emitOpError(
"tiling parallel dimensions is not supported with "
201 "partial reduction tiling strategies");
213 for (
auto dim :
options.reductionDims) {
216 reductionDims.insert(dim);
218 return reductionDims;
232 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
241 if (ts && ts.value() == 1)
242 return givenTileSize;
246 return givenTileSize;
268 if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
270 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
281 int materializedLoopNum = 0;
282 for (
auto [givenTileSize, loopRange] :
283 llvm::zip_equal(givenTileSizes, iterationDomain)) {
288 offsets.push_back(loopRange.offset);
289 sizes.push_back(loopRange.size);
293 Value iv = ivs[materializedLoopNum++];
295 offsets.push_back(offset);
298 sizes.push_back(size);
300 return {offsets, sizes};
309 for (
auto [loopRange, givenTileSize] :
310 llvm::zip_equal(loopRanges, givenTileSizes)) {
314 lbs.push_back(loopRange.offset);
315 ubs.push_back(loopRange.size);
316 steps.push_back(givenTileSize);
318 return {lbs, ubs, steps};
372 if (newDestArgs.empty())
374 if (
auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
375 destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
392 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
393 assert(loopRanges.size() == givenTileSizes.size() &&
394 "expected as many tile sizes as loop ranges");
398 std::tie(lbs, ubs, steps) =
409 ValueRange innerDestinationTensors(outerDestinationTensors);
410 for (
auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
412 scf::ForOp::create(rewriter, loc, lb, ub, step, innerDestinationTensors,
415 loops.push_back(loop);
416 ivs.push_back(loop.getInductionVar());
418 innerDestinationTensors = loop.getRegionIterArgs();
425 std::tie(offsets, sizes) =
430 if (
failed(tiledBodyFn(rewriter, loc, ivs, offsets, sizes,
431 innerDestinationTensors, tiledResults, resultOffsets,
434 loc,
"failed to generate inner tile loop body");
439 assert(tiledResults.size() == innerDestinationTensors.size() &&
440 "Number of results of body should be equal to number of iter args");
444 for (
auto [tiledValue, destinationTensor, resultOffset, resultSize] :
445 llvm::zip_equal(tiledResults, innerDestinationTensors, resultOffsets,
449 auto insertSlice = tensor::InsertSliceOp::create(
450 rewriter, loc, tiledValue, destinationTensor, resultOffset, resultSize,
452 yieldedValues.push_back(insertSlice);
454 scf::YieldOp::create(rewriter, loc, yieldedValues);
457 for (
auto [outerLoop, innerLoop] :
461 cast<scf::ForOp>(outerLoop.getOperation()).getBody());
462 scf::YieldOp::create(rewriter, outerLoop.getLoc(), innerLoop->getResults());
479 if (numThreads.empty()) {
485 int materializedLoopNum = 0;
491 offsetExpr = d0 + d1 * s0;
492 residualTileSizeExpr = s1 - (d0 + d1 * s0);
494 for (
auto [index, nt, givenTileSize, loopRange] :
500 offsets.push_back(loopRange.offset);
501 sizes.push_back(loopRange.size);
505 Value iv = ivs[materializedLoopNum++];
507 rewriter, loc, offsetExpr,
510 rewriter, loc, residualTileSizeExpr,
511 {loopRange.offset, nt, givenTileSize, loopRange.size});
517 {offset, loopRange.size});
521 {sizeMinusOffsetPerThread, givenTileSize});
537 rewriter, loc, maxMap, {rewriter.
getIndexAttr(0), size});
540 offsets.push_back(offset);
541 sizes.push_back(size);
543 return {offsets, sizes};
556 static FailureOr<SmallVector<LoopLikeOpInterface>>
564 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
565 assert(loopRanges.size() == givenTileSizes.size() &&
566 "expected as many tile sizes as loop ranges");
569 std::optional<ArrayAttr> mappingAttr;
570 if (!mappingVector.empty())
573 scf::ForallOp forallOp;
574 bool useNumThreads = !numThreads.empty();
580 for (
auto nt : numThreads) {
583 nonZeroNumThreads.push_back(nt);
585 forallOp = scf::ForallOp::create(rewriter, loc, nonZeroNumThreads,
586 outerDestinationTensors, mappingAttr);
589 std::tie(lbs, ubs, steps) =
591 forallOp = scf::ForallOp::create(rewriter, loc, lbs, ubs, steps,
592 outerDestinationTensors, mappingAttr);
594 loops.push_back(forallOp);
597 ValueRange innerDestinationTensors = forallOp.getRegionOutArgs();
603 rewriter, loc, ivs, loopRanges, givenTileSizes, numThreads);
607 if (
failed(tiledBodyFn(rewriter, loc, ivs, offsets, sizes,
608 innerDestinationTensors, tiledResults, resultOffsets,
613 for (
auto [tiledValue, destinationTensor, resultOffset, resultSize] :
614 llvm::zip_equal(tiledResults, innerDestinationTensors, resultOffsets,
619 tensor::ParallelInsertSliceOp::create(rewriter, loc, tiledValue,
620 destinationTensor, resultOffset,
621 resultSize, resultStride);
636 static FailureOr<SmallVector<LoopLikeOpInterface>>
642 &generateLoopTerminatorFn,
644 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
645 assert(loopRanges.size() == givenTileSizes.size() &&
646 "expected as many tile sizes as loop ranges");
647 assert(generateLoopHeaderFn && generateLoopTerminatorFn &&
648 "expected loop header/terminator generation function");
651 FailureOr<scf::SCFTilingOptions::CustomLoopHeaderInfo> loopHeaderInfo =
652 generateLoopHeaderFn(rewriter, loc, loopRanges, givenTileSizes,
653 outerDestinationTensors);
654 if (
failed(loopHeaderInfo)) {
661 if (
failed(tiledBodyFn(rewriter, loc, ivs, loopHeaderInfo->tileOffset,
662 loopHeaderInfo->tileSizes,
663 loopHeaderInfo->destinationTensors, tiledResults,
664 resultOffsets, resultSizes))) {
668 if (
failed(generateLoopTerminatorFn(rewriter, loc, tiledResults,
669 resultOffsets, resultSizes,
670 loopHeaderInfo->destinationTensors))) {
674 return loopHeaderInfo->loops;
698 llvm::map_to_vector(loopRanges, [](
Range r) {
return r.
offset; });
700 llvm::map_to_vector(loopRanges, [](
Range r) {
return r.
size; });
702 destinationTensors, tiledResults, resultOffsets,
710 destinationTensors, tiledBodyFn);
714 rewriter, loc, loopRanges, givenTileSizes, numThreads,
715 options.mappingVector, destinationTensors, tiledBodyFn);
719 rewriter, loc, loopRanges, givenTileSizes, destinationTensors,
739 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
741 return op->emitOpError(
742 "PartialReductionOuterReduction tiling strategy is only supported for "
743 "operations implementing PartialReductionOpInterface");
748 AffineExpr sizeExpr = ((s0 - s1).ceilDiv(s2));
750 for (
auto [index, domain, tileSize] :
752 if (!numThreads.empty()) {
756 rewriter, op.getLoc(), sizeExpr,
757 {domain.size, domain.offset, domain.stride});
760 sizes[index] = numThreads[index];
767 rewriter, op.getLoc(), sizeExpr,
768 {domain.size, domain.offset, domain.stride});
772 if (reductionStrategy ==
774 sizes[index] = tileSize;
778 assert(reductionStrategy ==
781 rewriter, op.getLoc(), sizeExpr,
782 {domain.size, domain.offset, domain.stride});
784 rewriter, op.getLoc(), divExpr, {normalizedRange, tileSize});
786 return redOp.generateInitialTensorForPartialReduction(rewriter, loc, sizes,
800 splitReductionIvs.resize(reductionDims.size(), rewriter.
getIndexAttr(0));
805 if (reductionStrategy ==
808 if (!numThreads.empty()) {
809 splitReductionIvs[index] = ivs[ivIndex++];
813 rewriter, loc, divExpr,
817 return splitReductionIvs;
820 static FailureOr<TilingResult>
829 return op.getTiledImplementation(rewriter, offsets, sizes);
832 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
835 op,
"PartialReductionOuterReduction tiling strategy is only "
836 "supported for operations "
837 "implementing PartialReductionOpInterface");
842 numThreads, givenTileSizes, reductionDims);
843 return redOp.tileToPartialReduction(rewriter, op.getLoc(), reductionStrategy,
844 regionIterArg, offsets, sizes,
845 reductionDims, splitReductionIvs);
850 int64_t index,
Value tiledResult, TilingInterface op,
859 return op.getResultTilePosition(rewriter, index, offsets, sizes,
860 resultOffset, resultSize);
862 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
865 op,
"PartialReductionOuterReduction tiling strategy is only supported"
866 "for operations implementing PartialReductionOpInterface");
870 numThreads, givenTileSizes, reductionDims);
871 return redOp.getPartialResultTilePosition(
872 rewriter, index, reductionStrategy, offsets, sizes, reductionDims,
873 splitReductionIvs, resultOffset, resultSize);
876 static FailureOr<MergeResult>
882 "expected merge to be called for only partial reduction cases");
884 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
887 op,
"PartialReductionOuterReduction tiling strategy is only "
888 "supported for operations "
889 "implementing PartialReductionOpInterface");
891 return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
902 template <
typename LoopType>
903 FailureOr<LoopLikeOpInterface>
912 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
919 auto inits = llvm::to_vector(loopOp.getInitArgs());
920 inits.append(newInitOperands.begin(), newInitOperands.end());
921 auto newLoop = scf::ForOp::create(
922 rewriter, loc, loopOp.getLowerBound(), loopOp.getUpperBound(),
924 loopOp.getUnsignedCmp());
927 Block *loopBody = loopOp.getBody();
928 Block *newLoopBody = newLoop.getBody();
930 loopBody, newLoopBody,
933 auto yieldOp = cast<scf::YieldOp>(newLoopBody->
getTerminator());
939 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
940 if (
failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
941 newRegionIterArgs, tiledValues, resultOffsets,
948 for (
auto [tiledValue, regionIterArg, resultOffset, resultSize] :
949 llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
953 Value insert = tensor::InsertSliceOp::create(
954 rewriter, yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset,
955 resultSize, resultStride);
956 newYieldValues.push_back(insert);
961 newLoop->getResults().take_front(loopOp.getNumResults()));
962 return cast<LoopLikeOpInterface>(newLoop.getOperation());
967 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
973 auto inits = llvm::to_vector(loopOp.getOutputs());
974 inits.append(newInitOperands.begin(), newInitOperands.end());
975 auto newLoop = scf::ForallOp::create(
976 rewriter, loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
977 loopOp.getMixedStep(), inits, loopOp.getMapping(),
981 Block *loopBody = loopOp.getBody();
982 Block *newLoopBody = newLoop.getBody();
984 loopBody, newLoopBody,
987 auto terminator = cast<scf::InParallelOp>(newLoopBody->
getTerminator());
992 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
993 if (
failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
994 regionIterArgs, tiledValues, resultOffsets,
998 "failed to get yielded tiled values");
1004 for (
auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
1005 tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
1008 tensor::ParallelInsertSliceOp::create(rewriter, terminator.getLoc(),
1009 tiledValue, iterArg, resultOffset,
1010 resultSize, resultStride);
1014 newLoop->getResults().take_front(loopOp.getNumResults()));
1015 return cast<LoopLikeOpInterface>(newLoop.getOperation());
1022 LoopLikeOpInterface loopLikeOp,
RewriterBase &rewriter,
1025 loopLikeOp.getOperation())
1026 .Case<scf::ForOp, scf::ForallOp>(
1027 [&](
auto loopOp) -> FailureOr<LoopLikeOpInterface> {
1029 loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
1031 .Default([&](
auto loopOp) -> FailureOr<LoopLikeOpInterface> {
1050 for (
auto &loop : loops.drop_back()) {
1054 auto forLoop = cast<scf::ForOp>(loop.getOperation());
1058 newInits.append(newInitValues.begin(), newInitValues.end());
1059 auto newLoop = scf::ForOp::create(
1060 rewriter, forLoop.getLoc(), forLoop.getLowerBound(),
1061 forLoop.getUpperBound(), forLoop.getStep(), newInits,
1063 forLoop.getUnsignedCmp());
1067 sourceBlockArgs.push_back(newLoop.getInductionVar());
1068 auto newRegionIterArgs = newLoop.getRegionIterArgs();
1069 sourceBlockArgs.append(
1070 newRegionIterArgs.begin(),
1071 std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
1072 rewriter.
mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
1074 forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
1076 ivs.push_back(newLoop.getInductionVar());
1077 newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
1081 LoopLikeOpInterface innerMostLoop = loops.back();
1082 FailureOr<LoopLikeOpInterface> newInnerMostLoop =
1084 getNewTiledYieldsFn);
1086 if (
failed(newInnerMostLoop))
1087 return innerMostLoop.emitOpError(
"failed to return additional yields");
1088 loops.back() = newInnerMostLoop.value();
1092 for (
auto [outerLoop, innerLoop] :
1093 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
1095 auto outerForLoop = cast<scf::ForOp>(outerLoop);
1096 auto outerLoopYield =
1097 cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
1099 llvm::to_vector(outerLoopYield.getOperands());
1101 innerLoop->getResults().take_back(newInitValues.size());
1102 newYields.append(additionalYields.begin(), additionalYields.end());
1111 FailureOr<scf::SCFTilingResult>
1126 std::tie(givenTileSizes, numThreads) =
1132 givenTileSizes, numThreads))) {
1143 if (!
options.interchangeVector.empty()) {
1145 iterationDomain.size());
1147 "expected interchange vector to be a permutation");
1151 if (!numThreads.empty())
1155 FailureOr<TilingResult> tilingResult;
1169 if (!interchangeVector.empty()) {
1178 auto clonedOp = cast<TilingInterface>(
1185 tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
1195 regionIterArgs, tileOffsetsVec, tileSizesVec,
1196 ivs, numThreads, givenTileSizes, reductionDims);
1197 if (
failed(tilingResult)) {
1199 return op.emitOpError(
"faild to tile operation");
1207 for (
auto [index, tiledValue] :
1209 tiledResults.push_back(tiledValue);
1212 rewriter,
options.reductionStrategy, index, tiledValue, op,
1213 tileOffsetsVec, tileSizesVec, ivs, numThreads, givenTileSizes,
1214 reductionDims, resultOffset, resultSize))) {
1215 for (
auto op : tilingResult->tiledOps) {
1219 op,
"failed to get slice of result produced");
1221 resultOffsets.emplace_back(std::move(resultOffset));
1222 resultSizes.emplace_back(std::move(resultSize));
1230 rewriter, op,
options.reductionStrategy, iterationDomain, numThreads,
1231 givenTileSizes, reductionDims);
1232 if (
failed(maybeInits)) {
1234 op,
"unable to create initial tensors for tiling");
1242 rewriter, op.getLoc(),
options, iterationDomain, givenTileSizes,
1243 numThreads, initTensors, innerYieldTiledValuesFn);
1245 return op.emitOpError(
"failed to generate tiling loops");
1246 assert(succeeded(tilingResult) &&
1247 "expected tiling result to be computed after loop generation");
1248 std::swap(loops, loopsOr.value());
1251 if (loops.empty()) {
1257 tilingResult->tiledValues,
1258 tilingResult->generatedSlices,
1262 auto loopResults = llvm::map_to_vector(loops.front()->getResults(),
1268 tilingResult->
tiledOps, initTensors, loops, loopResults,
1269 tilingResult->generatedSlices, {}};
1274 rewriter, op,
options.reductionStrategy, reductionDims, loopResults);
1275 if (
failed(mergeResult)) {
1277 op,
"Failed to merge partial results from tiling");
1282 mergeResult->replacements,
1283 tilingResult->generatedSlices,
1284 mergeResult->mergeOps};
1287 FailureOr<scf::SCFTilingResult>
1289 PartialReductionOpInterface op,
1293 options.setReductionTilingStrategy(
1295 options.setTileSizes(tileSize);
1297 for (
auto [index, iteratorType] :
llvm::enumerate(op.getLoopIteratorTypes()))
1298 if (iteratorType == utils::IteratorType::reduction)
1299 reductionDims.push_back(index);
1300 options.setReductionDims(reductionDims);
1314 static std::tuple<OpResult, std::optional<OpOperand *>>
1317 std::optional<OpOperand *> destinationIterArg;
1318 assert(!loops.empty() &&
"expected non empty loops container");
1319 auto loopIt = loops.rbegin();
1320 while (loopIt != loops.rend() && isa<BlockArgument>(source->
get())) {
1321 auto iterArg = cast<BlockArgument>(source->
get());
1322 auto loop = *loopIt;
1323 if (iterArg.getOwner()->getParentOp() != loop)
1325 source = loop.getTiedLoopInit(iterArg);
1328 if (loopIt == loops.rend())
1329 destinationIterArg = source;
1330 return {dyn_cast<OpResult>(source->
get()), destinationIterArg};
1335 std::optional<scf::SCFFuseProducerOfSliceResult>
1337 RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1341 auto [fusableProducer, destinationInitArg] =
1344 if (!fusableProducer)
1345 return std::nullopt;
1346 unsigned resultNumber = fusableProducer.getResultNumber();
1354 Operation *fusableProducerOp = fusableProducer.getOwner();
1355 if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
1357 rewriter, fusableProducerOp->
getLoc(), fusableProducerOp,
1358 origDestinationTensors)))
1359 return std::nullopt;
1361 clonedOpDestinationTensors = origDestinationTensors;
1362 if (destinationInitArg &&
1363 isa<DestinationStyleOpInterface>(fusableProducerOp)) {
1367 clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1371 rewriter, fusableProducerOp, clonedOpDestinationTensors);
1376 llvm::to_vector(candidateSliceOp->getOperands());
1377 candidateSliceOpOperands[0] = clonedProducerOp->
getResult(resultNumber);
1378 tensor::ExtractSliceOp clonedCandidateSliceOp =
1380 candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1383 FailureOr<TilingResult> tileAndFuseResult =
1385 rewriter, clonedCandidateSliceOp,
1386 clonedProducerOp->
getResult(resultNumber));
1387 if (
failed(tileAndFuseResult))
1388 return std::nullopt;
1392 tileAndFuseResult->tiledValues[0]);
1393 rewriter.
eraseOp(clonedCandidateSliceOp);
1394 rewriter.
eraseOp(clonedProducerOp);
1439 if (destinationInitArg &&
1440 isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
1442 ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1443 .set(origDestinationTensors[resultNumber]);
1446 fusableProducer, tileAndFuseResult->tiledValues[0],
1447 tileAndFuseResult->
tiledOps, tileAndFuseResult->generatedSlices};
1452 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1460 *tiledOwner = fusedProducerInfo.
tiledOps[0];
1465 yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
1467 : llvm::to_vector(yieldResultNumber);
1469 for (
const auto &resultNumber : initNumberList) {
1471 rewriter, loc, originalOwner->
getResult(resultNumber));
1472 if (succeeded(initValue)) {
1473 initValueList.push_back(initValue.value());
1489 sliceSizes = sliceOp.getMixedSizes();
1492 if (!llvm::all_of(sliceOp.getMixedStrides(),
isOneInteger))
1495 unsigned sliceResultNumber =
1498 auto tilableOp = cast<TilingInterface>(originalOwner);
1502 if (tilableOp->getNumResults() > 1 &&
1503 failed(tilableOp.getIterationDomainTileFromResultTile(
1504 rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1505 iterDomainOffset, iterDomainSizes))) {
1520 for (
const auto &resultNumber : initNumberList) {
1521 if (resultNumber == sliceResultNumber) {
1522 offsetList.push_back(sliceOffset);
1523 sizesList.push_back(sliceSizes);
1525 assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1528 if (
failed(tilableOp.getResultTilePosition(
1529 rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1533 offsetList.push_back(offset);
1534 sizesList.push_back(sizes);
1540 if (
auto tiledDestStyleOp =
1541 dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1543 for (
const auto &&[index, newRegionArg] :
1545 auto destSlice = tensor::ExtractSliceOp::create(
1546 rewriter, loc, newRegionArg, offsetList[index], sizesList[index],
1549 generatedSlices.push_back(destSlice);
1550 unsigned resultNumber = initNumberList[index];
1552 tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1561 for (
const auto &&[index, resultNumber] :
llvm::enumerate(initNumberList)) {
1562 tiledResult.push_back(tiledOwner->getResult(resultNumber));
1563 tiledOffset.emplace_back(offsetList[index]);
1564 tiledSizes.emplace_back(sizesList[index]);
1570 newYieldValuesFn))) {
1573 return generatedSlices;
1587 explicit SliceTrackingListener(
1588 std::optional<FrozenRewritePatternSet>
patterns);
1589 SliceTrackingListener() =
default;
1598 void notifyOperationInserted(
Operation *op,
1605 void notifyOperationErased(
Operation *op)
override;
1612 std::deque<tensor::ExtractSliceOp> worklist;
1617 std::optional<FrozenRewritePatternSet>
patterns = std::nullopt;
1620 SliceTrackingListener::SliceTrackingListener(
1621 std::optional<FrozenRewritePatternSet> p) {
1628 if (
auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1629 worklist.push_back(slice);
1641 void SliceTrackingListener::notifyOperationInserted(
1643 auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1646 worklist.push_back(slice);
1652 void SliceTrackingListener::removeOp(
Operation *op) {
1653 if (!isa<tensor::ExtractSliceOp>(op))
1655 auto iter = worklist.begin();
1656 while (iter != worklist.end()) {
1661 if (iter == worklist.end())
1664 worklist.erase(iter);
1667 void SliceTrackingListener::notifyOperationErased(
Operation *op) {
1671 void SliceTrackingListener::notifyOperationReplaced(
Operation *op,
1687 : ForwardingListener(listener), replacements(replacements) {}
1689 void updateReplacementValues(
ValueRange origValues,
1693 for (
auto &[key, val] : replacements) {
1694 for (
auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) {
1703 ForwardingListener::notifyOperationReplaced(op, newOp);
1708 ForwardingListener::notifyOperationReplaced(op, values);
1709 updateReplacementValues(op->
getResults(), values);
1719 FailureOr<scf::SCFTileAndFuseResult>
1725 if (!consumer->getNumResults()) {
1727 consumer,
"invalid pattern for op with no results");
1733 FailureOr<scf::SCFTilingResult> tilingResult =
1736 if (
failed(tilingResult))
1738 tiledAndFusedOps.insert_range(tilingResult->tiledOps);
1741 for (
auto [origVal, replacement] :
1742 llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
1743 replacements[origVal] = replacement;
1747 auto &loops = tilingResult->loops;
1748 if (loops.empty()) {
1757 auto resetListener =
1758 llvm::make_scope_exit([&]() { rewriter.
setListener(previousListener); });
1759 ReplacementListener replaceListener(replacements, previousListener);
1769 struct WorklistItem {
1770 tensor::ExtractSliceOp candidateSlice;
1774 SliceTrackingListener sliceTracker =
1775 SliceTrackingListener(
options.cleanupPatterns);
1778 sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1782 while (!sliceTracker.worklist.empty()) {
1783 auto candidateSlice = sliceTracker.worklist.front();
1784 sliceTracker.worklist.pop_front();
1786 auto [fusableProducer, destinationInitArg] =
1789 if (!fusableProducer)
1792 std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1793 options.fusionControlFn(candidateSlice, fusableProducer,
1794 destinationInitArg.has_value());
1795 if (!controlFnResult)
1798 WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
1803 std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1811 if (worklistItem.controlFnResult.yieldProducerReplacement) {
1816 Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1817 FailureOr<SmallVector<Operation *>> newSlices =
1819 worklistItem.candidateSlice,
1820 fusedResult.value(), loops);
1823 fusableProducerOp,
"failed to replacement value for this "
1824 "operation from within the tiled loop");
1826 worklistCandidates.append(newSlices.value());
1827 for (
auto [index, result] :
1829 replacements[result] = loops.front()->getResult(
1830 loops.front()->getNumResults() -
1835 fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1836 fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1837 tiledAndFusedOps.insert(tiledAndFusedOp);
1840 if (
failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1855 static LogicalResult
1857 Value result = candidateSliceOp.getResult();
1859 if (!llvm::hasSingleElement(uses)) {
1860 LLVM_DEBUG(llvm::dbgs() <<
"Too many uses of the candidate slice op\n");
1863 OpOperand &operandUse = (*uses.begin());
1865 if (!isa<scf::YieldOp>(userOp)) {
1866 LLVM_DEBUG(llvm::dbgs()
1867 <<
"Expected scf.yield to be the only user, but got -> "
1872 LLVM_DEBUG(llvm::dbgs() <<
"Expected tensor.insert_slice and scf.yield to "
1873 "be in the same block\n");
1882 if (!isa<LoopLikeOpInterface>(loopOp))
1901 if (isa<tensor::ParallelInsertSliceOp>(userOp))
1904 if (loopOp->
getBlock() != userOp->getBlock())
1908 firstUserOfLoop = userOp;
1910 return firstUserOfLoop;
1951 static FailureOr<llvm::SetVector<Operation *>>
1953 bool reorderOperations) {
1955 if (
failed(firstUserOfLoop))
1961 options.omitBlockArguments =
true;
1962 bool includeLoopOp =
false;
1965 includeLoopOp =
true;
1975 assert(result.succeeded() &&
"expected a backward slice");
1979 if (!slice.empty()) {
1989 if (includeLoopOp || !reorderOperations)
2001 unsigned resultNumber) {
2002 if (!isa<LoopLikeOpInterface>(loopOp))
2007 Operation *consumerOp = opOperand.getOwner();
2009 if (!isa<TilingInterface>(consumerOp) ||
2010 !isa<DestinationStyleOpInterface>(consumerOp)) {
2017 if (loopBlock != consumerOp->
getBlock())
2024 FailureOr<llvm::SetVector<Operation *>> slice =
2030 if (!slice->empty()) {
2033 assert(succeeded(firstUserOfLoop) &&
"First user of loop is not found");
2034 for (
auto op : *slice) {
2049 static FailureOr<OpOperand *>
2051 tensor::InsertSliceOp candidateSliceOp,
2053 assert(!loops.empty() &&
"unexpected loops to be empty");
2056 if (containingOp != loops.back()) {
2059 "expected slice to be within body of inner-most loop");
2065 candidateSliceOp,
"expected passed loops to be perfectly nested.");
2070 Value sliceResult = candidateSliceOp.getResult();
2076 scf::ForOp topLevelForOp = cast<scf::ForOp>(loops.front().getOperation());
2083 static FailureOr<OpOperand *>
2085 tensor::ParallelInsertSliceOp candidateSliceOp,
2087 assert(!loops.empty() &&
"unexpected loops to be empty");
2089 if (loops.size() != 1) {
2091 candidateSliceOp,
"expected single surrounding scf.forall");
2093 auto forallOp = dyn_cast<scf::ForallOp>(loops.front().getOperation());
2096 candidateSliceOp,
"expected single surrounding scf.forall");
2100 Value sliceDest = candidateSliceOp.getDest();
2101 auto iterArg = dyn_cast<BlockArgument>(sliceDest);
2104 if (iterArg.getOwner()->getParentOp() != forallOp)
2107 unsigned resultNumber =
2108 forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
2119 assert(!loops.empty() &&
"unexpected empty loops");
2120 assert(!sliceOps.empty() &&
"unexpected empty list of candidate slices");
2122 for (
auto sliceOp : sliceOps) {
2123 FailureOr<OpOperand *> fusedOperand =
2125 .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2132 if (
failed(fusedOperand)) {
2135 if (!fusedOperands.empty() &&
2136 fusedOperand.value()->getOwner() != fusedOperands.front()->getOwner()) {
2138 fusedOperand.value()->getOwner(),
2139 "all candidate slices must be to the same consumer");
2141 fusedOperands.push_back(fusedOperand.value());
2143 return fusedOperands;
2146 template <
typename InsertSliceOpTy>
2148 InsertSliceOpTy sliceOp);
2151 tensor::InsertSliceOp
2152 cloneAsInsertSlice<tensor::InsertSliceOp>(
RewriterBase &rewriter,
2153 tensor::InsertSliceOp insertSliceOp) {
2154 return cast<tensor::InsertSliceOp>(
2155 rewriter.
clone(*insertSliceOp.getOperation()));
2159 tensor::InsertSliceOp cloneAsInsertSlice<tensor::ParallelInsertSliceOp>(
2160 RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp) {
2161 return tensor::InsertSliceOp::create(
2162 rewriter, insertSliceOp->getLoc(), insertSliceOp.getSource(),
2163 insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(),
2164 insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
2170 assert(!candidateSlices.empty() &&
2171 "unexpected empty list of slices to clone");
2173 for (
auto sliceOp : candidateSlices) {
2175 .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2178 clonedSlices.push_back(clonedOp);
2182 assert(0 &&
"unexpected slice type while cloning as insert slice");
2185 return clonedSlices;
2190 FailureOr<scf::SCFFuseConsumerOfSliceResult>
2194 if (candidateSlices.empty()) {
2197 "no candidate slices provided for consumer fusion");
2201 if (loops.empty()) {
2203 candidateSlices.front(),
2204 "cannot call tile and fuse consumer with an empty loop nest");
2207 if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
2208 llvm::all_of(candidateSlices,
2209 llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
2211 candidateSlices.front(),
2212 "candidates slices need to be all `tensor.extract_slice`s or "
2213 "`tensor.parallel_insert_slice`s");
2221 FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
2223 if (
failed(maybeConsumerOpOperand)) {
2225 "could not fetch consumer to fuse");
2227 std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
2228 consumerOp = consumerOpOperands.front()->getOwner();
2231 LoopLikeOpInterface outerMostLoop = loops.front();
2232 LoopLikeOpInterface innerMostLoop = loops.back();
2237 outerMostLoop,
"the first user of loop should not dominate any define "
2238 "of consumer operand(s)");
2244 auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
2247 "consumer op is not DPS operation");
2248 if (llvm::any_of(consumerOpOperands, [&](
OpOperand *opOperand) {
2249 return dstOp.isDpsInit(opOperand);
2253 "consumer op taking the result of scf.for as init is not supported");
2260 if (
failed(firstUserOfLoop)) {
2262 outerMostLoop,
"could not find the first user of outer most loop");
2264 rewriter.
moveOpBefore(outerMostLoop, *firstUserOfLoop);
2271 dyn_cast<tensor::ParallelInsertSliceOp>(candidateSlices.front())) {
2272 auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
2282 auto clonedConsumerOp = cast<TilingInterface>(rewriter.
clone(*consumerOp));
2284 llvm::map_to_vector(consumerOpOperands, [](
OpOperand *opOperand) {
2288 llvm::map_to_vector(operandNumbers, [&](
unsigned operandNum) {
2289 return &clonedConsumerOp->getOpOperand(operandNum);
2295 for (
auto [operandToReplace, clonedSliceOp] :
2296 llvm::zip_equal(clonedOpFusedOperandsList, clonedInsertSlices)) {
2297 operandToReplace->set(clonedSliceOp.getResult());
2303 FailureOr<TilingResult> tileAndFuseResult =
2305 clonedOpFusedOperandsList);
2306 if (
failed(tileAndFuseResult)) {
2310 auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
2311 for (
auto [operandNum, clonedSliceOp] :
2312 llvm::zip_equal(operandNumbers, clonedInsertSlices)) {
2314 clonedSliceOp.getSource());
2328 for (
auto candidateSliceOp : clonedInsertSlices) {
2336 candidateSliceOp,
"containingOp's result yield with stride");
2339 allOffsets.emplace_back(std::move(offsets));
2340 allSizes.emplace_back(std::move(sizes));
2350 if (
failed(clonedConsumerOp.getIterationDomainTileFromOperandTiles(
2351 rewriter, operandNumbers, allOffsets, allSizes, iterDomainOffsets,
2352 iterDomainSizes))) {
2355 "can't get iter domain position from input position");
2361 unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
2363 totalNumResultsOfConsumer);
2365 totalNumResultsOfConsumer);
2366 for (
auto [idx, v] :
llvm::enumerate(tiledConsumerOp->getResults())) {
2367 if (
failed(tiledConsumerOp.getResultTilePosition(
2368 rewriter, idx, iterDomainOffsets, iterDomainSizes,
2369 resultOffsets[idx], resultSizes[idx]))) {
2372 "can't get result domain position from iter domain position");
2378 if (
auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
2379 tiledConsumerOp.getOperation())) {
2381 for (
const auto &&[index, newRegionArg] :
2383 auto destSlice = tensor::ExtractSliceOp::create(
2384 rewriter, loc, newRegionArg, resultOffsets[index],
2390 auto dstNumber = index;
2392 tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
2401 for (
const auto &&[index, result] :
2403 tiledResult.push_back(result);
2404 tiledOffset.emplace_back(resultOffsets[index]);
2405 tiledSizes.emplace_back(resultSizes[index]);
2411 newYieldValuesFn))) {
2413 "unable to add new inits to nest loop");
2419 for (
auto &&[oldResult, newResult] :
2421 loops.front()->getResults().take_back(newInits.size()))) {
2426 rewriter.
eraseOp(clonedConsumerOp);
2429 llvm::map_to_vector(operandNumbers, [&](
unsigned operandNum) {
2430 return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum);
2433 std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands),
2434 std::move(tileAndFuseResult->tiledOps)};
2441 FailureOr<SmallVector<scf::ForOp>>
2443 TilingInterface op) {
2445 if (op->getNumResults() > 0) {
2447 op,
"unable to lower to loops operations with return values");
2454 for (
auto loopRange : domain) {
2461 auto loop = scf::ForOp::create(rewriter, op.getLoc(), offsetVal, sizeVal,
2463 loops.push_back(loop);
2464 ivs.push_back(loop.getInductionVar());
2467 if (
failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
static llvm::ManagedStatic< PassManagerOptions > options
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult givenTileSize, OpFoldResult numThreads, OpFoldResult iterationSize)
Returns true if the maximum tile offset tileSize * numThreads-1 is less than iterationSize.
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizesWithForAllOp(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > givenTileSizes, ArrayRef< OpFoldResult > numThreads)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > givenTileSizes)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static LogicalResult checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp)
A utility function that checks whether the only use of the result of a tensor.insert_slice op is in a...
static LogicalResult checkTileSizes(TilingInterface op, scf::SCFTilingOptions::LoopType loopType, ReductionTilingStrategy reductionStrategy, ArrayRef< OpFoldResult > givenTileSizes, ArrayRef< OpFoldResult > numThreads)
Checks if any of the tiled loops are not parallel.
static FailureOr< SmallVector< LoopLikeOpInterface > > generateLoopNestUsingForOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes, ValueRange outerDestinationTensors, GenerateTiledBodyFn tiledBodyFn)
Generate the tile-loop nest using scf.for operation.
std::function< LogicalResult(RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, SmallVector< Value > &tiledValues, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> YieldTiledValuesFn
Typedef for function that allows returning additional yielded values during yieldTiledValuesAndReplac...
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes)
Function to return the bounds of the loops to be generated.
static SmallVector< OpFoldResult > getSplitReductionIvs(RewriterBase &rewriter, Location loc, ReductionTilingStrategy reductionStrategy, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims)
For the case of ReductionTilingStrategy::PartialReductionOuterParallel the PartialReductionOpInterfac...
std::function< LogicalResult(RewriterBase &rewriter, Location Loc, ValueRange ivs, ArrayRef< OpFoldResult > tileOffsets, ArrayRef< OpFoldResult > tileSizes, ValueRange outerDestinationTensors, SmallVector< Value > &tiledResults, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> GenerateTiledBodyFn
Typedef for function that implements the body of a tiled loop.
static FailureOr< SmallVector< LoopLikeOpInterface > > generateLoopNest(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes, ArrayRef< OpFoldResult > numThreads, ValueRange destinationTensors, GenerateTiledBodyFn tiledBodyFn)
Generate the tile-loop nest using the loop construct specifed in options.
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims)
static FailureOr< llvm::SetVector< Operation * > > checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, bool reorderOperations)
This utility currently checks whether the first userOp of loop is NOT before the last defineOp of con...
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Append the specified additional newInitOperands operands to the loops existing init operands (or simi...
static FailureOr< OpOperand * > getConsumerFromLoopUses(RewriterBase &rewriter, Operation *loopOp, unsigned resultNumber)
Fetches the OpOperand of the first valid user (and use) of the value val which implements TilingInter...
static FailureOr< MergeResult > mergeTilingResults(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, const SetVector< unsigned > &reductionDims, ValueRange partialResults)
static FailureOr< SmallVector< LoopLikeOpInterface > > generateLoopNestUsingCustomOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes, ValueRange outerDestinationTensors, const scf::SCFTilingOptions::GenerateLoopHeaderFn &generateLoopHeaderFn, const scf::SCFTilingOptions::GenerateLoopTerminatorFn &generateLoopTerminatorFn, GenerateTiledBodyFn tiledBodyFn)
Generate the tile-loop nest using custom loop operation.
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, OpFoldResult offset, OpFoldResult givenTileSize)
Returns the bounded tile size given the current offset, loopRange and tileSize, i....
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, ArrayRef< Range > iterationDomain, const scf::SCFTilingOptions &options)
Method to instantiate the tile sizes and/or number of threads specified by the user.
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< LoopLikeOpInterface > loops)
Return the untiled producer whose slice is used in a tiled consumer.
static bool tileDividesIterationDomain(Range loopRange)
Check if stride evenly divides the trip count size - offset.
static Operation * cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *op, ValueRange newDestArgs)
Clones the operation and updates the destination if the operation implements the DestinationStyleOpIn...
static SetVector< unsigned > getSanitizedReductionDims(ArrayRef< OpFoldResult > givenTileSizes, const scf::SCFTilingOptions &options)
Get the reduction dims that are tiled.
static FailureOr< Operation * > getFirstUserOfLoop(Operation *loopOp)
An utility to get the first user of the given loopOp.
static LogicalResult addInitOperandsToLoopNest(RewriterBase &rewriter, MutableArrayRef< LoopLikeOpInterface > loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn)
Method to add new init values to a loop nest.
static FailureOr< SmallVector< OpOperand * > > getUntiledConsumerOperandsFromSlices(RewriterBase &rewriter, ArrayRef< Operation * > sliceOps, MutableArrayRef< LoopLikeOpInterface > loops)
A utility to fetch an untiled consumer of tensor.insert_slice/tensor.parallel_insert_slice.
static SmallVector< tensor::InsertSliceOp > cloneAsInsertSlices(RewriterBase &rewriter, ArrayRef< Operation * > candidateSlices)
static tensor::InsertSliceOp cloneAsInsertSlice(RewriterBase &rewriter, InsertSliceOpTy sliceOp)
static FailureOr< OpOperand * > getUntiledConsumerFromSlice(RewriterBase &rewriter, tensor::InsertSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Fetch the untiled consumer of the outermost scf.for's result which is yielded by a tensor....
static FailureOr< SmallVector< Value > > createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims)
static FailureOr< SmallVector< LoopLikeOpInterface > > generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes, ArrayRef< OpFoldResult > numThreads, ArrayRef< Attribute > mappingVector, ValueRange outerDestinationTensors, GenerateTiledBodyFn tiledBodyFn)
Generate the tile-loop nest using scf.forall operation.
static LogicalResult verifyOptions(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options)
Verify the tile size options are set in a consistent manner.
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
This class allows control over how the GreedyPatternRewriteDriver works.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class represents a saved insertion point.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation * getOwner() const
Returns the operation that owns this result.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< scf::SCFTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSizes)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
FailureOr< scf::SCFFuseConsumerOfSliceResult > tileAndFuseConsumerOfSlices(RewriterBase &rewriter, ArrayRef< Operation * > candidateSlices, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing consumer of a single slice by computing the slice of the consumer in-place f...
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< SmallVector< Operation * > > yieldReplacementForFusedProducer(RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef< LoopLikeOpInterface > loops, ArrayRef< unsigned > yieldResultNumber=ArrayRef< unsigned >{})
Reconstruct the fused producer from within the tiled-and-fused code.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
std::optional< SCFFuseProducerOfSliceResult > tileAndFuseProducerOfSlice(RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing producer of a single slice by computing the slice of the producer in-place.
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
FailureOr< TilingResult > replaceInsertSlicesWithTiledConsumer(OpBuilder &builder, ArrayRef< tensor::InsertSliceOp > sliceOps, ArrayRef< OpOperand * > consumerOperands)
Method to swap tensor.insert_slices with their consumers when the consumer implements the TilingInter...
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Include the generated interface declarations.
bool isPerfectlyNestedForLoops(MutableArrayRef< LoopLikeOpInterface > loops)
Check if the provided loops are perfectly nested for-loops.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
ReductionTilingStrategy
Tiling can be thought of as splitting a dimension into 2 and materializing the outer dimension as a l...
@ PartialReductionOuterReduction
@ PartialReductionOuterParallel
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
@ ExistingAndNewOps
Only pre-existing and newly created ops are processed.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
Container for result values of tiling.
Fuse the consumer candidateSlices by computing the required slice of the consumer in-place.
Fuse the producer of the source of candidateSliceOp by computing the required slice of the producer i...
SmallVector< Operation * > tiledOps
Control function to check if a slice needs to be fused or not, The control function receives 1) the s...
Options used to control tile + fuse.
Transformation information returned after tile and fuse.
Options to use to control tiling.
SCFTileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes to use for each loop.
std::function< LogicalResult(RewriterBase &rewriter, Location loc, ValueRange tiledResults, ArrayRef< SmallVector< OpFoldResult > > resultOffsets, ArrayRef< SmallVector< OpFoldResult > > resultSizes, ValueRange destinationTensors)> GenerateLoopTerminatorFn
SCFTilingOptions & setNumThreads(ArrayRef< OpFoldResult > numThreads)
Convenience function to set the numThreadsComputationFunction to a function that computes num threads...
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
LoopType
Specify which loop construct to use for tile and fuse.
std::function< FailureOr< CustomLoopHeaderInfo >(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes, ValueRange destinationTensors)> GenerateLoopHeaderFn
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.