31 #include "llvm/ADT/ScopeExit.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/Debug.h"
36 #define DEBUG_TYPE "tile-using-interface"
43 auto tileSizes = llvm::to_vector(ts);
52 assert(!numThreadsComputationFunction &&
"num tiles already set");
53 auto numThreads = llvm::to_vector(nt);
64 size_t iterationDomainSize) {
66 if (filledVector.size() < iterationDomainSize) {
67 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
68 filledVector.append(range.begin(), range.end());
70 if (filledVector.size() > iterationDomainSize)
71 filledVector.resize(iterationDomainSize);
84 if (
options.numThreadsComputationFunction &&
87 loc,
"number of threads can only by specified when loop type is "
88 "set to use `scf.forall`");
92 if (!
options.interchangeVector.empty()) {
95 loc,
"invalid interchange vector, not a permutation of the entire "
110 size_t numLoops = iterationDomain.size();
113 if (
options.numThreadsComputationFunction) {
114 numThreads =
options.numThreadsComputationFunction(rewriter, op);
115 numThreads.resize(numLoops, zero);
118 if (
options.tileSizeComputationFunction) {
119 tileSizes =
options.tileSizeComputationFunction(rewriter, op);
120 tileSizes.resize(numLoops, zero);
121 return {tileSizes, numThreads};
133 tileSizes.resize(numLoops, zero);
134 for (
auto [index, range, nt] :
140 rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
142 tileSizes.resize(numLoops, zero);
143 return {tileSizes, numThreads};
150 assert(
options.tileSizeComputationFunction &&
151 "expected tile sizes to be specified");
152 tileSizes =
options.tileSizeComputationFunction(rewriter, op);
153 tileSizes.resize(numLoops, zero);
155 return {tileSizes, numThreads};
162 auto iterators = op.getLoopIteratorTypes();
163 assert(iterators.size() == tileSizes.size() &&
164 "expected as many tile size values as number of loops");
165 assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
166 "when specified, expected number of threads to use for each loop");
168 for (
auto [index, iterator, tileSize] :
172 if (!numThreads.empty()) {
173 if (std::optional<int64_t> constNumThreads =
175 if (constNumThreads.value() > 1 &&
176 iterator != utils::IteratorType::parallel) {
177 op.emitWarning() <<
"tiling is not thread safe at axis #" << index;
184 if (constTileSize.value() > 0 &&
185 iterator != utils::IteratorType::parallel) {
186 op.emitWarning() <<
"tiling is not thread safe at axis #" << index;
203 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
212 if (ts && ts.value() == 1)
239 if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
241 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
253 int materializedLoopNum = 0;
255 if (!numThreads.empty()) {
260 offsetExpr = d0 + d1 * s0;
261 residualTileSizeExpr = s1 - (d0 + d1 * s0);
263 for (
auto [nt, tileSize, loopRange] :
264 llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
269 offsets.push_back(loopRange.offset);
270 sizes.push_back(loopRange.size);
274 Value iv = ivs[materializedLoopNum++];
276 rewriter, loc, offsetExpr,
279 rewriter, loc, residualTileSizeExpr,
280 {loopRange.offset, nt, tileSize, loopRange.size});
286 {offset, loopRange.size});
290 {sizeMinusOffsetPerThread, tileSize});
306 rewriter, loc, maxMap, {rewriter.
getIndexAttr(0), size});
309 offsets.push_back(offset);
310 sizes.push_back(size);
312 return {offsets, sizes};
314 for (
auto [tileSize, loopRange] :
315 llvm::zip_equal(tileSizes, iterationDomain)) {
320 offsets.push_back(loopRange.offset);
321 sizes.push_back(loopRange.size);
325 Value iv = ivs[materializedLoopNum++];
327 offsets.push_back(offset);
330 sizes.push_back(size);
332 return {offsets, sizes};
342 for (
auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
346 lbs.push_back(loopRange.offset);
347 ubs.push_back(loopRange.size);
348 steps.push_back(tileSize);
350 return {lbs, ubs, steps};
380 if (newDestArgs.empty())
382 if (
auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
383 destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
401 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
402 assert(loopRanges.size() == tileSizes.size() &&
403 "expected as many tile sizes as loop ranges");
407 std::tie(lbs, ubs, steps) =
417 for (
auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
419 rewriter.
create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
422 loops.push_back(loop);
423 ivs.push_back(loop.getInductionVar());
425 destinationTensors = loop.getRegionIterArgs();
430 if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
431 tiledResults, resultOffsets, resultSizes))) {
433 loc,
"failed to generate inner tile loop body");
438 assert(tiledResults.size() == destinationTensors.size() &&
439 "Number of results of body should be equal to number of iter args");
443 for (
auto [tiledValue, destinationTensor, resultOffset, resultSize] :
444 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
448 auto insertSlice = rewriter.
create<tensor::InsertSliceOp>(
449 loc, tiledValue, destinationTensor, resultOffset, resultSize,
451 yieldedValues.push_back(insertSlice);
453 rewriter.
create<scf::YieldOp>(loc, yieldedValues);
456 for (
auto [outerLoop, innerLoop] :
460 cast<scf::ForOp>(outerLoop.getOperation()).getBody());
461 rewriter.
create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
482 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
483 assert(loopRanges.size() == tileSizes.size() &&
484 "expected as many tile sizes as loop ranges");
487 std::optional<ArrayAttr> mappingAttr;
488 if (!mappingVector.empty())
491 scf::ForallOp forallOp;
492 bool useNumThreads = !numThreads.empty();
497 for (
auto nt : numThreads) {
500 nonZeroNumThreads.push_back(nt);
502 forallOp = rewriter.
create<scf::ForallOp>(loc, nonZeroNumThreads,
503 destinationTensors, mappingAttr);
506 std::tie(lbs, ubs, steps) =
508 forallOp = rewriter.
create<scf::ForallOp>(loc, lbs, ubs, steps,
509 destinationTensors, mappingAttr);
511 loops.push_back(forallOp);
514 destinationTensors = forallOp.getRegionOutArgs();
518 if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
519 destinationTensors, tiledResults, resultOffsets,
524 for (
auto [tiledValue, destinationTensor, resultOffset, resultSize] :
525 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
530 rewriter.
create<tensor::ParallelInsertSliceOp>(
531 loc, tiledValue, destinationTensor, resultOffset, resultSize,
557 return tiledBodyFn(rewriter, loc,
ValueRange{}, destinationTensors,
558 tiledResults, resultOffsets, resultSizes);
562 destinationTensors, tiledBodyFn, loops);
566 rewriter, loc, loopRanges, tileSizes, numThreads,
options.mappingVector,
567 destinationTensors, tiledBodyFn, loops);
572 static FailureOr<SmallVector<Value>>
578 switch (
options.reductionStrategy) {
585 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
588 op,
"PartialReductionOuterReduction tiling strategy is only supported"
589 "for operations implementing PartialReductionOpInterface");
595 for (
auto [idx, iteratorType] :
597 if (iteratorType == utils::IteratorType::reduction)
598 reductionDims.push_back(idx);
600 return redOp.generateInitialTensorForPartialReduction(
601 rewriter, loc, tileSizes, reductionDims);
605 "unhandled reduction tiling strategy");
609 static FailureOr<TilingResult>
614 switch (
options.reductionStrategy) {
616 return op.getTiledImplementation(rewriter, offsets, sizes);
619 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
622 op,
"PartialReductionOuterReduction tiling strategy is only "
623 "supported for operations "
624 "implementing PartialReductionOpInterface");
630 for (
auto [idx, iteratorType] :
632 if (iteratorType == utils::IteratorType::reduction)
633 reductionDims.push_back(idx);
635 return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg,
636 offsets, sizes, reductionDims);
640 "unhandled reduction tiling strategy");
652 switch (
options.reductionStrategy) {
654 return op.getResultTilePosition(rewriter, index, offsets, sizes,
655 resultOffset, resultSize);
658 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
661 op,
"PartialReductionOuterReduction tiling strategy is only supported"
662 "for operations implementing PartialReductionOpInterface");
668 for (
auto [idx, iteratorType] :
670 if (iteratorType == utils::IteratorType::reduction)
671 reductionDims.push_back(idx);
673 return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
674 resultOffset, resultSize,
679 "unhandled reduction tiling strategy");
683 static FailureOr<MergeResult>
687 switch (
options.reductionStrategy) {
693 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
696 op,
"PartialReductionOuterReduction tiling strategy is only "
697 "supported for operations "
698 "implementing PartialReductionOpInterface");
704 for (
auto [idx, iteratorType] :
706 if (iteratorType == utils::IteratorType::reduction)
707 reductionDims.push_back(idx);
709 return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
714 "unhandled reduction tiling strategy");
725 template <
typename LoopType>
726 FailureOr<LoopLikeOpInterface>
735 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
742 auto inits = llvm::to_vector(loopOp.getInitArgs());
743 inits.append(newInitOperands.begin(), newInitOperands.end());
744 auto newLoop = rewriter.
create<scf::ForOp>(
745 loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
749 Block *loopBody = loopOp.getBody();
750 Block *newLoopBody = newLoop.getBody();
752 loopBody, newLoopBody,
753 newLoopBody->
getArguments().take_front(loopBody->getNumArguments()));
755 auto yieldOp = cast<scf::YieldOp>(newLoopBody->
getTerminator());
761 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
762 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
763 newRegionIterArgs, tiledValues, resultOffsets,
770 for (
auto [tiledValue, regionIterArg, resultOffset, resultSize] :
771 llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
775 Value insert = rewriter.
create<tensor::InsertSliceOp>(
776 yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
778 newYieldValues.push_back(insert);
783 newLoop->getResults().take_front(loopOp.getNumResults()));
784 return cast<LoopLikeOpInterface>(newLoop.getOperation());
789 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
795 auto inits = llvm::to_vector(loopOp.getOutputs());
796 inits.append(newInitOperands.begin(), newInitOperands.end());
797 auto newLoop = rewriter.
create<scf::ForallOp>(
798 loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
799 loopOp.getMixedStep(), inits, loopOp.getMapping(),
803 Block *loopBody = loopOp.getBody();
804 Block *newLoopBody = newLoop.getBody();
806 loopBody, newLoopBody,
807 newLoopBody->
getArguments().take_front(loopBody->getNumArguments()));
809 auto terminator = cast<scf::InParallelOp>(newLoopBody->
getTerminator());
814 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
815 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
816 regionIterArgs, tiledValues, resultOffsets,
820 "failed to get yielded tiled values");
826 for (
auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
827 tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
830 rewriter.
create<tensor::ParallelInsertSliceOp>(
831 terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
836 newLoop->getResults().take_front(loopOp.getNumResults()));
837 return cast<LoopLikeOpInterface>(newLoop.getOperation());
847 loopLikeOp.getOperation())
848 .Case<scf::ForOp, scf::ForallOp>(
849 [&](
auto loopOp) -> FailureOr<LoopLikeOpInterface> {
851 loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
853 .Default([&](
auto loopOp) -> FailureOr<LoopLikeOpInterface> {
872 for (
auto &loop : loops.drop_back()) {
876 auto forLoop = cast<scf::ForOp>(loop.getOperation());
880 newInits.append(newInitValues.begin(), newInitValues.end());
881 auto newLoop = rewriter.
create<scf::ForOp>(
882 forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
883 forLoop.getStep(), newInits,
888 sourceBlockArgs.push_back(newLoop.getInductionVar());
889 auto newRegionIterArgs = newLoop.getRegionIterArgs();
890 sourceBlockArgs.append(
891 newRegionIterArgs.begin(),
892 std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
893 rewriter.
mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
895 forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
897 ivs.push_back(newLoop.getInductionVar());
898 newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
902 LoopLikeOpInterface innerMostLoop = loops.back();
903 FailureOr<LoopLikeOpInterface> newInnerMostLoop =
905 getNewTiledYieldsFn);
907 if (failed(newInnerMostLoop))
908 return innerMostLoop.emitOpError(
"failed to return additional yields");
909 loops.back() = newInnerMostLoop.value();
913 for (
auto [outerLoop, innerLoop] :
914 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
916 auto outerForLoop = cast<scf::ForOp>(outerLoop);
917 auto outerLoopYield =
918 cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
920 llvm::to_vector(outerLoopYield.getOperands());
922 innerLoop->getResults().take_back(newInitValues.size());
923 newYields.append(additionalYields.begin(), additionalYields.end());
932 FailureOr<scf::SCFTilingResult>
947 std::tie(tileSizes, numThreads) =
959 if (!
options.interchangeVector.empty()) {
961 iterationDomain.size());
963 "expected interchange vector to be a permutation");
967 if (!numThreads.empty())
971 FailureOr<TilingResult> tilingResult;
983 rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
987 if (!interchangeVector.empty()) {
996 auto clonedOp = cast<TilingInterface>(
1003 tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
1013 if (failed(tilingResult)) {
1015 return op.emitOpError(
"faild to tile operation");
1023 for (
auto [index, tiledValue] :
1025 tiledResults.push_back(tiledValue);
1028 sizes, resultOffset, resultSize,
1030 for (
auto op : tilingResult->tiledOps) {
1034 op,
"failed to get slice of result produced");
1036 resultOffsets.emplace_back(std::move(resultOffset));
1037 resultSizes.emplace_back(std::move(resultSize));
1044 FailureOr<SmallVector<Value>> maybeInits =
1046 if (failed(maybeInits)) {
1048 op,
"unable to create initial tensors for tiling");
1055 tileSizes, numThreads, initTensors,
1056 innerYieldTiledValuesFn, loops)))
1057 return op.emitOpError(
"failed to generate tiling loops");
1058 assert(succeeded(tilingResult) &&
1059 "expected tiling result to be computed after loop generation");
1062 if (loops.empty()) {
1065 partialResults = tilingResult->tiledValues;
1067 partialResults = llvm::map_to_vector(loops.front()->getResults(),
1071 FailureOr<MergeResult> mergeResult =
1073 if (failed(mergeResult)) {
1075 op,
"Failed to merge partial results from tiling");
1079 mergeResult.value(),
1080 tilingResult->generatedSlices};
1083 FailureOr<scf::SCFTilingResult>
1085 PartialReductionOpInterface op,
1090 PartialReductionOuterReduction);
1091 options.setTileSizes(tileSizes);
1093 TilingInterface tilingInterfaceOp =
1094 dyn_cast<TilingInterface>(op.getOperation());
1095 if (!tilingInterfaceOp) {
1098 "Operation implementing PartialReductionOpInterface should implement "
1115 static std::tuple<OpResult, std::optional<OpOperand *>>
1118 std::optional<OpOperand *> destinationIterArg;
1119 assert(!loops.empty() &&
"expected non empty loops container");
1120 auto loopIt = loops.rbegin();
1121 while (loopIt != loops.rend() && isa<BlockArgument>(source->
get())) {
1122 auto iterArg = cast<BlockArgument>(source->
get());
1123 auto loop = *loopIt;
1124 if (iterArg.getOwner()->getParentOp() != loop)
1126 source = loop.getTiedLoopInit(iterArg);
1129 if (loopIt == loops.rend())
1130 destinationIterArg = source;
1131 return {dyn_cast<OpResult>(source->
get()), destinationIterArg};
1136 std::optional<scf::SCFFuseProducerOfSliceResult>
1138 RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1142 auto [fusableProducer, destinationInitArg] =
1145 if (!fusableProducer)
1146 return std::nullopt;
1147 unsigned resultNumber = fusableProducer.getResultNumber();
1155 Operation *fusableProducerOp = fusableProducer.getOwner();
1156 if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
1158 rewriter, fusableProducerOp->
getLoc(), fusableProducerOp,
1159 origDestinationTensors)))
1160 return std::nullopt;
1162 clonedOpDestinationTensors = origDestinationTensors;
1163 if (destinationInitArg &&
1164 isa<DestinationStyleOpInterface>(fusableProducerOp)) {
1168 clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1172 rewriter, fusableProducerOp, clonedOpDestinationTensors);
1177 llvm::to_vector(candidateSliceOp->getOperands());
1178 candidateSliceOpOperands[0] = clonedProducerOp->
getResult(resultNumber);
1179 tensor::ExtractSliceOp clonedCandidateSliceOp =
1181 candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1184 FailureOr<TilingResult> tileAndFuseResult =
1186 rewriter, clonedCandidateSliceOp,
1187 clonedProducerOp->
getResult(resultNumber));
1188 if (failed(tileAndFuseResult))
1189 return std::nullopt;
1193 tileAndFuseResult->tiledValues[0]);
1194 rewriter.
eraseOp(clonedCandidateSliceOp);
1195 rewriter.
eraseOp(clonedProducerOp);
1240 if (destinationInitArg &&
1241 isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
1243 ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1244 .set(origDestinationTensors[resultNumber]);
1247 fusableProducer, tileAndFuseResult->tiledValues[0],
1248 tileAndFuseResult->
tiledOps, tileAndFuseResult->generatedSlices};
1253 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1261 *tiledOwner = fusedProducerInfo.
tiledOps[0];
1266 yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
1268 : llvm::to_vector(yieldResultNumber);
1270 for (
const auto &resultNumber : initNumberList) {
1272 rewriter, loc, originalOwner->
getResult(resultNumber));
1273 if (succeeded(initValue)) {
1274 initValueList.push_back(initValue.value());
1290 sliceSizes = sliceOp.getMixedSizes();
1293 if (llvm::any_of(sliceOp.getMixedStrides(), [](
OpFoldResult ofr) {
1294 return !isConstantIntValue(ofr, 1);
1298 unsigned sliceResultNumber =
1301 auto tilableOp = cast<TilingInterface>(originalOwner);
1305 if (tilableOp->getNumResults() > 1 &&
1306 failed(tilableOp.getIterationDomainTileFromResultTile(
1307 rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1308 iterDomainOffset, iterDomainSizes))) {
1323 for (
const auto &resultNumber : initNumberList) {
1324 if (resultNumber == sliceResultNumber) {
1325 offsetList.push_back(sliceOffset);
1326 sizesList.push_back(sliceSizes);
1328 assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1331 if (failed(tilableOp.getResultTilePosition(
1332 rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1336 offsetList.push_back(offset);
1337 sizesList.push_back(sizes);
1343 if (
auto tiledDestStyleOp =
1344 dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1346 for (
const auto &&[index, newRegionArg] :
1348 auto destSlice = rewriter.
create<tensor::ExtractSliceOp>(
1349 loc, newRegionArg, offsetList[index], sizesList[index],
1352 generatedSlices.push_back(destSlice);
1353 unsigned resultNumber = initNumberList[index];
1355 tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1364 for (
const auto &&[index, resultNumber] :
llvm::enumerate(initNumberList)) {
1365 tiledResult.push_back(tiledOwner->getResult(resultNumber));
1366 tiledOffset.emplace_back(offsetList[index]);
1367 tiledSizes.emplace_back(sizesList[index]);
1373 newYieldValuesFn))) {
1376 return generatedSlices;
1390 explicit SliceTrackingListener(
1391 std::optional<FrozenRewritePatternSet>
patterns);
1392 SliceTrackingListener() =
default;
1401 void notifyOperationInserted(
Operation *op,
1408 void notifyOperationErased(
Operation *op)
override;
1415 std::deque<tensor::ExtractSliceOp> worklist;
1420 std::optional<FrozenRewritePatternSet>
patterns = std::nullopt;
1423 SliceTrackingListener::SliceTrackingListener(
1424 std::optional<FrozenRewritePatternSet> p) {
1431 if (
auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1432 worklist.push_back(slice);
1444 void SliceTrackingListener::notifyOperationInserted(
1446 auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1449 worklist.push_back(slice);
1455 void SliceTrackingListener::removeOp(
Operation *op) {
1456 if (!isa<tensor::ExtractSliceOp>(op))
1458 auto iter = worklist.begin();
1459 while (iter != worklist.end()) {
1464 if (iter == worklist.end())
1467 worklist.erase(iter);
1470 void SliceTrackingListener::notifyOperationErased(
Operation *op) {
1474 void SliceTrackingListener::notifyOperationReplaced(
Operation *op,
1490 : ForwardingListener(listener), replacements(replacements) {}
1492 void updateReplacementValues(
ValueRange origValues,
1496 for (
auto &[key, val] : replacements) {
1497 for (
auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) {
1506 ForwardingListener::notifyOperationReplaced(op, newOp);
1511 ForwardingListener::notifyOperationReplaced(op, values);
1512 updateReplacementValues(op->
getResults(), values);
1522 FailureOr<scf::SCFTileAndFuseResult>
1528 if (!consumer->getNumResults()) {
1530 consumer,
"invalid pattern for op with no results");
1536 FailureOr<scf::SCFTilingResult> tilingResult =
1539 if (failed(tilingResult))
1541 tiledAndFusedOps.insert_range(tilingResult->tiledOps);
1544 for (
auto [origVal, replacement] : llvm::zip_equal(
1545 consumer->getResults(), tilingResult->mergeResult.replacements)) {
1546 replacements[origVal] = replacement;
1550 auto &loops = tilingResult->loops;
1551 if (loops.empty()) {
1560 auto resetListener =
1561 llvm::make_scope_exit([&]() { rewriter.
setListener(previousListener); });
1562 ReplacementListener replaceListener(replacements, previousListener);
1572 struct WorklistItem {
1573 tensor::ExtractSliceOp candidateSlice;
1577 SliceTrackingListener sliceTracker =
1578 SliceTrackingListener(
options.cleanupPatterns);
1581 sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1585 while (!sliceTracker.worklist.empty()) {
1586 auto candidateSlice = sliceTracker.worklist.front();
1587 sliceTracker.worklist.pop_front();
1589 auto [fusableProducer, destinationInitArg] =
1592 if (!fusableProducer)
1595 std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1596 options.fusionControlFn(candidateSlice, fusableProducer,
1597 destinationInitArg.has_value());
1598 if (!controlFnResult)
1601 WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
1606 std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1614 if (worklistItem.controlFnResult.yieldProducerReplacement) {
1619 Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1620 FailureOr<SmallVector<Operation *>> newSlices =
1622 worklistItem.candidateSlice,
1623 fusedResult.value(), loops);
1624 if (failed(newSlices)) {
1626 fusableProducerOp,
"failed to replacement value for this "
1627 "operation from within the tiled loop");
1629 worklistCandidates.append(newSlices.value());
1630 for (
auto [index, result] :
1632 replacements[result] = loops.front()->getResult(
1633 loops.front()->getNumResults() -
1638 fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1639 fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1640 tiledAndFusedOps.insert(tiledAndFusedOp);
1643 if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1658 static LogicalResult
1660 Value result = candidateSliceOp.getResult();
1662 if (!llvm::hasSingleElement(uses)) {
1663 LLVM_DEBUG(llvm::dbgs() <<
"Too many uses of the candidate slice op\n");
1666 OpOperand &operandUse = (*uses.begin());
1668 if (!isa<scf::YieldOp>(userOp)) {
1669 LLVM_DEBUG(llvm::dbgs()
1670 <<
"Expected scf.yield to be the only user, but got -> "
1675 LLVM_DEBUG(llvm::dbgs() <<
"Expected tensor.insert_slice and scf.yield to "
1676 "be in the same block\n");
1685 if (!isa<LoopLikeOpInterface>(loopOp))
1704 if (isa<tensor::ParallelInsertSliceOp>(userOp))
1707 if (loopOp->
getBlock() != userOp->getBlock())
1711 firstUserOfLoop = userOp;
1713 return firstUserOfLoop;
1754 static FailureOr<llvm::SetVector<Operation *>>
1756 bool reorderOperations) {
1758 if (failed(firstUserOfLoop))
1764 options.omitBlockArguments =
true;
1765 bool includeLoopOp =
false;
1768 includeLoopOp =
true;
1780 if (!slice.empty()) {
1790 if (includeLoopOp || !reorderOperations)
1802 unsigned resultNumber) {
1803 if (!isa<LoopLikeOpInterface>(loopOp))
1808 Operation *consumerOp = opOperand.getOwner();
1810 if (!isa<TilingInterface>(consumerOp) ||
1811 !isa<DestinationStyleOpInterface>(consumerOp)) {
1818 if (loopBlock != consumerOp->
getBlock())
1825 FailureOr<llvm::SetVector<Operation *>> slice =
1831 if (!slice->empty()) {
1834 assert(succeeded(firstUserOfLoop) &&
"First user of loop is not found");
1835 for (
auto op : *slice) {
1859 assert(!loops.empty() &&
"unexpected empty loop nest");
1860 if (loops.size() == 1) {
1861 return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
1863 for (
auto [outerLoop, innerLoop] :
1864 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
1865 auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
1866 auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
1867 if (!outerFor || !innerFor) {
1870 auto outerBBArgs = outerFor.getRegionIterArgs();
1871 auto innerIterArgs = innerFor.getInitArgs();
1872 if (outerBBArgs.size() != innerIterArgs.size()) {
1876 for (
auto [outerBBArg, innerIterArg] :
1877 llvm::zip_equal(outerBBArgs, innerIterArgs)) {
1878 if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
1879 innerIterArg != outerBBArg) {
1885 cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
1886 ValueRange innerResults = innerFor.getResults();
1887 if (outerYields.size() != innerResults.size()) {
1890 for (
auto [outerYield, innerResult] :
1891 llvm::zip_equal(outerYields, innerResults)) {
1892 if (!llvm::hasSingleElement(innerResult.getUses()) ||
1893 outerYield != innerResult) {
1907 static FailureOr<OpOperand *>
1909 tensor::InsertSliceOp candidateSliceOp,
1911 assert(!loops.empty() &&
"unexpected loops to be empty");
1914 if (containingOp != loops.back()) {
1917 "expected slice to be within body of inner-most loop");
1923 candidateSliceOp,
"expected passed loops to be perfectly nested.");
1928 Value sliceResult = candidateSliceOp.getResult();
1934 scf::ForOp topLevelForOp = cast<scf::ForOp>(loops.front().getOperation());
1941 static FailureOr<OpOperand *>
1943 tensor::ParallelInsertSliceOp candidateSliceOp,
1945 assert(!loops.empty() &&
"unexpected loops to be empty");
1947 if (loops.size() != 1) {
1949 candidateSliceOp,
"expected single surrounding scf.forall");
1951 auto forallOp = dyn_cast<scf::ForallOp>(loops.front().getOperation());
1954 candidateSliceOp,
"expected single surrounding scf.forall");
1958 Value sliceDest = candidateSliceOp.getDest();
1959 auto iterArg = dyn_cast<BlockArgument>(sliceDest);
1962 if (iterArg.getOwner()->getParentOp() != forallOp)
1965 unsigned resultNumber =
1966 forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
1974 static FailureOr<OpOperand *>
1977 assert(!loops.empty() &&
"unexpected empty loops");
1978 if (
auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1980 }
else if (
auto parallelInsertSlice =
1981 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1990 FailureOr<scf::SCFFuseConsumerOfSliceResult>
1996 if (loops.empty()) {
1998 "cannot call tile and fuse consumer with an empty loop nest");
2000 if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2006 FailureOr<OpOperand *> maybeConsumerOpOperand =
2008 if (failed(maybeConsumerOpOperand)) {
2010 "could not fetch consumer to fuse");
2012 OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
2015 unsigned resultNumber = 0;
2016 if (
auto producerResult = dyn_cast<OpResult>(consumerOpOperand->
get())) {
2017 resultNumber = producerResult.getResultNumber();
2020 consumerOp,
"consumer op's operand doesn't seem to be an OpResult");
2023 LoopLikeOpInterface outerMostLoop = loops.front();
2024 LoopLikeOpInterface innerMostLoop = loops.back();
2029 outerMostLoop,
"the first user of loop should not dominate any define "
2030 "of consumer operand(s)");
2036 auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
2039 "consumer op is not DPS operation");
2041 llvm::map_to_vector(dstOp.getDpsInits(), [](
Value v) { return v; });
2042 if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
2045 "consumer op taking the result of scf.for as init is not supported");
2049 Location loc = outerMostLoop->getLoc();
2054 if (failed(firstUserOfLoop)) {
2056 outerMostLoop,
"could not find the first user of outer most loop");
2058 rewriter.
moveOpBefore(outerMostLoop, *firstUserOfLoop);
2064 tensor::InsertSliceOp clonedInsertSliceOp;
2066 dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
2067 auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
2069 clonedInsertSliceOp = rewriter.
create<tensor::InsertSliceOp>(
2070 loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
2071 sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
2074 clonedInsertSliceOp =
2075 cast<tensor::InsertSliceOp>(rewriter.
clone(*candidateSliceOp));
2079 auto clonedConsumerOp = cast<TilingInterface>(rewriter.
clone(*consumerOp));
2083 OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
2085 operandToReplace.
set(clonedInsertSliceOp.getResult());
2091 cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
2092 FailureOr<TilingResult> tileAndFuseResult =
2094 rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
2095 if (failed(tileAndFuseResult)) {
2098 auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
2100 clonedInsertSliceOp.getSource());
2121 candidateSliceOp,
"containingOp's result yield with stride");
2131 if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
2132 rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
2133 iterDomainSizes))) {
2136 "can't get iter domain position from input position");
2142 unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
2144 totalNumResultsOfConsumer);
2146 totalNumResultsOfConsumer);
2147 for (
auto [idx, v] :
llvm::enumerate(tiledConsumerOp->getResults())) {
2148 if (failed(tiledConsumerOp.getResultTilePosition(
2149 rewriter, idx, iterDomainOffsets, iterDomainSizes,
2150 resultOffsets[idx], resultSizes[idx]))) {
2153 "can't get result domain position from iter domain position");
2159 if (
auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
2160 tiledConsumerOp.getOperation())) {
2162 for (
const auto &&[index, newRegionArg] :
2164 auto destSlice = rewriter.
create<tensor::ExtractSliceOp>(
2165 loc, newRegionArg, resultOffsets[index], resultSizes[index],
2170 auto dstNumber = index;
2172 tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
2181 for (
const auto &&[index, result] :
2183 tiledResult.push_back(result);
2184 tiledOffset.emplace_back(resultOffsets[index]);
2185 tiledSizes.emplace_back(resultSizes[index]);
2191 newYieldValuesFn))) {
2193 "unable to add new inits to nest loop");
2199 for (
auto &&[oldResult, newResult] :
2201 loops.front()->getResults().take_back(newInits.size()))) {
2206 rewriter.
eraseOp(clonedConsumerOp);
2210 &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
2211 tileAndFuseResult->tiledOps};
2218 FailureOr<SmallVector<scf::ForOp>>
2220 TilingInterface op) {
2222 if (op->getNumResults() > 0) {
2224 op,
"unable to lower to loops operations with return values");
2231 for (
auto loopRange : domain) {
2238 auto loop = rewriter.
create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
2240 loops.push_back(loop);
2241 ivs.push_back(loop.getInductionVar());
2244 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static LogicalResult verifyTileSizeOptions(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options)
Verify the tile size options are set in a consistent manner.
static LogicalResult checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp)
A utility function that checks whether the only use of the result of a tensor.insert_slice op is in a...
static bool isPerfectlyNestedForLoops(MutableArrayRef< LoopLikeOpInterface > loops)
Check that the loop is perfectly nested.
std::function< LogicalResult(RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, SmallVector< Value > &tiledValues, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> YieldTiledValuesFn
A function that allows returning additional yielded values during yieldTiledValuesAndReplace.
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, OpFoldResult numThreads, OpFoldResult iterationSize)
Returns true if the maximum tile offset tileSize * numThreads-1 is less than iterationSize.
static LogicalResult getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize, const scf::SCFTilingOptions &options)
static FailureOr< llvm::SetVector< Operation * > > checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, bool reorderOperations)
This utility currently checks whether the first userOp of loop is NOT before the last defineOp of con...
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, OpFoldResult offset, OpFoldResult tileSize)
Returns the bounded tile size given the current offset, loopRange and tileSize, i....
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Append the specified additional newInitOperands operands to the loops existing init operands (or simi...
static LogicalResult generateLoopNestUsingForOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ValueRange destinationTensors, YieldTiledValuesFn yieldTiledValuesFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.for operation.
static FailureOr< OpOperand * > getConsumerFromLoopUses(RewriterBase &rewriter, Operation *loopOp, unsigned resultNumber)
Fetches the OpOperand of the first valid user (and use) of the value val which implements TilingInter...
static void checkSafeToTileToForall(TilingInterface op, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Checks if any of the tiled loops are not parallel.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, ArrayRef< Range > iterationDomain, const scf::SCFTilingOptions &options)
Method to instantiate the tile sizes and/or number of threads specified by the user.
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< LoopLikeOpInterface > loops)
Return the untiled producer whose slice is used in a tiled consumer.
static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using the loop construct specifed in options.
static bool tileDividesIterationDomain(Range loopRange)
Check if stride evenly divides the trip count size - offset.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes)
Function to return the bounds of the loops to be generated.
static Operation * cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *op, ValueRange newDestArgs)
Clones the operation and updates the destination if the operation implements the DestinationStyleOpIn...
static FailureOr< Operation * > getFirstUserOfLoop(Operation *loopOp)
An utility to get the first user of the given loopOp.
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, const scf::SCFTilingOptions &options)
static LogicalResult addInitOperandsToLoopNest(RewriterBase &rewriter, MutableArrayRef< LoopLikeOpInterface > loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn)
Method to add new init values to a loop nest.
static FailureOr< MergeResult > mergeTilingResults(RewriterBase &rewriter, TilingInterface op, ValueRange partialResults, const scf::SCFTilingOptions &options)
static FailureOr< OpOperand * > getUntiledConsumerFromSlice(RewriterBase &rewriter, tensor::InsertSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Fetch the untiled consumer of the outermost scf.for's result which is yielded by a tensor....
static FailureOr< SmallVector< Value > > createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, ArrayRef< OpFoldResult > tileSizes, const scf::SCFTilingOptions &options)
static LogicalResult generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ArrayRef< Attribute > mappingVector, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.forall operation.
Base type for affine expression.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
This class allows control over how the GreedyPatternRewriteDriver works.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class represents a saved insertion point.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation * getOwner() const
Returns the operation that owns this result.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< scf::SCFFuseConsumerOfSliceResult > tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing consumer of a single slice by computing the slice of the consumer in-place f...
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< SmallVector< Operation * > > yieldReplacementForFusedProducer(RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef< LoopLikeOpInterface > loops, ArrayRef< unsigned > yieldResultNumber=ArrayRef< unsigned >{})
Reconstruct the fused producer from within the tiled-and-fused code.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
FailureOr< scf::SCFTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
std::optional< SCFFuseProducerOfSliceResult > tileAndFuseProducerOfSlice(RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing producer of a single slice by computing the slice of the producer in-place.
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
FailureOr< TilingResult > replaceInsertSliceWithTiledConsumer(OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp, OpOperand &consumerOp)
Method to swap an tensor.insert_slice with its consumer when the consumer implements the TilingInterf...
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
@ ExistingAndNewOps
Only pre-existing and newly created ops are processed.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Container for the result of merge operation of tiling.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
Container for result values of tiling.
Fuse the consumer of the source of candidateSliceOp by computing the required slice of the consumer i...
Fuse the producer of the source of candidateSliceOp by computing the required slice of the producer i...
SmallVector< Operation * > tiledOps
Control function to check if a slice needs to be fused or not, The control function receives 1) the s...
Options used to control tile + fuse.
Transformation information returned after tile and fuse.
Options to use to control tiling.
SCFTileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes to use for each loop.
SCFTilingOptions & setNumThreads(ArrayRef< OpFoldResult > numThreads)
Convenience function to set the numThreadsComputationFunction to a function that computes num threads...
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
ReductionTilingStrategy
Specify how reduction dimensions should be tiled.
@ PartialReductionOuterReduction
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.