31 #include "llvm/ADT/ScopeExit.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/Debug.h"
36 #define DEBUG_TYPE "tile-using-interface"
43 auto tileSizes = llvm::to_vector(ts);
52 assert(!numThreadsComputationFunction &&
"num tiles already set");
53 auto numThreads = llvm::to_vector(nt);
64 size_t iterationDomainSize) {
66 if (filledVector.size() < iterationDomainSize) {
67 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
68 filledVector.append(range.begin(), range.end());
70 if (filledVector.size() > iterationDomainSize)
71 filledVector.resize(iterationDomainSize);
84 if (
options.numThreadsComputationFunction &&
87 loc,
"number of threads can only by specified when loop type is "
88 "set to use `scf.forall`");
92 if (!
options.interchangeVector.empty()) {
95 loc,
"invalid interchange vector, not a permutation of the entire "
110 size_t numLoops = iterationDomain.size();
113 if (
options.numThreadsComputationFunction) {
114 numThreads =
options.numThreadsComputationFunction(rewriter, op);
115 numThreads.resize(numLoops, zero);
118 if (
options.tileSizeComputationFunction) {
119 tileSizes =
options.tileSizeComputationFunction(rewriter, op);
120 tileSizes.resize(numLoops, zero);
121 return {tileSizes, numThreads};
133 tileSizes.resize(numLoops, zero);
134 for (
auto [index, range, nt] :
140 rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
142 tileSizes.resize(numLoops, zero);
143 return {tileSizes, numThreads};
150 assert(
options.tileSizeComputationFunction &&
151 "expected tile sizes to be specified");
152 tileSizes =
options.tileSizeComputationFunction(rewriter, op);
153 tileSizes.resize(numLoops, zero);
155 return {tileSizes, numThreads};
162 auto iterators = op.getLoopIteratorTypes();
163 assert(iterators.size() == tileSizes.size() &&
164 "expected as many tile size values as number of loops");
165 assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
166 "when specified, expected number of threads to use for each loop");
168 for (
auto [index, iterator, tileSize] :
172 if (!numThreads.empty()) {
173 if (std::optional<int64_t> constNumThreads =
175 if (constNumThreads.value() > 1 &&
176 iterator != utils::IteratorType::parallel) {
177 op.emitWarning() <<
"tiling is not thread safe at axis #" << index;
184 if (constTileSize.value() > 0 &&
185 iterator != utils::IteratorType::parallel) {
186 op.emitWarning() <<
"tiling is not thread safe at axis #" << index;
203 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
212 if (ts && ts.value() == 1)
239 if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
241 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
253 int materializedLoopNum = 0;
255 if (!numThreads.empty()) {
260 offsetExpr = d0 + d1 * s0;
261 residualTileSizeExpr = s1 - (d0 + d1 * s0);
263 for (
auto [nt, tileSize, loopRange] :
264 llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
269 offsets.push_back(loopRange.offset);
270 sizes.push_back(loopRange.size);
274 Value iv = ivs[materializedLoopNum++];
276 rewriter, loc, offsetExpr,
279 rewriter, loc, residualTileSizeExpr,
280 {loopRange.offset, nt, tileSize, loopRange.size});
286 {offset, loopRange.size});
290 {sizeMinusOffsetPerThread, tileSize});
306 rewriter, loc, maxMap, {rewriter.
getIndexAttr(0), size});
309 offsets.push_back(offset);
310 sizes.push_back(size);
312 return {offsets, sizes};
314 for (
auto [tileSize, loopRange] :
315 llvm::zip_equal(tileSizes, iterationDomain)) {
320 offsets.push_back(loopRange.offset);
321 sizes.push_back(loopRange.size);
325 Value iv = ivs[materializedLoopNum++];
327 offsets.push_back(offset);
330 sizes.push_back(size);
332 return {offsets, sizes};
342 for (
auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
346 lbs.push_back(loopRange.offset);
347 ubs.push_back(loopRange.size);
348 steps.push_back(tileSize);
350 return {lbs, ubs, steps};
380 if (newDestArgs.empty())
382 if (
auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
383 destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
401 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
402 assert(loopRanges.size() == tileSizes.size() &&
403 "expected as many tile sizes as loop ranges");
407 std::tie(lbs, ubs, steps) =
417 for (
auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
419 rewriter.
create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
422 loops.push_back(loop);
423 ivs.push_back(loop.getInductionVar());
425 destinationTensors = loop.getRegionIterArgs();
430 if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
431 tiledResults, resultOffsets, resultSizes))) {
433 loc,
"failed to generate inner tile loop body");
438 assert(tiledResults.size() == destinationTensors.size() &&
439 "Number of results of body should be equal to number of iter args");
443 for (
auto [tiledValue, destinationTensor, resultOffset, resultSize] :
444 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
448 auto insertSlice = rewriter.
create<tensor::InsertSliceOp>(
449 loc, tiledValue, destinationTensor, resultOffset, resultSize,
451 yieldedValues.push_back(insertSlice);
453 rewriter.
create<scf::YieldOp>(loc, yieldedValues);
456 for (
auto [outerLoop, innerLoop] :
460 cast<scf::ForOp>(outerLoop.getOperation()).getBody());
461 rewriter.
create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
482 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
483 assert(loopRanges.size() == tileSizes.size() &&
484 "expected as many tile sizes as loop ranges");
487 std::optional<ArrayAttr> mappingAttr;
488 if (!mappingVector.empty())
491 scf::ForallOp forallOp;
492 bool useNumThreads = !numThreads.empty();
497 for (
auto nt : numThreads) {
500 nonZeroNumThreads.push_back(nt);
502 forallOp = rewriter.
create<scf::ForallOp>(loc, nonZeroNumThreads,
503 destinationTensors, mappingAttr);
506 std::tie(lbs, ubs, steps) =
508 forallOp = rewriter.
create<scf::ForallOp>(loc, lbs, ubs, steps,
509 destinationTensors, mappingAttr);
511 loops.push_back(forallOp);
514 destinationTensors = forallOp.getRegionOutArgs();
518 if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
519 destinationTensors, tiledResults, resultOffsets,
524 for (
auto [tiledValue, destinationTensor, resultOffset, resultSize] :
525 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
530 rewriter.
create<tensor::ParallelInsertSliceOp>(
531 loc, tiledValue, destinationTensor, resultOffset, resultSize,
557 return tiledBodyFn(rewriter, loc,
ValueRange{}, destinationTensors,
558 tiledResults, resultOffsets, resultSizes);
562 destinationTensors, tiledBodyFn, loops);
566 rewriter, loc, loopRanges, tileSizes, numThreads,
options.mappingVector,
567 destinationTensors, tiledBodyFn, loops);
572 static FailureOr<SmallVector<Value>>
578 switch (
options.reductionStrategy) {
585 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
588 op,
"PartialReductionOuterReduction tiling strategy is only supported"
589 "for operations implementing PartialReductionOpInterface");
595 for (
auto [idx, iteratorType] :
597 if (iteratorType == utils::IteratorType::reduction)
598 reductionDims.push_back(idx);
600 return redOp.generateInitialTensorForPartialReduction(
601 rewriter, loc, tileSizes, reductionDims);
605 "unhandled reduction tiling strategy");
609 static FailureOr<TilingResult>
614 switch (
options.reductionStrategy) {
616 return op.getTiledImplementation(rewriter, offsets, sizes);
619 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
622 op,
"PartialReductionOuterReduction tiling strategy is only "
623 "supported for operations "
624 "implementing PartialReductionOpInterface");
630 for (
auto [idx, iteratorType] :
632 if (iteratorType == utils::IteratorType::reduction)
633 reductionDims.push_back(idx);
635 return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg,
636 offsets, sizes, reductionDims);
640 "unhandled reduction tiling strategy");
652 switch (
options.reductionStrategy) {
654 return op.getResultTilePosition(rewriter, index, offsets, sizes,
655 resultOffset, resultSize);
658 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
661 op,
"PartialReductionOuterReduction tiling strategy is only supported"
662 "for operations implementing PartialReductionOpInterface");
668 for (
auto [idx, iteratorType] :
670 if (iteratorType == utils::IteratorType::reduction)
671 reductionDims.push_back(idx);
673 return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
674 resultOffset, resultSize,
679 "unhandled reduction tiling strategy");
683 static FailureOr<MergeResult>
687 switch (
options.reductionStrategy) {
693 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
696 op,
"PartialReductionOuterReduction tiling strategy is only "
697 "supported for operations "
698 "implementing PartialReductionOpInterface");
704 for (
auto [idx, iteratorType] :
706 if (iteratorType == utils::IteratorType::reduction)
707 reductionDims.push_back(idx);
709 return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
714 "unhandled reduction tiling strategy");
725 template <
typename LoopType>
726 FailureOr<LoopLikeOpInterface>
735 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
742 auto inits = llvm::to_vector(loopOp.getInitArgs());
743 inits.append(newInitOperands.begin(), newInitOperands.end());
744 auto newLoop = rewriter.
create<scf::ForOp>(
745 loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
749 Block *loopBody = loopOp.getBody();
750 Block *newLoopBody = newLoop.getBody();
752 loopBody, newLoopBody,
753 newLoopBody->
getArguments().take_front(loopBody->getNumArguments()));
755 auto yieldOp = cast<scf::YieldOp>(newLoopBody->
getTerminator());
761 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
762 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
763 newRegionIterArgs, tiledValues, resultOffsets,
770 for (
auto [tiledValue, regionIterArg, resultOffset, resultSize] :
771 llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
775 Value insert = rewriter.
create<tensor::InsertSliceOp>(
776 yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
778 newYieldValues.push_back(insert);
783 newLoop->getResults().take_front(loopOp.getNumResults()));
784 return cast<LoopLikeOpInterface>(newLoop.getOperation());
789 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
795 auto inits = llvm::to_vector(loopOp.getOutputs());
796 inits.append(newInitOperands.begin(), newInitOperands.end());
797 auto newLoop = rewriter.
create<scf::ForallOp>(
798 loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
799 loopOp.getMixedStep(), inits, loopOp.getMapping(),
803 Block *loopBody = loopOp.getBody();
804 Block *newLoopBody = newLoop.getBody();
806 loopBody, newLoopBody,
807 newLoopBody->
getArguments().take_front(loopBody->getNumArguments()));
809 auto terminator = cast<scf::InParallelOp>(newLoopBody->
getTerminator());
814 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
815 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
816 regionIterArgs, tiledValues, resultOffsets,
820 "failed to get yielded tiled values");
826 for (
auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
827 tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
830 rewriter.
create<tensor::ParallelInsertSliceOp>(
831 terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
836 newLoop->getResults().take_front(loopOp.getNumResults()));
837 return cast<LoopLikeOpInterface>(newLoop.getOperation());
847 loopLikeOp.getOperation())
848 .Case<scf::ForOp, scf::ForallOp>(
849 [&](
auto loopOp) -> FailureOr<LoopLikeOpInterface> {
851 loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
853 .Default([&](
auto loopOp) -> FailureOr<LoopLikeOpInterface> {
872 for (
auto &loop : loops.drop_back()) {
876 auto forLoop = cast<scf::ForOp>(loop.getOperation());
880 newInits.append(newInitValues.begin(), newInitValues.end());
881 auto newLoop = rewriter.
create<scf::ForOp>(
882 forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
883 forLoop.getStep(), newInits,
888 sourceBlockArgs.push_back(newLoop.getInductionVar());
889 auto newRegionIterArgs = newLoop.getRegionIterArgs();
890 sourceBlockArgs.append(
891 newRegionIterArgs.begin(),
892 std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
893 rewriter.
mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
895 forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
897 ivs.push_back(newLoop.getInductionVar());
898 newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
902 LoopLikeOpInterface innerMostLoop = loops.back();
903 FailureOr<LoopLikeOpInterface> newInnerMostLoop =
905 getNewTiledYieldsFn);
907 if (failed(newInnerMostLoop))
908 return innerMostLoop.emitOpError(
"failed to return additional yields");
909 loops.back() = newInnerMostLoop.value();
913 for (
auto [outerLoop, innerLoop] :
914 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
916 auto outerForLoop = cast<scf::ForOp>(outerLoop);
917 auto outerLoopYield =
918 cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
920 llvm::to_vector(outerLoopYield.getOperands());
922 innerLoop->getResults().take_back(newInitValues.size());
923 newYields.append(additionalYields.begin(), additionalYields.end());
932 FailureOr<scf::SCFTilingResult>
947 std::tie(tileSizes, numThreads) =
959 if (!
options.interchangeVector.empty()) {
961 iterationDomain.size());
963 "expected interchange vector to be a permutation");
967 if (!numThreads.empty())
971 FailureOr<TilingResult> tilingResult;
983 rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
987 if (!interchangeVector.empty()) {
996 auto clonedOp = cast<TilingInterface>(
1003 tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
1013 if (failed(tilingResult)) {
1015 return op.emitOpError(
"faild to tile operation");
1023 for (
auto [index, tiledValue] :
1025 tiledResults.push_back(tiledValue);
1028 sizes, resultOffset, resultSize,
1030 for (
auto op : tilingResult->tiledOps) {
1034 op,
"failed to get slice of result produced");
1036 resultOffsets.emplace_back(std::move(resultOffset));
1037 resultSizes.emplace_back(std::move(resultSize));
1044 FailureOr<SmallVector<Value>> maybeInits =
1046 if (failed(maybeInits)) {
1048 op,
"unable to create initial tensors for tiling");
1055 tileSizes, numThreads, initTensors,
1056 innerYieldTiledValuesFn, loops)))
1057 return op.emitOpError(
"failed to generate tiling loops");
1058 assert(succeeded(tilingResult) &&
1059 "expected tiling result to be computed after loop generation");
1061 if (loops.empty()) {
1067 tilingResult->tiledValues,
1068 tilingResult->generatedSlices,
1072 auto loopResults = llvm::map_to_vector(loops.front()->getResults(),
1076 if (
options.reductionStrategy ==
1079 tilingResult->
tiledOps, initTensors, loops, loopResults,
1080 tilingResult->generatedSlices, {}};
1084 FailureOr<MergeResult> mergeResult =
1086 if (failed(mergeResult)) {
1088 op,
"Failed to merge partial results from tiling");
1093 mergeResult->replacements,
1094 tilingResult->generatedSlices,
1095 mergeResult->mergeOps};
1098 FailureOr<scf::SCFTilingResult>
1100 PartialReductionOpInterface op,
1104 options.setReductionTilingStrategy(
1106 PartialReductionOuterReduction);
1107 options.setTileSizes(tileSize);
1121 static std::tuple<OpResult, std::optional<OpOperand *>>
1124 std::optional<OpOperand *> destinationIterArg;
1125 assert(!loops.empty() &&
"expected non empty loops container");
1126 auto loopIt = loops.rbegin();
1127 while (loopIt != loops.rend() && isa<BlockArgument>(source->
get())) {
1128 auto iterArg = cast<BlockArgument>(source->
get());
1129 auto loop = *loopIt;
1130 if (iterArg.getOwner()->getParentOp() != loop)
1132 source = loop.getTiedLoopInit(iterArg);
1135 if (loopIt == loops.rend())
1136 destinationIterArg = source;
1137 return {dyn_cast<OpResult>(source->
get()), destinationIterArg};
1142 std::optional<scf::SCFFuseProducerOfSliceResult>
1144 RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1148 auto [fusableProducer, destinationInitArg] =
1151 if (!fusableProducer)
1152 return std::nullopt;
1153 unsigned resultNumber = fusableProducer.getResultNumber();
1161 Operation *fusableProducerOp = fusableProducer.getOwner();
1162 if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
1164 rewriter, fusableProducerOp->
getLoc(), fusableProducerOp,
1165 origDestinationTensors)))
1166 return std::nullopt;
1168 clonedOpDestinationTensors = origDestinationTensors;
1169 if (destinationInitArg &&
1170 isa<DestinationStyleOpInterface>(fusableProducerOp)) {
1174 clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1178 rewriter, fusableProducerOp, clonedOpDestinationTensors);
1183 llvm::to_vector(candidateSliceOp->getOperands());
1184 candidateSliceOpOperands[0] = clonedProducerOp->
getResult(resultNumber);
1185 tensor::ExtractSliceOp clonedCandidateSliceOp =
1187 candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1190 FailureOr<TilingResult> tileAndFuseResult =
1192 rewriter, clonedCandidateSliceOp,
1193 clonedProducerOp->
getResult(resultNumber));
1194 if (failed(tileAndFuseResult))
1195 return std::nullopt;
1199 tileAndFuseResult->tiledValues[0]);
1200 rewriter.
eraseOp(clonedCandidateSliceOp);
1201 rewriter.
eraseOp(clonedProducerOp);
1246 if (destinationInitArg &&
1247 isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
1249 ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1250 .set(origDestinationTensors[resultNumber]);
1253 fusableProducer, tileAndFuseResult->tiledValues[0],
1254 tileAndFuseResult->
tiledOps, tileAndFuseResult->generatedSlices};
1259 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1267 *tiledOwner = fusedProducerInfo.
tiledOps[0];
1272 yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
1274 : llvm::to_vector(yieldResultNumber);
1276 for (
const auto &resultNumber : initNumberList) {
1278 rewriter, loc, originalOwner->
getResult(resultNumber));
1279 if (succeeded(initValue)) {
1280 initValueList.push_back(initValue.value());
1296 sliceSizes = sliceOp.getMixedSizes();
1299 if (!llvm::all_of(sliceOp.getMixedStrides(),
isOneInteger))
1302 unsigned sliceResultNumber =
1305 auto tilableOp = cast<TilingInterface>(originalOwner);
1309 if (tilableOp->getNumResults() > 1 &&
1310 failed(tilableOp.getIterationDomainTileFromResultTile(
1311 rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1312 iterDomainOffset, iterDomainSizes))) {
1327 for (
const auto &resultNumber : initNumberList) {
1328 if (resultNumber == sliceResultNumber) {
1329 offsetList.push_back(sliceOffset);
1330 sizesList.push_back(sliceSizes);
1332 assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1335 if (failed(tilableOp.getResultTilePosition(
1336 rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1340 offsetList.push_back(offset);
1341 sizesList.push_back(sizes);
1347 if (
auto tiledDestStyleOp =
1348 dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1350 for (
const auto &&[index, newRegionArg] :
1352 auto destSlice = rewriter.
create<tensor::ExtractSliceOp>(
1353 loc, newRegionArg, offsetList[index], sizesList[index],
1356 generatedSlices.push_back(destSlice);
1357 unsigned resultNumber = initNumberList[index];
1359 tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1368 for (
const auto &&[index, resultNumber] :
llvm::enumerate(initNumberList)) {
1369 tiledResult.push_back(tiledOwner->getResult(resultNumber));
1370 tiledOffset.emplace_back(offsetList[index]);
1371 tiledSizes.emplace_back(sizesList[index]);
1377 newYieldValuesFn))) {
1380 return generatedSlices;
1394 explicit SliceTrackingListener(
1395 std::optional<FrozenRewritePatternSet>
patterns);
1396 SliceTrackingListener() =
default;
1405 void notifyOperationInserted(
Operation *op,
1412 void notifyOperationErased(
Operation *op)
override;
1419 std::deque<tensor::ExtractSliceOp> worklist;
1424 std::optional<FrozenRewritePatternSet>
patterns = std::nullopt;
1427 SliceTrackingListener::SliceTrackingListener(
1428 std::optional<FrozenRewritePatternSet> p) {
1435 if (
auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1436 worklist.push_back(slice);
1448 void SliceTrackingListener::notifyOperationInserted(
1450 auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1453 worklist.push_back(slice);
1459 void SliceTrackingListener::removeOp(
Operation *op) {
1460 if (!isa<tensor::ExtractSliceOp>(op))
1462 auto iter = worklist.begin();
1463 while (iter != worklist.end()) {
1468 if (iter == worklist.end())
1471 worklist.erase(iter);
1474 void SliceTrackingListener::notifyOperationErased(
Operation *op) {
1478 void SliceTrackingListener::notifyOperationReplaced(
Operation *op,
1494 : ForwardingListener(listener), replacements(replacements) {}
1496 void updateReplacementValues(
ValueRange origValues,
1500 for (
auto &[key, val] : replacements) {
1501 for (
auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) {
1510 ForwardingListener::notifyOperationReplaced(op, newOp);
1515 ForwardingListener::notifyOperationReplaced(op, values);
1516 updateReplacementValues(op->
getResults(), values);
1526 FailureOr<scf::SCFTileAndFuseResult>
1532 if (!consumer->getNumResults()) {
1534 consumer,
"invalid pattern for op with no results");
1540 FailureOr<scf::SCFTilingResult> tilingResult =
1543 if (failed(tilingResult))
1545 tiledAndFusedOps.insert_range(tilingResult->tiledOps);
1548 for (
auto [origVal, replacement] :
1549 llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
1550 replacements[origVal] = replacement;
1554 auto &loops = tilingResult->loops;
1555 if (loops.empty()) {
1564 auto resetListener =
1565 llvm::make_scope_exit([&]() { rewriter.
setListener(previousListener); });
1566 ReplacementListener replaceListener(replacements, previousListener);
1576 struct WorklistItem {
1577 tensor::ExtractSliceOp candidateSlice;
1581 SliceTrackingListener sliceTracker =
1582 SliceTrackingListener(
options.cleanupPatterns);
1585 sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1589 while (!sliceTracker.worklist.empty()) {
1590 auto candidateSlice = sliceTracker.worklist.front();
1591 sliceTracker.worklist.pop_front();
1593 auto [fusableProducer, destinationInitArg] =
1596 if (!fusableProducer)
1599 std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1600 options.fusionControlFn(candidateSlice, fusableProducer,
1601 destinationInitArg.has_value());
1602 if (!controlFnResult)
1605 WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
1610 std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1618 if (worklistItem.controlFnResult.yieldProducerReplacement) {
1623 Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1624 FailureOr<SmallVector<Operation *>> newSlices =
1626 worklistItem.candidateSlice,
1627 fusedResult.value(), loops);
1628 if (failed(newSlices)) {
1630 fusableProducerOp,
"failed to replacement value for this "
1631 "operation from within the tiled loop");
1633 worklistCandidates.append(newSlices.value());
1634 for (
auto [index, result] :
1636 replacements[result] = loops.front()->getResult(
1637 loops.front()->getNumResults() -
1642 fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1643 fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1644 tiledAndFusedOps.insert(tiledAndFusedOp);
1647 if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1662 static LogicalResult
1664 Value result = candidateSliceOp.getResult();
1666 if (!llvm::hasSingleElement(uses)) {
1667 LLVM_DEBUG(llvm::dbgs() <<
"Too many uses of the candidate slice op\n");
1670 OpOperand &operandUse = (*uses.begin());
1672 if (!isa<scf::YieldOp>(userOp)) {
1673 LLVM_DEBUG(llvm::dbgs()
1674 <<
"Expected scf.yield to be the only user, but got -> "
1679 LLVM_DEBUG(llvm::dbgs() <<
"Expected tensor.insert_slice and scf.yield to "
1680 "be in the same block\n");
1689 if (!isa<LoopLikeOpInterface>(loopOp))
1708 if (isa<tensor::ParallelInsertSliceOp>(userOp))
1711 if (loopOp->
getBlock() != userOp->getBlock())
1715 firstUserOfLoop = userOp;
1717 return firstUserOfLoop;
1758 static FailureOr<llvm::SetVector<Operation *>>
1760 bool reorderOperations) {
1762 if (failed(firstUserOfLoop))
1768 options.omitBlockArguments =
true;
1769 bool includeLoopOp =
false;
1772 includeLoopOp =
true;
1782 assert(result.succeeded() &&
"expected a backward slice");
1786 if (!slice.empty()) {
1796 if (includeLoopOp || !reorderOperations)
1808 unsigned resultNumber) {
1809 if (!isa<LoopLikeOpInterface>(loopOp))
1814 Operation *consumerOp = opOperand.getOwner();
1816 if (!isa<TilingInterface>(consumerOp) ||
1817 !isa<DestinationStyleOpInterface>(consumerOp)) {
1824 if (loopBlock != consumerOp->
getBlock())
1831 FailureOr<llvm::SetVector<Operation *>> slice =
1837 if (!slice->empty()) {
1840 assert(succeeded(firstUserOfLoop) &&
"First user of loop is not found");
1841 for (
auto op : *slice) {
1865 assert(!loops.empty() &&
"unexpected empty loop nest");
1866 if (loops.size() == 1) {
1867 return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
1869 for (
auto [outerLoop, innerLoop] :
1870 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
1871 auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
1872 auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
1873 if (!outerFor || !innerFor) {
1876 auto outerBBArgs = outerFor.getRegionIterArgs();
1877 auto innerIterArgs = innerFor.getInitArgs();
1878 if (outerBBArgs.size() != innerIterArgs.size()) {
1882 for (
auto [outerBBArg, innerIterArg] :
1883 llvm::zip_equal(outerBBArgs, innerIterArgs)) {
1884 if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
1885 innerIterArg != outerBBArg) {
1891 cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
1892 ValueRange innerResults = innerFor.getResults();
1893 if (outerYields.size() != innerResults.size()) {
1896 for (
auto [outerYield, innerResult] :
1897 llvm::zip_equal(outerYields, innerResults)) {
1898 if (!llvm::hasSingleElement(innerResult.getUses()) ||
1899 outerYield != innerResult) {
1913 static FailureOr<OpOperand *>
1915 tensor::InsertSliceOp candidateSliceOp,
1917 assert(!loops.empty() &&
"unexpected loops to be empty");
1920 if (containingOp != loops.back()) {
1923 "expected slice to be within body of inner-most loop");
1929 candidateSliceOp,
"expected passed loops to be perfectly nested.");
1934 Value sliceResult = candidateSliceOp.getResult();
1940 scf::ForOp topLevelForOp = cast<scf::ForOp>(loops.front().getOperation());
1947 static FailureOr<OpOperand *>
1949 tensor::ParallelInsertSliceOp candidateSliceOp,
1951 assert(!loops.empty() &&
"unexpected loops to be empty");
1953 if (loops.size() != 1) {
1955 candidateSliceOp,
"expected single surrounding scf.forall");
1957 auto forallOp = dyn_cast<scf::ForallOp>(loops.front().getOperation());
1960 candidateSliceOp,
"expected single surrounding scf.forall");
1964 Value sliceDest = candidateSliceOp.getDest();
1965 auto iterArg = dyn_cast<BlockArgument>(sliceDest);
1968 if (iterArg.getOwner()->getParentOp() != forallOp)
1971 unsigned resultNumber =
1972 forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
1980 static FailureOr<OpOperand *>
1983 assert(!loops.empty() &&
"unexpected empty loops");
1984 if (
auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1986 }
else if (
auto parallelInsertSlice =
1987 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1996 FailureOr<scf::SCFFuseConsumerOfSliceResult>
2002 if (loops.empty()) {
2004 "cannot call tile and fuse consumer with an empty loop nest");
2006 if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2012 FailureOr<OpOperand *> maybeConsumerOpOperand =
2014 if (failed(maybeConsumerOpOperand)) {
2016 "could not fetch consumer to fuse");
2018 OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
2021 unsigned resultNumber = 0;
2022 if (
auto producerResult = dyn_cast<OpResult>(consumerOpOperand->
get())) {
2023 resultNumber = producerResult.getResultNumber();
2026 consumerOp,
"consumer op's operand doesn't seem to be an OpResult");
2029 LoopLikeOpInterface outerMostLoop = loops.front();
2030 LoopLikeOpInterface innerMostLoop = loops.back();
2035 outerMostLoop,
"the first user of loop should not dominate any define "
2036 "of consumer operand(s)");
2042 auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
2045 "consumer op is not DPS operation");
2047 llvm::map_to_vector(dstOp.getDpsInits(), [](
Value v) { return v; });
2048 if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
2051 "consumer op taking the result of scf.for as init is not supported");
2055 Location loc = outerMostLoop->getLoc();
2060 if (failed(firstUserOfLoop)) {
2062 outerMostLoop,
"could not find the first user of outer most loop");
2064 rewriter.
moveOpBefore(outerMostLoop, *firstUserOfLoop);
2070 tensor::InsertSliceOp clonedInsertSliceOp;
2072 dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
2073 auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
2075 clonedInsertSliceOp = rewriter.
create<tensor::InsertSliceOp>(
2076 loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
2077 sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
2080 clonedInsertSliceOp =
2081 cast<tensor::InsertSliceOp>(rewriter.
clone(*candidateSliceOp));
2085 auto clonedConsumerOp = cast<TilingInterface>(rewriter.
clone(*consumerOp));
2089 OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
2091 operandToReplace.
set(clonedInsertSliceOp.getResult());
2097 cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
2098 FailureOr<TilingResult> tileAndFuseResult =
2100 rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
2101 if (failed(tileAndFuseResult)) {
2104 auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
2106 clonedInsertSliceOp.getSource());
2125 candidateSliceOp,
"containingOp's result yield with stride");
2135 if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
2136 rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
2137 iterDomainSizes))) {
2140 "can't get iter domain position from input position");
2146 unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
2148 totalNumResultsOfConsumer);
2150 totalNumResultsOfConsumer);
2151 for (
auto [idx, v] :
llvm::enumerate(tiledConsumerOp->getResults())) {
2152 if (failed(tiledConsumerOp.getResultTilePosition(
2153 rewriter, idx, iterDomainOffsets, iterDomainSizes,
2154 resultOffsets[idx], resultSizes[idx]))) {
2157 "can't get result domain position from iter domain position");
2163 if (
auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
2164 tiledConsumerOp.getOperation())) {
2166 for (
const auto &&[index, newRegionArg] :
2168 auto destSlice = rewriter.
create<tensor::ExtractSliceOp>(
2169 loc, newRegionArg, resultOffsets[index], resultSizes[index],
2174 auto dstNumber = index;
2176 tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
2185 for (
const auto &&[index, result] :
2187 tiledResult.push_back(result);
2188 tiledOffset.emplace_back(resultOffsets[index]);
2189 tiledSizes.emplace_back(resultSizes[index]);
2195 newYieldValuesFn))) {
2197 "unable to add new inits to nest loop");
2203 for (
auto &&[oldResult, newResult] :
2205 loops.front()->getResults().take_back(newInits.size()))) {
2210 rewriter.
eraseOp(clonedConsumerOp);
2214 &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
2215 tileAndFuseResult->tiledOps};
2222 FailureOr<SmallVector<scf::ForOp>>
2224 TilingInterface op) {
2226 if (op->getNumResults() > 0) {
2228 op,
"unable to lower to loops operations with return values");
2235 for (
auto loopRange : domain) {
2242 auto loop = rewriter.
create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
2244 loops.push_back(loop);
2245 ivs.push_back(loop.getInductionVar());
2248 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static LogicalResult verifyTileSizeOptions(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options)
Verify the tile size options are set in a consistent manner.
static LogicalResult checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp)
A utility function that checks whether the only use of the result of a tensor.insert_slice op is in a...
static bool isPerfectlyNestedForLoops(MutableArrayRef< LoopLikeOpInterface > loops)
Check that the loop is perfectly nested.
std::function< LogicalResult(RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, SmallVector< Value > &tiledValues, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> YieldTiledValuesFn
A function that allows returning additional yielded values during yieldTiledValuesAndReplace.
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, OpFoldResult numThreads, OpFoldResult iterationSize)
Returns true if the maximum tile offset tileSize * numThreads-1 is less than iterationSize.
static LogicalResult getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize, const scf::SCFTilingOptions &options)
static FailureOr< llvm::SetVector< Operation * > > checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, bool reorderOperations)
This utility currently checks whether the first userOp of loop is NOT before the last defineOp of con...
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, OpFoldResult offset, OpFoldResult tileSize)
Returns the bounded tile size given the current offset, loopRange and tileSize, i....
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Append the specified additional newInitOperands operands to the loops existing init operands (or simi...
static LogicalResult generateLoopNestUsingForOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ValueRange destinationTensors, YieldTiledValuesFn yieldTiledValuesFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.for operation.
static FailureOr< OpOperand * > getConsumerFromLoopUses(RewriterBase &rewriter, Operation *loopOp, unsigned resultNumber)
Fetches the OpOperand of the first valid user (and use) of the value val which implements TilingInter...
static void checkSafeToTileToForall(TilingInterface op, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Checks if any of the tiled loops are not parallel.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, ArrayRef< Range > iterationDomain, const scf::SCFTilingOptions &options)
Method to instantiate the tile sizes and/or number of threads specified by the user.
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< LoopLikeOpInterface > loops)
Return the untiled producer whose slice is used in a tiled consumer.
static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using the loop construct specifed in options.
static bool tileDividesIterationDomain(Range loopRange)
Check if stride evenly divides the trip count size - offset.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes)
Function to return the bounds of the loops to be generated.
static Operation * cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *op, ValueRange newDestArgs)
Clones the operation and updates the destination if the operation implements the DestinationStyleOpIn...
static FailureOr< Operation * > getFirstUserOfLoop(Operation *loopOp)
An utility to get the first user of the given loopOp.
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, const scf::SCFTilingOptions &options)
static LogicalResult addInitOperandsToLoopNest(RewriterBase &rewriter, MutableArrayRef< LoopLikeOpInterface > loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn)
Method to add new init values to a loop nest.
static FailureOr< MergeResult > mergeTilingResults(RewriterBase &rewriter, TilingInterface op, ValueRange partialResults, const scf::SCFTilingOptions &options)
static FailureOr< OpOperand * > getUntiledConsumerFromSlice(RewriterBase &rewriter, tensor::InsertSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Fetch the untiled consumer of the outermost scf.for's result which is yielded by a tensor....
static FailureOr< SmallVector< Value > > createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, ArrayRef< OpFoldResult > tileSizes, const scf::SCFTilingOptions &options)
static LogicalResult generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ArrayRef< Attribute > mappingVector, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.forall operation.
Base type for affine expression.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
This class allows control over how the GreedyPatternRewriteDriver works.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class represents a saved insertion point.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation * getOwner() const
Returns the operation that owns this result.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< scf::SCFFuseConsumerOfSliceResult > tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing consumer of a single slice by computing the slice of the consumer in-place f...
FailureOr< scf::SCFTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSizes)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< SmallVector< Operation * > > yieldReplacementForFusedProducer(RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef< LoopLikeOpInterface > loops, ArrayRef< unsigned > yieldResultNumber=ArrayRef< unsigned >{})
Reconstruct the fused producer from within the tiled-and-fused code.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
std::optional< SCFFuseProducerOfSliceResult > tileAndFuseProducerOfSlice(RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing producer of a single slice by computing the slice of the producer in-place.
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
FailureOr< TilingResult > replaceInsertSliceWithTiledConsumer(OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp, OpOperand &consumerOp)
Method to swap an tensor.insert_slice with its consumer when the consumer implements the TilingInterf...
Include the generated interface declarations.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
@ ExistingAndNewOps
Only pre-existing and newly created ops are processed.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Container for the result of merge operation of tiling.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
Container for result values of tiling.
Fuse the consumer of the source of candidateSliceOp by computing the required slice of the consumer i...
Fuse the producer of the source of candidateSliceOp by computing the required slice of the producer i...
SmallVector< Operation * > tiledOps
Control function to check if a slice needs to be fused or not, The control function receives 1) the s...
Options used to control tile + fuse.
Transformation information returned after tile and fuse.
Options to use to control tiling.
SCFTileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes to use for each loop.
SCFTilingOptions & setNumThreads(ArrayRef< OpFoldResult > numThreads)
Convenience function to set the numThreadsComputationFunction to a function that computes num threads...
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
ReductionTilingStrategy
Specify how reduction dimensions should be tiled.
@ PartialReductionOuterReduction
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.