31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Support/Debug.h"
35 #define DEBUG_TYPE "tile-using-interface"
42 auto tileSizes = llvm::to_vector(ts);
51 assert(!numThreadsComputationFunction &&
"num tiles already set");
52 auto numThreads = llvm::to_vector(nt);
63 size_t iterationDomainSize) {
65 if (filledVector.size() < iterationDomainSize) {
66 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
67 filledVector.append(range.begin(), range.end());
69 if (filledVector.size() > iterationDomainSize)
70 filledVector.resize(iterationDomainSize);
83 if (
options.numThreadsComputationFunction &&
86 loc,
"number of threads can only by specified when loop type is "
87 "set to use `scf.forall`");
91 if (!
options.interchangeVector.empty()) {
94 loc,
"invalid interchange vector, not a permutation of the entire "
109 size_t numLoops = iterationDomain.size();
112 if (
options.numThreadsComputationFunction) {
113 numThreads =
options.numThreadsComputationFunction(rewriter, op);
114 numThreads.resize(numLoops, zero);
117 if (
options.tileSizeComputationFunction) {
118 tileSizes =
options.tileSizeComputationFunction(rewriter, op);
119 tileSizes.resize(numLoops, zero);
120 return {tileSizes, numThreads};
132 tileSizes.resize(numLoops, zero);
133 for (
auto [index, range, nt] :
139 rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
141 tileSizes.resize(numLoops, zero);
142 return {tileSizes, numThreads};
149 assert(
options.tileSizeComputationFunction &&
150 "expected tile sizes to be specified");
151 tileSizes =
options.tileSizeComputationFunction(rewriter, op);
152 tileSizes.resize(numLoops, zero);
154 return {tileSizes, numThreads};
161 auto iterators = op.getLoopIteratorTypes();
162 assert(iterators.size() == tileSizes.size() &&
163 "expected as many tile size values as number of loops");
164 assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
165 "when specified, expected number of threads to use for each loop");
167 for (
auto [index, iterator, tileSize] :
171 if (!numThreads.empty()) {
172 if (std::optional<int64_t> constNumThreads =
174 if (constNumThreads.value() > 1 &&
175 iterator != utils::IteratorType::parallel) {
176 op.emitWarning() <<
"tiling is not thread safe at axis #" << index;
183 if (constTileSize.value() > 0 &&
184 iterator != utils::IteratorType::parallel) {
185 op.emitWarning() <<
"tiling is not thread safe at axis #" << index;
202 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
211 if (ts && ts.value() == 1)
238 if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
240 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
252 int materializedLoopNum = 0;
254 if (!numThreads.empty()) {
259 offsetExpr = d0 + d1 * s0;
260 residualTileSizeExpr = s1 - (d0 + d1 * s0);
262 for (
auto [nt, tileSize, loopRange] :
263 llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
268 offsets.push_back(loopRange.offset);
269 sizes.push_back(loopRange.size);
273 Value iv = ivs[materializedLoopNum++];
275 rewriter, loc, offsetExpr,
278 rewriter, loc, residualTileSizeExpr,
279 {loopRange.offset, nt, tileSize, loopRange.size});
285 {offset, loopRange.size});
289 {sizeMinusOffsetPerThread, tileSize});
305 rewriter, loc, maxMap, {rewriter.
getIndexAttr(0), size});
308 offsets.push_back(offset);
309 sizes.push_back(size);
311 return {offsets, sizes};
313 for (
auto [tileSize, loopRange] :
314 llvm::zip_equal(tileSizes, iterationDomain)) {
319 offsets.push_back(loopRange.offset);
320 sizes.push_back(loopRange.size);
324 Value iv = ivs[materializedLoopNum++];
326 offsets.push_back(offset);
329 sizes.push_back(size);
331 return {offsets, sizes};
341 for (
auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
345 lbs.push_back(loopRange.offset);
346 ubs.push_back(loopRange.size);
347 steps.push_back(tileSize);
349 return {lbs, ubs, steps};
379 if (newDestArgs.empty())
381 if (
auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
382 destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
400 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
401 assert(loopRanges.size() == tileSizes.size() &&
402 "expected as many tile sizes as loop ranges");
406 std::tie(lbs, ubs, steps) =
416 for (
auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
418 rewriter.
create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
421 loops.push_back(loop);
422 ivs.push_back(loop.getInductionVar());
424 destinationTensors = loop.getRegionIterArgs();
429 if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
430 tiledResults, resultOffsets, resultSizes))) {
432 loc,
"failed to generate inner tile loop body");
437 assert(tiledResults.size() == destinationTensors.size() &&
438 "Number of results of body should be equal to number of iter args");
442 for (
auto [tiledValue, destinationTensor, resultOffset, resultSize] :
443 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
447 auto insertSlice = rewriter.
create<tensor::InsertSliceOp>(
448 loc, tiledValue, destinationTensor, resultOffset, resultSize,
450 yieldedValues.push_back(insertSlice);
452 rewriter.
create<scf::YieldOp>(loc, yieldedValues);
455 for (
auto [outerLoop, innerLoop] :
459 cast<scf::ForOp>(outerLoop.getOperation()).getBody());
460 rewriter.
create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
481 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
482 assert(loopRanges.size() == tileSizes.size() &&
483 "expected as many tile sizes as loop ranges");
486 sizes(loopRanges.size());
488 std::optional<ArrayAttr> mappingAttr;
489 if (!mappingVector.empty())
492 scf::ForallOp forallOp;
493 bool useNumThreads = !numThreads.empty();
498 for (
auto nt : numThreads) {
501 nonZeroNumThreads.push_back(nt);
503 forallOp = rewriter.
create<scf::ForallOp>(loc, nonZeroNumThreads,
504 destinationTensors, mappingAttr);
507 std::tie(lbs, ubs, steps) =
509 forallOp = rewriter.
create<scf::ForallOp>(loc, lbs, ubs, steps,
510 destinationTensors, mappingAttr);
512 loops.push_back(forallOp);
515 destinationTensors = forallOp.getRegionOutArgs();
519 if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
520 destinationTensors, tiledResults, resultOffsets,
525 for (
auto [tiledValue, destinationTensor, resultOffset, resultSize] :
526 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
531 rewriter.
create<tensor::ParallelInsertSliceOp>(
532 loc, tiledValue, destinationTensor, resultOffset, resultSize,
558 return tiledBodyFn(rewriter, loc,
ValueRange{}, destinationTensors,
559 tiledResults, resultOffsets, resultSizes);
563 destinationTensors, tiledBodyFn, loops);
567 rewriter, loc, loopRanges, tileSizes, numThreads,
options.mappingVector,
568 destinationTensors, tiledBodyFn, loops);
573 static FailureOr<SmallVector<Value>>
579 switch (
options.reductionStrategy) {
586 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
589 op,
"PartialReductionOuterReduction tiling strategy is only supported"
590 "for operations implementing PartialReductionOpInterface");
596 for (
auto [idx, iteratorType] :
598 if (iteratorType == utils::IteratorType::reduction)
599 reductionDims.push_back(idx);
601 return redOp.generateInitialTensorForPartialReduction(
602 rewriter, loc, tileSizes, reductionDims);
606 "unhandled reduction tiling strategy");
610 static FailureOr<TilingResult>
615 switch (
options.reductionStrategy) {
617 return op.getTiledImplementation(rewriter, offsets, sizes);
620 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
623 op,
"PartialReductionOuterReduction tiling strategy is only "
624 "supported for operations "
625 "implementing PartialReductionOpInterface");
631 for (
auto [idx, iteratorType] :
633 if (iteratorType == utils::IteratorType::reduction)
634 reductionDims.push_back(idx);
636 return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg,
637 offsets, sizes, reductionDims);
641 "unhandled reduction tiling strategy");
653 switch (
options.reductionStrategy) {
655 return op.getResultTilePosition(rewriter, index, offsets, sizes,
656 resultOffset, resultSize);
664 for (
size_t i = 0; i < offsets.size(); i++) {
665 resultSize.push_back(
671 "unhandled reduction tiling strategy");
676 static FailureOr<MergeResult>
680 switch (
options.reductionStrategy) {
686 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
689 op,
"PartialReductionOuterReduction tiling strategy is only "
690 "supported for operations "
691 "implementing PartialReductionOpInterface");
697 for (
auto [idx, iteratorType] :
699 if (iteratorType == utils::IteratorType::reduction)
700 reductionDims.push_back(idx);
702 return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
707 "unhandled reduction tiling strategy");
718 template <
typename LoopType>
719 FailureOr<LoopLikeOpInterface>
728 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
735 auto inits = llvm::to_vector(loopOp.getInitArgs());
736 inits.append(newInitOperands.begin(), newInitOperands.end());
737 auto newLoop = rewriter.
create<scf::ForOp>(
738 loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
742 Block *loopBody = loopOp.getBody();
743 Block *newLoopBody = newLoop.getBody();
745 loopBody, newLoopBody,
746 newLoopBody->
getArguments().take_front(loopBody->getNumArguments()));
748 auto yieldOp = cast<scf::YieldOp>(newLoopBody->
getTerminator());
754 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
755 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
756 newRegionIterArgs, tiledValues, resultOffsets,
763 for (
auto [tiledValue, regionIterArg, resultOffset, resultSize] :
764 llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
768 Value insert = rewriter.
create<tensor::InsertSliceOp>(
769 yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
771 newYieldValues.push_back(insert);
776 newLoop->getResults().take_front(loopOp.getNumResults()));
777 return cast<LoopLikeOpInterface>(newLoop.getOperation());
782 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
788 auto inits = llvm::to_vector(loopOp.getOutputs());
789 inits.append(newInitOperands.begin(), newInitOperands.end());
790 auto newLoop = rewriter.
create<scf::ForallOp>(
791 loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
792 loopOp.getMixedStep(), inits, loopOp.getMapping(),
796 Block *loopBody = loopOp.getBody();
797 Block *newLoopBody = newLoop.getBody();
799 loopBody, newLoopBody,
800 newLoopBody->
getArguments().take_front(loopBody->getNumArguments()));
802 auto terminator = cast<scf::InParallelOp>(newLoopBody->
getTerminator());
807 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
808 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
809 regionIterArgs, tiledValues, resultOffsets,
813 "failed to get yielded tiled values");
819 for (
auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
820 tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
823 rewriter.
create<tensor::ParallelInsertSliceOp>(
824 terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
829 newLoop->getResults().take_front(loopOp.getNumResults()));
830 return cast<LoopLikeOpInterface>(newLoop.getOperation());
840 loopLikeOp.getOperation())
841 .Case<scf::ForOp, scf::ForallOp>(
842 [&](
auto loopOp) -> FailureOr<LoopLikeOpInterface> {
844 loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
846 .Default([&](
auto loopOp) -> FailureOr<LoopLikeOpInterface> {
866 for (
auto &loop : loops.drop_back()) {
870 auto forLoop = cast<scf::ForOp>(loop.getOperation());
874 newInits.append(newInitValues.begin(), newInitValues.end());
875 auto newLoop = rewriter.
create<scf::ForOp>(
876 forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
877 forLoop.getStep(), newInits,
882 sourceBlockArgs.push_back(newLoop.getInductionVar());
883 auto newRegionIterArgs = newLoop.getRegionIterArgs();
884 sourceBlockArgs.append(
885 newRegionIterArgs.begin(),
886 std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
887 rewriter.
mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
889 forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
891 ivs.push_back(newLoop.getInductionVar());
892 newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
896 LoopLikeOpInterface innerMostLoop = loops.back();
897 FailureOr<LoopLikeOpInterface> newInnerMostLoop =
899 getNewTiledYieldsFn);
901 if (failed(newInnerMostLoop))
902 return innerMostLoop.emitOpError(
"failed to return additional yields");
903 loops.back() = newInnerMostLoop.value();
907 for (
auto [outerLoop, innerLoop] :
908 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
910 auto outerForLoop = cast<scf::ForOp>(outerLoop);
911 auto outerLoopYield =
912 cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
914 llvm::to_vector(outerLoopYield.getOperands());
916 innerLoop->getResults().take_back(newInitValues.size());
917 newYields.append(additionalYields.begin(), additionalYields.end());
926 FailureOr<scf::SCFTilingResult>
941 std::tie(tileSizes, numThreads) =
953 if (!
options.interchangeVector.empty()) {
955 iterationDomain.size());
957 "expected interchange vector to be a permutation");
961 if (!numThreads.empty())
965 FailureOr<TilingResult> tilingResult;
977 rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
981 if (!interchangeVector.empty()) {
990 auto clonedOp = cast<TilingInterface>(
997 tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
1007 if (failed(tilingResult)) {
1009 return op.emitOpError(
"faild to tile operation");
1017 for (
auto [index, tiledValue] :
1019 tiledResults.push_back(tiledValue);
1022 sizes, resultOffset, resultSize,
1024 for (
auto op : tilingResult->tiledOps) {
1028 op,
"failed to get slice of result produced");
1030 resultOffsets.emplace_back(std::move(resultOffset));
1031 resultSizes.emplace_back(std::move(resultSize));
1038 FailureOr<SmallVector<Value>> maybeInits =
1040 if (failed(maybeInits)) {
1042 op,
"unable to create initial tensors for tiling");
1049 tileSizes, numThreads, initTensors,
1050 innerYieldTiledValuesFn, loops)))
1051 return op.emitOpError(
"failed to generate tiling loops");
1052 assert(succeeded(tilingResult) &&
1053 "expected tiling result to be computed after loop generation");
1056 if (loops.empty()) {
1059 partialResults = tilingResult->tiledValues;
1061 partialResults = llvm::map_to_vector(loops.front()->getResults(),
1065 FailureOr<MergeResult> mergeResult =
1067 if (failed(mergeResult)) {
1069 op,
"Failed to merge partial results from tiling");
1073 mergeResult.value(),
1074 tilingResult->generatedSlices};
1077 FailureOr<scf::SCFTilingResult>
1079 PartialReductionOpInterface op,
1084 PartialReductionOuterReduction);
1085 options.setTileSizes(tileSizes);
1087 TilingInterface tilingInterfaceOp =
1088 dyn_cast<TilingInterface>(op.getOperation());
1089 if (!tilingInterfaceOp) {
1092 "Operation implementing PartialReductionOpInterface should implement "
1109 static std::tuple<OpResult, std::optional<OpOperand *>>
1112 std::optional<OpOperand *> destinationIterArg;
1113 auto loopIt = loops.rbegin();
1114 while (
auto iterArg = dyn_cast<BlockArgument>(source->
get())) {
1115 auto loop = *loopIt;
1116 if (iterArg.getOwner()->getParentOp() != loop)
1118 source = loop.getTiedLoopInit(iterArg);
1121 if (loopIt == loops.rend())
1122 destinationIterArg = source;
1123 return {dyn_cast<OpResult>(source->
get()), destinationIterArg};
1128 std::optional<scf::SCFFuseProducerOfSliceResult>
1130 RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1134 auto [fusableProducer, destinationInitArg] =
1137 if (!fusableProducer)
1138 return std::nullopt;
1139 unsigned resultNumber = fusableProducer.getResultNumber();
1147 Operation *fusableProducerOp = fusableProducer.getOwner();
1148 if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
1150 rewriter, fusableProducerOp->
getLoc(), fusableProducerOp,
1151 origDestinationTensors)))
1152 return std::nullopt;
1154 clonedOpDestinationTensors = origDestinationTensors;
1155 if (destinationInitArg &&
1156 isa<DestinationStyleOpInterface>(fusableProducerOp)) {
1160 clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1164 rewriter, fusableProducerOp, clonedOpDestinationTensors);
1169 llvm::to_vector(candidateSliceOp->getOperands());
1170 candidateSliceOpOperands[0] = clonedProducerOp->
getResult(resultNumber);
1171 tensor::ExtractSliceOp clonedCandidateSliceOp =
1173 candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1176 FailureOr<TilingResult> tileAndFuseResult =
1178 rewriter, clonedCandidateSliceOp,
1179 clonedProducerOp->
getResult(resultNumber));
1180 if (failed(tileAndFuseResult))
1181 return std::nullopt;
1185 tileAndFuseResult->tiledValues[0]);
1186 rewriter.
eraseOp(clonedCandidateSliceOp);
1187 rewriter.
eraseOp(clonedProducerOp);
1232 if (destinationInitArg &&
1233 isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
1235 ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1236 .set(origDestinationTensors[resultNumber]);
1239 fusableProducer, tileAndFuseResult->tiledValues[0],
1240 tileAndFuseResult->
tiledOps, tileAndFuseResult->generatedSlices};
1245 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1253 *tiledOwner = fusedProducerInfo.
tiledOps[0];
1258 yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
1260 : llvm::to_vector(yieldResultNumber);
1262 for (
const auto &resultNumber : initNumberList) {
1264 rewriter, loc, originalOwner->
getResult(resultNumber));
1265 if (succeeded(initValue)) {
1266 initValueList.push_back(initValue.value());
1282 sliceSizes = sliceOp.getMixedSizes();
1285 if (llvm::any_of(sliceOp.getMixedStrides(), [](
OpFoldResult ofr) {
1286 return !isConstantIntValue(ofr, 1);
1290 unsigned sliceResultNumber =
1293 auto tilableOp = cast<TilingInterface>(originalOwner);
1297 if (tilableOp->getNumResults() > 1 &&
1298 failed(tilableOp.getIterationDomainTileFromResultTile(
1299 rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1300 iterDomainOffset, iterDomainSizes))) {
1315 for (
const auto &resultNumber : initNumberList) {
1316 if (resultNumber == sliceResultNumber) {
1317 offsetList.push_back(sliceOffset);
1318 sizesList.push_back(sliceSizes);
1320 assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1323 if (failed(tilableOp.getResultTilePosition(
1324 rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1328 offsetList.push_back(offset);
1329 sizesList.push_back(sizes);
1335 if (
auto tiledDestStyleOp =
1336 dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1338 for (
const auto &&[index, newRegionArg] :
1340 auto destSlice = rewriter.
create<tensor::ExtractSliceOp>(
1341 loc, newRegionArg, offsetList[index], sizesList[index],
1344 generatedSlices.push_back(destSlice);
1345 unsigned resultNumber = initNumberList[index];
1347 tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1356 for (
const auto &&[index, resultNumber] :
llvm::enumerate(initNumberList)) {
1357 tiledResult.push_back(tiledOwner->getResult(resultNumber));
1358 tiledOffset.emplace_back(offsetList[index]);
1359 tiledSizes.emplace_back(sizesList[index]);
1365 newYieldValuesFn))) {
1368 return generatedSlices;
1382 explicit SliceTrackingListener(
1383 std::optional<FrozenRewritePatternSet>
patterns);
1384 SliceTrackingListener() =
default;
1393 void notifyOperationInserted(
Operation *op,
1400 void notifyOperationErased(
Operation *op)
override;
1407 std::deque<tensor::ExtractSliceOp> worklist;
1412 std::optional<FrozenRewritePatternSet>
patterns = std::nullopt;
1415 SliceTrackingListener::SliceTrackingListener(
1416 std::optional<FrozenRewritePatternSet> p) {
1423 if (
auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1424 worklist.push_back(slice);
1436 void SliceTrackingListener::notifyOperationInserted(
1438 auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1441 worklist.push_back(slice);
1447 void SliceTrackingListener::removeOp(
Operation *op) {
1448 if (!isa<tensor::ExtractSliceOp>(op))
1450 auto iter = worklist.begin();
1451 while (iter != worklist.end()) {
1456 if (iter == worklist.end())
1459 worklist.erase(iter);
1462 void SliceTrackingListener::notifyOperationErased(
Operation *op) {
1466 void SliceTrackingListener::notifyOperationReplaced(
Operation *op,
1473 FailureOr<scf::SCFTileAndFuseResult>
1479 if (!consumer->getNumResults()) {
1481 consumer,
"invalid pattern for op with no results");
1486 llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
1488 FailureOr<scf::SCFTilingResult> tilingResult =
1491 if (failed(tilingResult))
1493 for (
auto *tiledOp : tilingResult->tiledOps)
1494 tiledAndFusedOps.insert(tiledOp);
1497 auto &loops = tilingResult->loops;
1498 if (loops.empty()) {
1500 for (
auto [origVal, replacement] : llvm::zip_equal(
1501 consumer->getResults(), tilingResult->mergeResult.replacements)) {
1502 replacements[origVal] = replacement;
1513 for (
auto [index, result] :
llvm::enumerate(consumer->getResults())) {
1514 origValToResultNumber[result] = index;
1524 struct WorklistItem {
1525 tensor::ExtractSliceOp candidateSlice;
1529 SliceTrackingListener sliceTracker =
1530 SliceTrackingListener(
options.cleanupPatterns);
1533 sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1537 while (!sliceTracker.worklist.empty()) {
1538 auto candidateSlice = sliceTracker.worklist.front();
1539 sliceTracker.worklist.pop_front();
1541 auto [fusableProducer, destinationInitArg] =
1544 if (!fusableProducer)
1547 std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1548 options.fusionControlFn(candidateSlice, fusableProducer,
1549 destinationInitArg.has_value());
1550 if (!controlFnResult)
1553 WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
1558 std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1566 if (worklistItem.controlFnResult.yieldProducerReplacement) {
1571 Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1572 FailureOr<SmallVector<Operation *>> newSlices =
1574 worklistItem.candidateSlice,
1575 fusedResult.value(), loops);
1576 if (failed(newSlices)) {
1578 fusableProducerOp,
"failed to replacement value for this "
1579 "operation from within the tiled loop");
1581 worklistCandidates.append(newSlices.value());
1582 for (
auto [index, result] :
1584 origValToResultNumber[result] = loops.front()->getNumResults() -
1590 fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1591 fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1592 tiledAndFusedOps.insert(tiledAndFusedOp);
1595 if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1601 for (
auto [origVal, resultNumber] : origValToResultNumber) {
1602 replacements[origVal] = loops.front()->getResult(resultNumber);
1615 static LogicalResult
1617 Value result = candidateSliceOp.getResult();
1619 if (!llvm::hasSingleElement(uses)) {
1620 LLVM_DEBUG(llvm::dbgs() <<
"Too many uses of the candidate slice op\n");
1623 OpOperand &operandUse = (*uses.begin());
1625 if (!isa<scf::YieldOp>(userOp)) {
1626 LLVM_DEBUG(llvm::dbgs()
1627 <<
"Expected scf.yield to be the only user, but got -> "
1632 LLVM_DEBUG(llvm::dbgs() <<
"Expected tensor.insert_slice and scf.yield to "
1633 "be in the same block\n");
1642 if (!isa<LoopLikeOpInterface>(loopOp))
1661 if (isa<tensor::ParallelInsertSliceOp>(userOp))
1664 if (loopOp->
getBlock() != userOp->getBlock())
1668 firstUserOfLoop = userOp;
1670 return firstUserOfLoop;
1711 static FailureOr<llvm::SetVector<Operation *>>
1713 bool reorderOperations) {
1715 if (failed(firstUserOfLoop))
1721 options.omitBlockArguments =
true;
1722 bool includeLoopOp =
false;
1725 includeLoopOp =
true;
1737 if (!slice.empty()) {
1747 if (includeLoopOp || !reorderOperations)
1759 unsigned resultNumber) {
1760 if (!isa<LoopLikeOpInterface>(loopOp))
1765 Operation *consumerOp = opOperand.getOwner();
1767 if (!isa<TilingInterface>(consumerOp) ||
1768 !isa<DestinationStyleOpInterface>(consumerOp)) {
1775 if (loopBlock != consumerOp->
getBlock())
1782 FailureOr<llvm::SetVector<Operation *>> slice =
1788 if (!slice->empty()) {
1791 assert(succeeded(firstUserOfLoop) &&
"First user of loop is not found");
1792 for (
auto op : *slice) {
1821 auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp());
1824 auto isForOpYieldResultOfInnerLoop =
1825 [](scf::ForOp outerLoop) -> LogicalResult {
1826 Block *body = outerLoop.getBody();
1830 auto innerForOp = dyn_cast<scf::ForOp>(body->
front());
1834 return success(innerForOp->getNumResults() == yieldOp->getNumOperands());
1837 while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) {
1838 nestLoops.push_back(outerLoop);
1839 outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp());
1842 return {nestLoops.rbegin(), nestLoops.rend()};
1849 static FailureOr<OpOperand *>
1851 tensor::InsertSliceOp candidateSliceOp) {
1854 Value sliceResult = candidateSliceOp.getResult();
1860 auto forOp = dyn_cast<scf::ForOp>(containingOp);
1870 static FailureOr<OpOperand *>
1872 tensor::ParallelInsertSliceOp candidateSliceOp) {
1874 Value sliceDest = candidateSliceOp.getDest();
1875 auto iterArg = dyn_cast<BlockArgument>(sliceDest);
1882 auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
1885 unsigned resultNumber =
1886 forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
1894 static FailureOr<OpOperand *>
1896 if (
auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1898 }
else if (
auto parallelInsertSlice =
1899 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1908 FailureOr<scf::SCFFuseConsumerOfSliceResult>
1911 if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
1915 bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
1919 FailureOr<OpOperand *> maybeConsumerOpOperand =
1921 if (failed(maybeConsumerOpOperand)) {
1923 "could not fetch consumer to fuse");
1925 OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
1928 unsigned resultNumber = 0;
1929 if (
auto producerResult = dyn_cast<OpResult>(consumerOpOperand->
get())) {
1930 resultNumber = producerResult.getResultNumber();
1933 consumerOp,
"consumer op's operand doesn't seem to be an OpResult");
1940 LoopLikeOpInterface innerMostLoop =
1943 if (isInsertSliceOp) {
1944 nestedLoops = llvm::map_to_vector(
1946 cast<scf::ForOp>(innerMostLoop.getOperation())),
1947 [](scf::ForOp forOp) {
1948 return cast<LoopLikeOpInterface>(forOp.getOperation());
1951 nestedLoops = {innerMostLoop};
1954 LoopLikeOpInterface outerMostLoop = nestedLoops.front();
1959 outerMostLoop,
"the first user of loop should not dominate any define "
1960 "of consumer operand(s)");
1966 auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
1969 "consumer op is not DPS operation");
1971 llvm::map_to_vector(dstOp.getDpsInits(), [](
Value v) { return v; });
1972 if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
1975 "consumer op taking the result of scf.for as init is not supported");
1979 Location loc = outerMostLoop->getLoc();
1984 if (failed(firstUserOfLoop)) {
1986 outerMostLoop,
"could not find the first user of outer most loop");
1988 rewriter.
moveOpBefore(outerMostLoop, *firstUserOfLoop);
1994 tensor::InsertSliceOp clonedInsertSliceOp;
1996 dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
1997 auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
1999 clonedInsertSliceOp = rewriter.
create<tensor::InsertSliceOp>(
2000 loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
2001 sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
2004 clonedInsertSliceOp =
2005 cast<tensor::InsertSliceOp>(rewriter.
clone(*candidateSliceOp));
2009 auto clonedConsumerOp = cast<TilingInterface>(rewriter.
clone(*consumerOp));
2013 OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
2015 operandToReplace.
set(clonedInsertSliceOp.getResult());
2021 cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
2022 FailureOr<TilingResult> tileAndFuseResult =
2024 rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
2025 if (failed(tileAndFuseResult)) {
2028 auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
2030 clonedInsertSliceOp.getSource());
2051 candidateSliceOp,
"containingOp's result yield with stride");
2061 if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
2062 rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
2063 iterDomainSizes))) {
2066 "can't get iter domain position from input position");
2072 unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
2074 totalNumResultsOfConsumer);
2076 totalNumResultsOfConsumer);
2077 for (
auto [idx, v] :
llvm::enumerate(tiledConsumerOp->getResults())) {
2078 if (failed(tiledConsumerOp.getResultTilePosition(
2079 rewriter, idx, iterDomainOffsets, iterDomainSizes,
2080 resultOffsets[idx], resultSizes[idx]))) {
2083 "can't get result domain position from iter domain position");
2089 if (
auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
2090 tiledConsumerOp.getOperation())) {
2092 for (
const auto &&[index, newRegionArg] :
2094 auto destSlice = rewriter.
create<tensor::ExtractSliceOp>(
2095 loc, newRegionArg, resultOffsets[index], resultSizes[index],
2100 auto dstNumber = index;
2102 tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
2111 for (
const auto &&[index, result] :
2113 tiledResult.push_back(result);
2114 tiledOffset.emplace_back(resultOffsets[index]);
2115 tiledSizes.emplace_back(resultSizes[index]);
2121 newYieldValuesFn))) {
2123 "unable to add new inits to nest loop");
2129 for (
auto &&[oldResult, newResult] : llvm::zip(
2131 nestedLoops.front()->getResults().take_back(newInits.size()))) {
2136 rewriter.
eraseOp(clonedConsumerOp);
2140 &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
2141 tileAndFuseResult->tiledOps};
2148 FailureOr<SmallVector<scf::ForOp>>
2150 TilingInterface op) {
2152 if (op->getNumResults() > 0) {
2154 op,
"unable to lower to loops operations with return values");
2161 for (
auto loopRange : domain) {
2168 auto loop = rewriter.
create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
2170 loops.push_back(loop);
2171 ivs.push_back(loop.getInductionVar());
2174 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static LogicalResult verifyTileSizeOptions(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options)
Verify the tile size options are set in a consistent manner.
static LogicalResult checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp)
A utility function that checks whether the only use of the result of a tensor.insert_slice op is in a...
std::function< LogicalResult(RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, SmallVector< Value > &tiledValues, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> YieldTiledValuesFn
A function that allows returning additional yielded values during yieldTiledValuesAndReplace.
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, OpFoldResult numThreads, OpFoldResult iterationSize)
Returns true if the maximum tile offset tileSize * numThreads-1 is less than iterationSize.
static LogicalResult getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize, const scf::SCFTilingOptions &options)
static FailureOr< llvm::SetVector< Operation * > > checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, bool reorderOperations)
This utility currently checks whether the first userOp of loop is NOT before the last defineOp of con...
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, OpFoldResult offset, OpFoldResult tileSize)
Returns the bounded tile size given the current offset, loopRange and tileSize, i....
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Append the specified additional newInitOperands operands to the loops existing init operands (or simi...
static LogicalResult generateLoopNestUsingForOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ValueRange destinationTensors, YieldTiledValuesFn yieldTiledValuesFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.for operation.
static FailureOr< OpOperand * > getConsumerFromLoopUses(RewriterBase &rewriter, Operation *loopOp, unsigned resultNumber)
Fetches the OpOperand of the first valid user (and use) of the value val which implements TilingInter...
static FailureOr< OpOperand * > getUntiledConsumerFromSlice(RewriterBase &rewriter, tensor::InsertSliceOp candidateSliceOp)
Fetch the untiled consumer of a scf.for's result which is yielded by a tensor.insert_slice.
static void checkSafeToTileToForall(TilingInterface op, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Checks if any of the tiled loops are not parallel.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, ArrayRef< Range > iterationDomain, const scf::SCFTilingOptions &options)
Method to instantiate the tile sizes and/or number of threads specified by the user.
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< LoopLikeOpInterface > loops)
Return the untiled producer whose slice is used in a tiled consumer.
static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using the loop construct specifed in options.
static SmallVector< scf::ForOp > getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop)
Find the perfectly nested loops outside of given loop(included) sorted from outer to inner.
static bool tileDividesIterationDomain(Range loopRange)
Check if stride evenly divides the trip count size - offset.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes)
Function to return the bounds of the loops to be generated.
static Operation * cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *op, ValueRange newDestArgs)
Clones the operation and updates the destination if the operation implements the DestinationStyleOpIn...
static FailureOr< Operation * > getFirstUserOfLoop(Operation *loopOp)
An utility to get the first user of the given loopOp.
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, const scf::SCFTilingOptions &options)
static LogicalResult addInitOperandsToLoopNest(RewriterBase &rewriter, MutableArrayRef< LoopLikeOpInterface > loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn)
Method to add new init values to a loop nest.
static FailureOr< MergeResult > mergeTilingResults(RewriterBase &rewriter, TilingInterface op, ValueRange partialResults, const scf::SCFTilingOptions &options)
static FailureOr< SmallVector< Value > > createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, ArrayRef< OpFoldResult > tileSizes, const scf::SCFTilingOptions &options)
static LogicalResult generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ArrayRef< Attribute > mappingVector, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.forall operation.
Base type for affine expression.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
This class allows control over how the GreedyPatternRewriteDriver works.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class represents a saved insertion point.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation * getOwner() const
Returns the operation that owns this result.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
FailureOr< scf::SCFFuseConsumerOfSliceResult > tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp)
Implementation of fusing consumer of a single slice by computing the slice of the consumer in-place f...
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< SmallVector< Operation * > > yieldReplacementForFusedProducer(RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef< LoopLikeOpInterface > loops, ArrayRef< unsigned > yieldResultNumber=ArrayRef< unsigned >{})
Reconstruct the fused producer from within the tiled-and-fused code.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
FailureOr< scf::SCFTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
std::optional< SCFFuseProducerOfSliceResult > tileAndFuseProducerOfSlice(RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing producer of a single slice by computing the slice of the producer in-place.
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
FailureOr< TilingResult > replaceInsertSliceWithTiledConsumer(OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp, OpOperand &consumerOp)
Method to swap an tensor.insert_slice with its consumer when the consumer implements the TilingInterf...
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
const FrozenRewritePatternSet GreedyRewriteConfig config
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
@ ExistingAndNewOps
Only pre-existing and newly created ops are processed.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Container for the result of merge operation of tiling.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Container for result values of tiling.
Fuse the consumer of the source of candidateSliceOp by computing the required slice of the consumer i...
Fuse the producer of the source of candidateSliceOp by computing the required slice of the producer i...
SmallVector< Operation * > tiledOps
Control function to check if a slice needs to be fused or not, The control function receives 1) the s...
Options used to control tile + fuse.
Transformation information returned after tile and fuse.
Options to use to control tiling.
SCFTileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes to use for each loop.
SCFTilingOptions & setNumThreads(ArrayRef< OpFoldResult > numThreads)
Convenience function to set the numThreadsComputationFunction to a function that computes num threads...
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
ReductionTilingStrategy
Specify how reduction dimensions should be tiled.
@ PartialReductionOuterReduction
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.