31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Support/Debug.h"
35 #define DEBUG_TYPE "tile-using-interface"
42 auto tileSizes = llvm::to_vector(ts);
51 assert(!numThreadsComputationFunction &&
"num tiles already set");
52 auto numThreads = llvm::to_vector(nt);
63 size_t iterationDomainSize) {
65 if (filledVector.size() < iterationDomainSize) {
66 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
67 filledVector.append(range.begin(), range.end());
69 if (filledVector.size() > iterationDomainSize)
70 filledVector.resize(iterationDomainSize);
83 if (
options.numThreadsComputationFunction &&
86 loc,
"number of threads can only by specified when loop type is "
87 "set to use `scf.forall`");
91 if (!
options.interchangeVector.empty()) {
94 loc,
"invalid interchange vector, not a permutation of the entire "
109 size_t numLoops = iterationDomain.size();
112 if (
options.numThreadsComputationFunction) {
113 numThreads =
options.numThreadsComputationFunction(rewriter, op);
114 numThreads.resize(numLoops, zero);
117 if (
options.tileSizeComputationFunction) {
118 tileSizes =
options.tileSizeComputationFunction(rewriter, op);
119 tileSizes.resize(numLoops, zero);
120 return {tileSizes, numThreads};
132 tileSizes.resize(numLoops, zero);
133 for (
auto [index, range, nt] :
139 rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
141 tileSizes.resize(numLoops, zero);
142 return {tileSizes, numThreads};
149 assert(
options.tileSizeComputationFunction &&
150 "expected tile sizes to be specified");
151 tileSizes =
options.tileSizeComputationFunction(rewriter, op);
152 tileSizes.resize(numLoops, zero);
154 return {tileSizes, numThreads};
161 auto iterators = op.getLoopIteratorTypes();
162 assert(iterators.size() == tileSizes.size() &&
163 "expected as many tile size values as number of loops");
164 assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
165 "when specified, expected number of threads to use for each loop");
167 for (
auto [index, iterator, tileSize] :
171 if (!numThreads.empty()) {
172 if (std::optional<int64_t> constNumThreads =
174 if (constNumThreads.value() > 1 &&
175 iterator != utils::IteratorType::parallel) {
176 op.emitWarning() <<
"tiling is not thread safe at axis #" << index;
183 if (constTileSize.value() > 0 &&
184 iterator != utils::IteratorType::parallel) {
185 op.emitWarning() <<
"tiling is not thread safe at axis #" << index;
202 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
211 if (ts && ts.value() == 1)
238 if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
240 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
252 int materializedLoopNum = 0;
254 if (!numThreads.empty()) {
259 offsetExpr = d0 + d1 * s0;
260 residualTileSizeExpr = s1 - (d0 + d1 * s0);
262 for (
auto [nt, tileSize, loopRange] :
263 llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
268 offsets.push_back(loopRange.offset);
269 sizes.push_back(loopRange.size);
273 Value iv = ivs[materializedLoopNum++];
275 rewriter, loc, offsetExpr,
278 rewriter, loc, residualTileSizeExpr,
279 {loopRange.offset, nt, tileSize, loopRange.size});
285 {offset, loopRange.size});
289 {sizeMinusOffsetPerThread, tileSize});
305 rewriter, loc, maxMap, {rewriter.
getIndexAttr(0), size});
308 offsets.push_back(offset);
309 sizes.push_back(size);
311 return {offsets, sizes};
313 for (
auto [tileSize, loopRange] :
314 llvm::zip_equal(tileSizes, iterationDomain)) {
319 offsets.push_back(loopRange.offset);
320 sizes.push_back(loopRange.size);
324 Value iv = ivs[materializedLoopNum++];
326 offsets.push_back(offset);
329 sizes.push_back(size);
331 return {offsets, sizes};
341 for (
auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
345 lbs.push_back(loopRange.offset);
346 ubs.push_back(loopRange.size);
347 steps.push_back(tileSize);
349 return {lbs, ubs, steps};
379 if (newDestArgs.empty())
381 if (
auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
382 destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
400 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
401 assert(loopRanges.size() == tileSizes.size() &&
402 "expected as many tile sizes as loop ranges");
406 std::tie(lbs, ubs, steps) =
416 for (
auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
418 rewriter.
create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
421 loops.push_back(loop);
422 ivs.push_back(loop.getInductionVar());
424 destinationTensors = loop.getRegionIterArgs();
429 if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
430 tiledResults, resultOffsets, resultSizes))) {
432 loc,
"failed to generate inner tile loop body");
437 assert(tiledResults.size() == destinationTensors.size() &&
438 "Number of results of body should be equal to number of iter args");
442 for (
auto [tiledValue, destinationTensor, resultOffset, resultSize] :
443 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
447 auto insertSlice = rewriter.
create<tensor::InsertSliceOp>(
448 loc, tiledValue, destinationTensor, resultOffset, resultSize,
450 yieldedValues.push_back(insertSlice);
452 rewriter.
create<scf::YieldOp>(loc, yieldedValues);
455 for (
auto [outerLoop, innerLoop] :
459 cast<scf::ForOp>(outerLoop.getOperation()).getBody());
460 rewriter.
create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
481 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
482 assert(loopRanges.size() == tileSizes.size() &&
483 "expected as many tile sizes as loop ranges");
486 sizes(loopRanges.size());
488 std::optional<ArrayAttr> mappingAttr;
489 if (!mappingVector.empty())
492 scf::ForallOp forallOp;
493 bool useNumThreads = !numThreads.empty();
498 for (
auto nt : numThreads) {
501 nonZeroNumThreads.push_back(nt);
503 forallOp = rewriter.
create<scf::ForallOp>(loc, nonZeroNumThreads,
504 destinationTensors, mappingAttr);
507 std::tie(lbs, ubs, steps) =
509 forallOp = rewriter.
create<scf::ForallOp>(loc, lbs, ubs, steps,
510 destinationTensors, mappingAttr);
512 loops.push_back(forallOp);
515 destinationTensors = forallOp.getRegionOutArgs();
519 if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
520 destinationTensors, tiledResults, resultOffsets,
525 for (
auto [tiledValue, destinationTensor, resultOffset, resultSize] :
526 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
531 rewriter.
create<tensor::ParallelInsertSliceOp>(
532 loc, tiledValue, destinationTensor, resultOffset, resultSize,
558 return tiledBodyFn(rewriter, loc,
ValueRange{}, destinationTensors,
559 tiledResults, resultOffsets, resultSizes);
563 destinationTensors, tiledBodyFn, loops);
567 rewriter, loc, loopRanges, tileSizes, numThreads,
options.mappingVector,
568 destinationTensors, tiledBodyFn, loops);
580 template <
typename LoopType>
581 FailureOr<LoopLikeOpInterface>
590 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
597 auto inits = llvm::to_vector(loopOp.getInitArgs());
598 inits.append(newInitOperands.begin(), newInitOperands.end());
599 auto newLoop = rewriter.
create<scf::ForOp>(
600 loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
604 Block *loopBody = loopOp.getBody();
605 Block *newLoopBody = newLoop.getBody();
607 loopBody, newLoopBody,
608 newLoopBody->
getArguments().take_front(loopBody->getNumArguments()));
610 auto yieldOp = cast<scf::YieldOp>(newLoopBody->
getTerminator());
616 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
617 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
618 newRegionIterArgs, tiledValues, resultOffsets,
625 for (
auto [tiledValue, regionIterArg, resultOffset, resultSize] :
626 llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
630 Value insert = rewriter.
create<tensor::InsertSliceOp>(
631 yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
633 newYieldValues.push_back(insert);
638 newLoop->getResults().take_front(loopOp.getNumResults()));
639 return cast<LoopLikeOpInterface>(newLoop.getOperation());
644 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
650 auto inits = llvm::to_vector(loopOp.getOutputs());
651 inits.append(newInitOperands.begin(), newInitOperands.end());
652 auto newLoop = rewriter.
create<scf::ForallOp>(
653 loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
654 loopOp.getMixedStep(), inits, loopOp.getMapping(),
658 Block *loopBody = loopOp.getBody();
659 Block *newLoopBody = newLoop.getBody();
661 loopBody, newLoopBody,
662 newLoopBody->
getArguments().take_front(loopBody->getNumArguments()));
664 auto terminator = cast<scf::InParallelOp>(newLoopBody->
getTerminator());
669 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
670 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
671 regionIterArgs, tiledValues, resultOffsets,
675 "failed to get yielded tiled values");
681 for (
auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
682 tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
685 rewriter.
create<tensor::ParallelInsertSliceOp>(
686 terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
691 newLoop->getResults().take_front(loopOp.getNumResults()));
692 return cast<LoopLikeOpInterface>(newLoop.getOperation());
702 loopLikeOp.getOperation())
703 .Case<scf::ForOp, scf::ForallOp>(
704 [&](
auto loopOp) -> FailureOr<LoopLikeOpInterface> {
706 loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
708 .Default([&](
auto loopOp) -> FailureOr<LoopLikeOpInterface> {
728 for (
auto &loop : loops.drop_back()) {
732 auto forLoop = cast<scf::ForOp>(loop.getOperation());
736 newInits.append(newInitValues.begin(), newInitValues.end());
737 auto newLoop = rewriter.
create<scf::ForOp>(
738 forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
739 forLoop.getStep(), newInits,
744 sourceBlockArgs.push_back(newLoop.getInductionVar());
745 auto newRegionIterArgs = newLoop.getRegionIterArgs();
746 sourceBlockArgs.append(
747 newRegionIterArgs.begin(),
748 std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
749 rewriter.
mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
751 forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
753 ivs.push_back(newLoop.getInductionVar());
754 newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
758 LoopLikeOpInterface innerMostLoop = loops.back();
759 FailureOr<LoopLikeOpInterface> newInnerMostLoop =
761 getNewTiledYieldsFn);
763 if (failed(newInnerMostLoop))
764 return innerMostLoop.emitOpError(
"failed to return additional yields");
765 loops.back() = newInnerMostLoop.value();
769 for (
auto [outerLoop, innerLoop] :
770 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
772 auto outerForLoop = cast<scf::ForOp>(outerLoop);
773 auto outerLoopYield =
774 cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
776 llvm::to_vector(outerLoopYield.getOperands());
778 innerLoop->getResults().take_back(newInitValues.size());
779 newYields.append(additionalYields.begin(), additionalYields.end());
788 FailureOr<scf::SCFTilingResult>
803 std::tie(tileSizes, numThreads) =
815 if (!
options.interchangeVector.empty()) {
817 iterationDomain.size());
819 "expected interchange vector to be a permutation");
823 if (!numThreads.empty())
827 FailureOr<TilingResult> tilingResult;
839 rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
843 if (!interchangeVector.empty()) {
852 auto clonedOp = cast<TilingInterface>(
859 tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
867 tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes);
868 if (failed(tilingResult)) {
870 return op.emitOpError(
"faild to tile operation");
878 for (
auto [index, tiledValue] :
880 tiledResults.push_back(tiledValue);
882 if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
883 resultOffset, resultSize))) {
884 for (
auto op : tilingResult->tiledOps) {
888 op,
"failed to get slice of result produced");
890 resultOffsets.emplace_back(std::move(resultOffset));
891 resultSizes.emplace_back(std::move(resultSize));
900 destinationTensors))) {
902 "unable to create destination tensors");
908 tileSizes, numThreads, destinationTensors,
909 innerYieldTiledValuesFn, loops)))
910 return op.emitOpError(
"failed to generate tiling loops");
911 assert(succeeded(tilingResult) &&
912 "expected tiling result to be computed after loop generation");
918 tilingResult->tiledValues,
919 tilingResult->generatedSlices};
923 loops.front()->getResults(), [](
OpResult r) ->
Value { return r; });
925 tilingResult->generatedSlices};
928 FailureOr<scf::SCFReductionTilingResult>
930 PartialReductionOpInterface op,
935 auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
937 auto tileSizesVector = llvm::to_vector(tileSizes);
938 if (tileSizesVector.size() < iterationDomain.size()) {
940 tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
944 tilingInterfaceOp.getLoopIteratorTypes();
947 for (
auto [idx, iteratorType] :
949 if (iteratorType == utils::IteratorType::reduction)
950 reductionDims.push_back(idx);
954 FailureOr<SmallVector<Value>> maybeInitTensors =
955 op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
957 if (failed(maybeInitTensors)) {
964 auto innerYieldTiledValuesFn =
972 int materializedLoopNum = 0;
973 for (
auto [tileSize, loopRange] :
974 llvm::zip_equal(tileSizesVector, iterationDomain)) {
976 offsets.push_back(loopRange.offset);
977 sizes.push_back(loopRange.size);
980 Value iv = ivs[materializedLoopNum++];
981 offsets.push_back(iv);
989 auto clonedOp = cast<PartialReductionOpInterface>(
993 FailureOr<TilingResult> partialTilingResult =
994 clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets,
995 sizes, reductionDims);
996 if (failed(partialTilingResult)) {
999 std::swap(parallelTiledOps, partialTilingResult->tiledOps);
1000 std::swap(tiledResult, partialTilingResult->tiledValues);
1008 for (
auto result : tiledResult) {
1010 resultOffsets.emplace_back(std::move(outOffsets));
1013 for (
size_t i = 0; i < offsets.size(); i++) {
1016 resultSizes.emplace_back(std::move(outSizes));
1027 initTensors, innerYieldTiledValuesFn, loops)))
1031 loops.front()->getResults(), [](
OpResult r) ->
Value { return r; });
1035 FailureOr<MergeResult> mergeResult =
1036 op.mergeReductions(b, loc, replacements, reductionDims);
1037 if (failed(mergeResult)) {
1040 b.
replaceOp(op, mergeResult->replacements);
1044 std::swap(reductionTilingResult.
mergeOps, mergeResult->mergeOps);
1045 std::swap(reductionTilingResult.
initialValues, initTensors);
1046 std::swap(reductionTilingResult.
loops, loops);
1047 std::swap(reductionTilingResult.
replacements, mergeResult->replacements);
1049 return reductionTilingResult;
1061 static std::tuple<OpResult, std::optional<OpOperand *>>
1064 std::optional<OpOperand *> destinationIterArg;
1065 auto loopIt = loops.rbegin();
1066 while (
auto iterArg = dyn_cast<BlockArgument>(source->
get())) {
1067 auto loop = *loopIt;
1068 if (iterArg.getOwner()->getParentOp() != loop)
1070 source = loop.getTiedLoopInit(iterArg);
1073 if (loopIt == loops.rend())
1074 destinationIterArg = source;
1075 return {dyn_cast<OpResult>(source->
get()), destinationIterArg};
1080 std::optional<scf::SCFFuseProducerOfSliceResult>
1082 RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1086 auto [fusableProducer, destinationInitArg] =
1089 if (!fusableProducer)
1090 return std::nullopt;
1091 unsigned resultNumber = fusableProducer.getResultNumber();
1099 Operation *fusableProducerOp = fusableProducer.getOwner();
1100 if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
1102 rewriter, fusableProducerOp->
getLoc(), fusableProducerOp,
1103 origDestinationTensors)))
1104 return std::nullopt;
1106 clonedOpDestinationTensors = origDestinationTensors;
1107 if (destinationInitArg &&
1108 isa<DestinationStyleOpInterface>(fusableProducerOp)) {
1112 clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1116 rewriter, fusableProducerOp, clonedOpDestinationTensors);
1121 llvm::to_vector(candidateSliceOp->getOperands());
1122 candidateSliceOpOperands[0] = clonedProducerOp->
getResult(resultNumber);
1123 tensor::ExtractSliceOp clonedCandidateSliceOp =
1125 candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1128 FailureOr<TilingResult> tileAndFuseResult =
1130 rewriter, clonedCandidateSliceOp,
1131 clonedProducerOp->
getResult(resultNumber));
1132 if (failed(tileAndFuseResult))
1133 return std::nullopt;
1137 tileAndFuseResult->tiledValues[0]);
1138 rewriter.
eraseOp(clonedCandidateSliceOp);
1139 rewriter.
eraseOp(clonedProducerOp);
1184 if (destinationInitArg &&
1185 isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
1187 ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1188 .set(origDestinationTensors[resultNumber]);
1191 fusableProducer, tileAndFuseResult->tiledValues[0],
1192 tileAndFuseResult->
tiledOps, tileAndFuseResult->generatedSlices};
1197 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1205 *tiledOwner = fusedProducerInfo.
tiledOps[0];
1210 yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
1212 : llvm::to_vector(yieldResultNumber);
1214 for (
const auto &resultNumber : initNumberList) {
1216 rewriter, loc, originalOwner->
getResult(resultNumber));
1217 if (succeeded(initValue)) {
1218 initValueList.push_back(initValue.value());
1234 sliceSizes = sliceOp.getMixedSizes();
1237 if (llvm::any_of(sliceOp.getMixedStrides(), [](
OpFoldResult ofr) {
1238 return !isConstantIntValue(ofr, 1);
1242 unsigned sliceResultNumber =
1245 auto tilableOp = cast<TilingInterface>(originalOwner);
1249 if (tilableOp->getNumResults() > 1 &&
1250 failed(tilableOp.getIterationDomainTileFromResultTile(
1251 rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1252 iterDomainOffset, iterDomainSizes))) {
1267 for (
const auto &resultNumber : initNumberList) {
1268 if (resultNumber == sliceResultNumber) {
1269 offsetList.push_back(sliceOffset);
1270 sizesList.push_back(sliceSizes);
1272 assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1275 if (failed(tilableOp.getResultTilePosition(
1276 rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1280 offsetList.push_back(offset);
1281 sizesList.push_back(sizes);
1286 if (
auto tiledDestStyleOp =
1287 dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1289 for (
const auto &&[index, newRegionArg] :
1291 auto destSlice = rewriter.
create<tensor::ExtractSliceOp>(
1292 loc, newRegionArg, offsetList[index], sizesList[index],
1295 generatedSlices.push_back(destSlice);
1296 unsigned resultNumber = initNumberList[index];
1298 tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1307 for (
const auto &&[index, resultNumber] :
llvm::enumerate(initNumberList)) {
1308 tiledResult.push_back(tiledOwner->getResult(resultNumber));
1309 tiledOffset.emplace_back(offsetList[index]);
1310 tiledSizes.emplace_back(sizesList[index]);
1316 newYieldValuesFn))) {
1319 return generatedSlices;
1333 explicit SliceTrackingListener(
1334 std::optional<FrozenRewritePatternSet> patterns);
1335 SliceTrackingListener() =
default;
1343 void notifyOperationInserted(
Operation *op,
1350 void notifyOperationErased(
Operation *op)
override;
1357 std::deque<tensor::ExtractSliceOp> worklist;
1361 std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
1364 SliceTrackingListener::SliceTrackingListener(
1365 std::optional<FrozenRewritePatternSet> p) {
1366 patterns = std::move(p);
1372 if (
auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1373 worklist.push_back(slice);
1385 void SliceTrackingListener::notifyOperationInserted(
1387 auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1390 worklist.push_back(slice);
1395 void SliceTrackingListener::removeOp(
Operation *op) {
1396 if (!isa<tensor::ExtractSliceOp>(op))
1398 auto iter = worklist.begin();
1399 while (iter != worklist.end()) {
1404 if (iter == worklist.end())
1407 worklist.erase(iter);
1410 void SliceTrackingListener::notifyOperationErased(
Operation *op) {
1414 void SliceTrackingListener::notifyOperationReplaced(
Operation *op,
1421 FailureOr<scf::SCFTileAndFuseResult>
1427 if (!consumer->getNumResults()) {
1429 consumer,
"invalid pattern for op with no results");
1434 llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
1436 FailureOr<scf::SCFTilingResult> tilingResult =
1439 if (failed(tilingResult))
1441 for (
auto *tiledOp : tilingResult->tiledOps)
1442 tiledAndFusedOps.insert(tiledOp);
1445 auto &loops = tilingResult->loops;
1446 if (loops.empty()) {
1448 for (
auto [origVal, replacement] :
1449 llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
1450 replacements[origVal] = replacement;
1460 for (
auto [index, result] :
llvm::enumerate(consumer->getResults())) {
1461 origValToResultNumber[result] = index;
1471 struct WorklistItem {
1472 tensor::ExtractSliceOp candidateSlice;
1476 SliceTrackingListener sliceTracker =
1477 SliceTrackingListener(
options.cleanupPatterns);
1480 sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1484 while (!sliceTracker.worklist.empty()) {
1485 auto candidateSlice = sliceTracker.worklist.front();
1486 sliceTracker.worklist.pop_front();
1488 auto [fusableProducer, destinationInitArg] =
1491 if (!fusableProducer)
1494 std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1495 options.fusionControlFn(candidateSlice, fusableProducer,
1496 destinationInitArg.has_value());
1497 if (!controlFnResult)
1500 WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
1505 std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1513 if (worklistItem.controlFnResult.yieldProducerReplacement) {
1517 Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1518 FailureOr<SmallVector<Operation *>> newSlices =
1520 worklistItem.candidateSlice,
1521 fusedResult.value(), loops);
1522 if (failed(newSlices)) {
1524 fusableProducerOp,
"failed to replacement value for this "
1525 "operation from within the tiled loop");
1527 worklistCandidates.append(newSlices.value());
1528 for (
auto [index, result] :
1530 origValToResultNumber[result] = loops.front()->getNumResults() -
1536 fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1537 fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1538 tiledAndFusedOps.insert(tiledAndFusedOp);
1541 if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1547 for (
auto [origVal, resultNumber] : origValToResultNumber) {
1548 replacements[origVal] = loops.front()->getResult(resultNumber);
1561 static LogicalResult
1563 Value result = candidateSliceOp.getResult();
1565 if (!llvm::hasSingleElement(uses)) {
1566 LLVM_DEBUG(llvm::dbgs() <<
"Too many uses of the candidate slice op\n");
1569 OpOperand &operandUse = (*uses.begin());
1571 if (!isa<scf::YieldOp>(userOp)) {
1572 LLVM_DEBUG(llvm::dbgs()
1573 <<
"Expected scf.yield to be the only user, but got -> "
1578 LLVM_DEBUG(llvm::dbgs() <<
"Expected tensor.insert_slice and scf.yield to "
1579 "be in the same block\n");
1588 if (!isa<LoopLikeOpInterface>(loopOp))
1607 if (isa<tensor::ParallelInsertSliceOp>(userOp))
1610 if (loopOp->
getBlock() != userOp->getBlock())
1614 firstUserOfLoop = userOp;
1616 return firstUserOfLoop;
1657 static FailureOr<llvm::SetVector<Operation *>>
1659 bool reorderOperations) {
1661 if (failed(firstUserOfLoop))
1667 options.omitBlockArguments =
true;
1668 bool includeLoopOp =
false;
1671 includeLoopOp =
true;
1683 if (!slice.empty()) {
1693 if (includeLoopOp || !reorderOperations)
1705 unsigned resultNumber) {
1706 if (!isa<LoopLikeOpInterface>(loopOp))
1711 Operation *consumerOp = opOperand.getOwner();
1713 if (!isa<TilingInterface>(consumerOp) ||
1714 !isa<DestinationStyleOpInterface>(consumerOp)) {
1721 if (loopBlock != consumerOp->
getBlock())
1728 FailureOr<llvm::SetVector<Operation *>> slice =
1733 if (!slice->empty()) {
1736 assert(succeeded(firstUserOfLoop) &&
"First user of loop is not found");
1737 for (
auto op : *slice) {
1766 auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp());
1769 auto isForOpYieldResultOfInnerLoop =
1770 [](scf::ForOp outerLoop) -> LogicalResult {
1771 Block *body = outerLoop.getBody();
1775 auto innerForOp = dyn_cast<scf::ForOp>(body->
front());
1779 return success(innerForOp->getNumResults() == yieldOp->getNumOperands());
1782 while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) {
1783 nestLoops.push_back(outerLoop);
1784 outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp());
1787 return {nestLoops.rbegin(), nestLoops.rend()};
1794 static FailureOr<OpOperand *>
1796 tensor::InsertSliceOp candidateSliceOp) {
1799 Value sliceResult = candidateSliceOp.getResult();
1805 auto forOp = dyn_cast<scf::ForOp>(containingOp);
1815 static FailureOr<OpOperand *>
1817 tensor::ParallelInsertSliceOp candidateSliceOp) {
1819 Value sliceDest = candidateSliceOp.getDest();
1820 auto iterArg = dyn_cast<BlockArgument>(sliceDest);
1827 auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
1830 unsigned resultNumber =
1831 forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
1839 static FailureOr<OpOperand *>
1841 if (
auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1843 }
else if (
auto parallelInsertSlice =
1844 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1853 FailureOr<scf::SCFFuseConsumerOfSliceResult>
1856 if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
1860 bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
1864 FailureOr<OpOperand *> maybeConsumerOpOperand =
1866 if (failed(maybeConsumerOpOperand)) {
1868 "could not fetch consumer to fuse");
1870 OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
1873 unsigned resultNumber = 0;
1874 if (
auto producerResult = dyn_cast<OpResult>(consumerOpOperand->
get())) {
1875 resultNumber = producerResult.getResultNumber();
1878 consumerOp,
"consumer op's operand doesn't seem to be an OpResult");
1885 LoopLikeOpInterface innerMostLoop =
1888 if (isInsertSliceOp) {
1889 nestedLoops = llvm::map_to_vector(
1891 cast<scf::ForOp>(innerMostLoop.getOperation())),
1892 [](scf::ForOp forOp) {
1893 return cast<LoopLikeOpInterface>(forOp.getOperation());
1896 nestedLoops = {innerMostLoop};
1899 LoopLikeOpInterface outerMostLoop = nestedLoops.front();
1904 outerMostLoop,
"the first user of loop should not dominate any define "
1905 "of consumer operand(s)");
1911 auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
1914 "consumer op is not DPS operation");
1916 llvm::map_to_vector(dstOp.getDpsInits(), [](
Value v) { return v; });
1917 if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
1920 "consumer op taking the result of scf.for as init is not supported");
1924 Location loc = outerMostLoop->getLoc();
1929 if (failed(firstUserOfLoop)) {
1931 outerMostLoop,
"could not find the first user of outer most loop");
1933 rewriter.
moveOpBefore(outerMostLoop, *firstUserOfLoop);
1939 tensor::InsertSliceOp clonedInsertSliceOp;
1941 dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
1942 auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
1944 clonedInsertSliceOp = rewriter.
create<tensor::InsertSliceOp>(
1945 loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
1946 sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
1949 clonedInsertSliceOp =
1950 cast<tensor::InsertSliceOp>(rewriter.
clone(*candidateSliceOp));
1954 auto clonedConsumerOp = cast<TilingInterface>(rewriter.
clone(*consumerOp));
1958 OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
1960 operandToReplace.
set(clonedInsertSliceOp.getResult());
1966 cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
1967 FailureOr<TilingResult> tileAndFuseResult =
1969 rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
1970 if (failed(tileAndFuseResult)) {
1973 auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
1975 clonedInsertSliceOp.getSource());
1996 candidateSliceOp,
"containingOp's result yield with stride");
2005 if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
2006 rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
2007 iterDomainSizes))) {
2010 "can't get iter domain position from input position");
2016 unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
2018 totalNumResultsOfConsumer);
2020 totalNumResultsOfConsumer);
2021 for (
auto [idx, v] :
llvm::enumerate(tiledConsumerOp->getResults())) {
2022 if (failed(tiledConsumerOp.getResultTilePosition(
2023 rewriter, idx, iterDomainOffsets, iterDomainSizes,
2024 resultOffsets[idx], resultSizes[idx]))) {
2027 "can't get result domain position from iter domain position");
2033 if (
auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
2034 tiledConsumerOp.getOperation())) {
2036 for (
const auto &&[index, newRegionArg] :
2038 auto destSlice = rewriter.
create<tensor::ExtractSliceOp>(
2039 loc, newRegionArg, resultOffsets[index], resultSizes[index],
2044 auto dstNumber = index;
2046 tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
2055 for (
const auto &&[index, result] :
2057 tiledResult.push_back(result);
2058 tiledOffset.emplace_back(resultOffsets[index]);
2059 tiledSizes.emplace_back(resultSizes[index]);
2065 newYieldValuesFn))) {
2067 "unable to add new inits to nest loop");
2072 for (
auto &&[oldResult, newResult] : llvm::zip(
2074 nestedLoops.front()->getResults().take_back(newInits.size()))) {
2079 rewriter.
eraseOp(clonedConsumerOp);
2083 &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
2084 tileAndFuseResult->tiledOps};
2091 FailureOr<SmallVector<scf::ForOp>>
2093 TilingInterface op) {
2095 if (op->getNumResults() > 0) {
2097 op,
"unable to lower to loops operations with return values");
2104 for (
auto loopRange : domain) {
2111 auto loop = rewriter.
create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
2113 loops.push_back(loop);
2114 ivs.push_back(loop.getInductionVar());
2117 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static LogicalResult verifyTileSizeOptions(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options)
Verify the tile size options are set in a consistent manner.
static LogicalResult checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp)
A utility function that checks whether the only use of the result of a tensor.insert_slice op is in a...
std::function< LogicalResult(RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, SmallVector< Value > &tiledValues, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> YieldTiledValuesFn
A function that allows returning additional yielded values during yieldTiledValuesAndReplace.
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, OpFoldResult numThreads, OpFoldResult iterationSize)
Returns true if the maximum tile offset tileSize * numThreads-1 is less than iterationSize.
static FailureOr< llvm::SetVector< Operation * > > checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, bool reorderOperations)
This utility currently checks whether the first userOp of loop is NOT before the last defineOp of con...
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, OpFoldResult offset, OpFoldResult tileSize)
Returns the bounded tile size given the current offset, loopRange and tileSize, i....
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Append the specified additional newInitOperands operands to the loops existing init operands (or simi...
static LogicalResult generateLoopNestUsingForOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ValueRange destinationTensors, YieldTiledValuesFn yieldTiledValuesFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.for operation.
static FailureOr< OpOperand * > getConsumerFromLoopUses(RewriterBase &rewriter, Operation *loopOp, unsigned resultNumber)
Fetches the OpOperand of the first valid user (and use) of the value val which implements TilingInter...
static FailureOr< OpOperand * > getUntiledConsumerFromSlice(RewriterBase &rewriter, tensor::InsertSliceOp candidateSliceOp)
Fetch the untiled consumer of a scf.for's result which is yielded by a tensor.insert_slice.
static void checkSafeToTileToForall(TilingInterface op, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Checks if any of the tiled loops are not parallel.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, ArrayRef< Range > iterationDomain, const scf::SCFTilingOptions &options)
Method to instantiate the tile sizes and/or number of threads specified by the user.
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< LoopLikeOpInterface > loops)
Return the untiled producer whose slice is used in a tiled consumer.
static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using the loop construct specifed in options.
static SmallVector< scf::ForOp > getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop)
Find the perfectly nested loops outside of given loop(included) sorted from outer to inner.
static bool tileDividesIterationDomain(Range loopRange)
Check if stride evenly divides the trip count size - offset.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes)
Function to return the bounds of the loops to be generated.
static Operation * cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *op, ValueRange newDestArgs)
Clones the operation and updates the destination if the operation implements the DestinationStyleOpIn...
static FailureOr< Operation * > getFirstUserOfLoop(Operation *loopOp)
An utility to get the first user of the given loopOp.
static LogicalResult addInitOperandsToLoopNest(RewriterBase &rewriter, MutableArrayRef< LoopLikeOpInterface > loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn)
Method to add new init values to a loop nest.
static LogicalResult generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ArrayRef< Attribute > mappingVector, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.forall operation.
Base type for affine expression.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteStrictness strictMode
Strict mode can restrict the ops that are added to the worklist during the rewrite.
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class represents a saved insertion point.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation * getOwner() const
Returns the operation that owns this result.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< scf::SCFReductionTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
FailureOr< scf::SCFFuseConsumerOfSliceResult > tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp)
Implementation of fusing consumer of a single slice by computing the slice of the consumer in-place f...
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< SmallVector< Operation * > > yieldReplacementForFusedProducer(RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef< LoopLikeOpInterface > loops, ArrayRef< unsigned > yieldResultNumber=ArrayRef< unsigned >{})
Reconstruct the fused producer from within the tiled-and-fused code.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
std::optional< SCFFuseProducerOfSliceResult > tileAndFuseProducerOfSlice(RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing producer of a single slice by computing the slice of the producer in-place.
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
FailureOr< TilingResult > replaceInsertSliceWithTiledConsumer(OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp, OpOperand &consumerOp)
Method to swap an tensor.insert_slice with its consumer when the consumer implements the TilingInterf...
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
LogicalResult applyOpPatternsAndFold(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
@ ExistingAndNewOps
Only pre-existing and newly created ops are processed.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Container for result values of tiling.
Fuse the consumer of the source of candidateSliceOp by computing the required slice of the consumer i...
Fuse the producer of the source of candidateSliceOp by computing the required slice of the producer i...
SmallVector< Operation * > tiledOps
Transformation information returned after reduction tiling.
SmallVector< Value > replacements
The replacements to use for the results of the tiled operation.
SmallVector< Value > initialValues
Initial values used for reduction.
SmallVector< Operation * > parallelTiledOps
The partial reduction tiled op generated.
SmallVector< LoopLikeOpInterface > loops
The loop operations that iterate over the tiles.
SmallVector< Operation * > mergeOps
The final reduction operation merging all the partial reductions.
Control function to check if a slice needs to be fused or not, The control function receives 1) the s...
Options used to control tile + fuse.
Transformation information returned after tile and fuse.
Options to use to control tiling.
SCFTileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes to use for each loop.
SCFTilingOptions & setNumThreads(ArrayRef< OpFoldResult > numThreads)
Convenience function to set the numThreadsComputationFunction to a function that computes num threads...
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.