28 #include "llvm/ADT/ScopeExit.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/Support/Debug.h"
33 #define DEBUG_TYPE "tile-using-interface"
40 auto tileSizes = llvm::to_vector(ts);
49 assert(!numThreadsComputationFunction &&
"num tiles already set");
50 auto numThreads = llvm::to_vector(nt);
61 size_t iterationDomainSize) {
63 if (filledVector.size() < iterationDomainSize) {
64 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
65 filledVector.append(range.begin(), range.end());
67 if (filledVector.size() > iterationDomainSize)
68 filledVector.resize(iterationDomainSize);
80 if (
options.numThreadsComputationFunction &&
83 loc,
"number of threads can only by specified when loop type is "
84 "set to use `scf.forall`");
88 if (!
options.interchangeVector.empty()) {
91 loc,
"invalid interchange vector, not a permutation of the entire "
106 size_t numLoops = iterationDomain.size();
109 if (
options.numThreadsComputationFunction) {
110 numThreads =
options.numThreadsComputationFunction(rewriter, op);
111 numThreads.resize(numLoops, zero);
114 if (
options.tileSizeComputationFunction) {
115 tileSizes =
options.tileSizeComputationFunction(rewriter, op);
116 tileSizes.resize(numLoops, zero);
117 return {tileSizes, numThreads};
129 tileSizes.resize(numLoops, zero);
130 for (
auto [index, range, nt] :
136 rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
138 tileSizes.resize(numLoops, zero);
139 return {tileSizes, numThreads};
146 assert(
options.tileSizeComputationFunction &&
147 "expected tile sizes to be specified");
148 tileSizes =
options.tileSizeComputationFunction(rewriter, op);
149 tileSizes.resize(numLoops, zero);
151 return {tileSizes, numThreads};
160 auto iterators = op.getLoopIteratorTypes();
161 assert(iterators.size() == tileSizes.size() &&
162 "expected as many tile size values as number of loops");
163 assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
164 "when specified, expected number of threads to use for each loop");
166 bool isParallelTiling =
false;
167 for (
auto [index, iterator, tileSize] :
170 isParallelTiling |= iterator == utils::IteratorType::parallel;
177 if (!numThreads.empty()) {
178 if (std::optional<int64_t> constNumThreads =
180 if (constNumThreads.value() > 1 &&
181 iterator != utils::IteratorType::parallel) {
182 op.emitWarning() <<
"tiling is not thread safe at axis #" << index;
188 if (std::optional<int64_t> constTileSize =
190 if (constTileSize.value() > 0 &&
191 iterator != utils::IteratorType::parallel) {
192 op.emitWarning() <<
"tiling is not thread safe at axis #" << index;
199 if (isParallelTiling) {
200 return op->emitOpError(
"tiling parallel dimensions is not supported with "
201 "partial reduction tiling strategies");
213 for (
auto dim :
options.reductionDims) {
216 reductionDims.insert(dim);
218 return reductionDims;
232 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
241 if (ts && ts.value() == 1)
268 if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
270 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
284 int materializedLoopNum = 0;
286 if (!numThreads.empty()) {
291 offsetExpr = d0 + d1 * s0;
292 residualTileSizeExpr = s1 - (d0 + d1 * s0);
294 for (
auto [index, nt, tileSize, loopRange] :
300 offsets.push_back(loopRange.offset);
301 sizes.push_back(loopRange.size);
305 Value iv = ivs[materializedLoopNum++];
307 rewriter, loc, offsetExpr,
310 rewriter, loc, residualTileSizeExpr,
311 {loopRange.offset, nt, tileSize, loopRange.size});
317 {offset, loopRange.size});
321 {sizeMinusOffsetPerThread, tileSize});
337 rewriter, loc, maxMap, {rewriter.
getIndexAttr(0), size});
340 offsets.push_back(offset);
341 sizes.push_back(size);
343 return {offsets, sizes};
345 for (
auto [tileSize, loopRange] :
346 llvm::zip_equal(tileSizes, iterationDomain)) {
351 offsets.push_back(loopRange.offset);
352 sizes.push_back(loopRange.size);
356 Value iv = ivs[materializedLoopNum++];
358 offsets.push_back(offset);
361 sizes.push_back(size);
363 return {offsets, sizes};
373 for (
auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
377 lbs.push_back(loopRange.offset);
378 ubs.push_back(loopRange.size);
379 steps.push_back(tileSize);
381 return {lbs, ubs, steps};
411 if (newDestArgs.empty())
413 if (
auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
414 destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
432 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
433 assert(loopRanges.size() == tileSizes.size() &&
434 "expected as many tile sizes as loop ranges");
438 std::tie(lbs, ubs, steps) =
448 for (
auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
450 scf::ForOp::create(rewriter, loc, lb, ub, step, destinationTensors,
453 loops.push_back(loop);
454 ivs.push_back(loop.getInductionVar());
456 destinationTensors = loop.getRegionIterArgs();
461 if (
failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
462 tiledResults, resultOffsets, resultSizes))) {
464 loc,
"failed to generate inner tile loop body");
469 assert(tiledResults.size() == destinationTensors.size() &&
470 "Number of results of body should be equal to number of iter args");
474 for (
auto [tiledValue, destinationTensor, resultOffset, resultSize] :
475 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
479 auto insertSlice = tensor::InsertSliceOp::create(
480 rewriter, loc, tiledValue, destinationTensor, resultOffset, resultSize,
482 yieldedValues.push_back(insertSlice);
484 scf::YieldOp::create(rewriter, loc, yieldedValues);
487 for (
auto [outerLoop, innerLoop] :
491 cast<scf::ForOp>(outerLoop.getOperation()).getBody());
492 scf::YieldOp::create(rewriter, outerLoop.getLoc(), innerLoop->getResults());
513 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
514 assert(loopRanges.size() == tileSizes.size() &&
515 "expected as many tile sizes as loop ranges");
518 std::optional<ArrayAttr> mappingAttr;
519 if (!mappingVector.empty())
522 scf::ForallOp forallOp;
523 bool useNumThreads = !numThreads.empty();
528 for (
auto nt : numThreads) {
531 nonZeroNumThreads.push_back(nt);
533 forallOp = scf::ForallOp::create(rewriter, loc, nonZeroNumThreads,
534 destinationTensors, mappingAttr);
537 std::tie(lbs, ubs, steps) =
539 forallOp = scf::ForallOp::create(rewriter, loc, lbs, ubs, steps,
540 destinationTensors, mappingAttr);
542 loops.push_back(forallOp);
545 destinationTensors = forallOp.getRegionOutArgs();
549 if (
failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
550 destinationTensors, tiledResults, resultOffsets,
555 for (
auto [tiledValue, destinationTensor, resultOffset, resultSize] :
556 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
561 tensor::ParallelInsertSliceOp::create(rewriter, loc, tiledValue,
562 destinationTensor, resultOffset,
563 resultSize, resultStride);
589 return tiledBodyFn(rewriter, loc,
ValueRange{}, destinationTensors,
590 tiledResults, resultOffsets, resultSizes);
594 destinationTensors, tiledBodyFn, loops);
598 rewriter, loc, loopRanges, tileSizes, numThreads, mappingVector,
599 destinationTensors, tiledBodyFn, loops);
617 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
619 return op->emitOpError(
620 "PartialReductionOuterReduction tiling strategy is only supported for "
621 "operations implementing PartialReductionOpInterface");
626 AffineExpr sizeExpr = ((s0 - s1).ceilDiv(s2));
628 for (
auto [index, domain, tileSize] :
630 if (!numThreads.empty()) {
634 rewriter, op.getLoc(), sizeExpr,
635 {domain.size, domain.offset, domain.stride});
638 sizes[index] = numThreads[index];
645 rewriter, op.getLoc(), sizeExpr,
646 {domain.size, domain.offset, domain.stride});
650 if (reductionStrategy ==
652 sizes[index] = tileSize;
656 assert(reductionStrategy ==
659 rewriter, op.getLoc(), sizeExpr,
660 {domain.size, domain.offset, domain.stride});
662 rewriter, op.getLoc(), divExpr, {normalizedRange, tileSize});
664 return redOp.generateInitialTensorForPartialReduction(rewriter, loc, sizes,
678 splitReductionIvs.resize(reductionDims.size(), rewriter.
getIndexAttr(0));
683 if (reductionStrategy ==
686 if (!numThreads.empty()) {
687 splitReductionIvs[index] = ivs[ivIndex++];
691 rewriter, loc, divExpr,
695 return splitReductionIvs;
698 static FailureOr<TilingResult>
707 return op.getTiledImplementation(rewriter, offsets, sizes);
710 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
713 op,
"PartialReductionOuterReduction tiling strategy is only "
714 "supported for operations "
715 "implementing PartialReductionOpInterface");
720 numThreads, tileSizes, reductionDims);
721 return redOp.tileToPartialReduction(rewriter, op.getLoc(), reductionStrategy,
722 regionIterArg, offsets, sizes,
723 reductionDims, splitReductionIvs);
728 int64_t index,
Value tiledResult, TilingInterface op,
736 return op.getResultTilePosition(rewriter, index, offsets, sizes,
737 resultOffset, resultSize);
739 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
742 op,
"PartialReductionOuterReduction tiling strategy is only supported"
743 "for operations implementing PartialReductionOpInterface");
747 numThreads, tileSizes, reductionDims);
748 return redOp.getPartialResultTilePosition(
749 rewriter, index, reductionStrategy, offsets, sizes, reductionDims,
750 splitReductionIvs, resultOffset, resultSize);
753 static FailureOr<MergeResult>
759 "expected merge to be called for only partial reduction cases");
761 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
764 op,
"PartialReductionOuterReduction tiling strategy is only "
765 "supported for operations "
766 "implementing PartialReductionOpInterface");
768 return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
779 template <
typename LoopType>
780 FailureOr<LoopLikeOpInterface>
789 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
796 auto inits = llvm::to_vector(loopOp.getInitArgs());
797 inits.append(newInitOperands.begin(), newInitOperands.end());
798 auto newLoop = scf::ForOp::create(
799 rewriter, loc, loopOp.getLowerBound(), loopOp.getUpperBound(),
801 loopOp.getUnsignedCmp());
804 Block *loopBody = loopOp.getBody();
805 Block *newLoopBody = newLoop.getBody();
807 loopBody, newLoopBody,
810 auto yieldOp = cast<scf::YieldOp>(newLoopBody->
getTerminator());
816 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
817 if (
failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
818 newRegionIterArgs, tiledValues, resultOffsets,
825 for (
auto [tiledValue, regionIterArg, resultOffset, resultSize] :
826 llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
830 Value insert = tensor::InsertSliceOp::create(
831 rewriter, yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset,
832 resultSize, resultStride);
833 newYieldValues.push_back(insert);
838 newLoop->getResults().take_front(loopOp.getNumResults()));
839 return cast<LoopLikeOpInterface>(newLoop.getOperation());
844 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
850 auto inits = llvm::to_vector(loopOp.getOutputs());
851 inits.append(newInitOperands.begin(), newInitOperands.end());
852 auto newLoop = scf::ForallOp::create(
853 rewriter, loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
854 loopOp.getMixedStep(), inits, loopOp.getMapping(),
858 Block *loopBody = loopOp.getBody();
859 Block *newLoopBody = newLoop.getBody();
861 loopBody, newLoopBody,
864 auto terminator = cast<scf::InParallelOp>(newLoopBody->
getTerminator());
869 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
870 if (
failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
871 regionIterArgs, tiledValues, resultOffsets,
875 "failed to get yielded tiled values");
881 for (
auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
882 tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
885 tensor::ParallelInsertSliceOp::create(rewriter, terminator.getLoc(),
886 tiledValue, iterArg, resultOffset,
887 resultSize, resultStride);
891 newLoop->getResults().take_front(loopOp.getNumResults()));
892 return cast<LoopLikeOpInterface>(newLoop.getOperation());
902 loopLikeOp.getOperation())
903 .Case<scf::ForOp, scf::ForallOp>(
904 [&](
auto loopOp) -> FailureOr<LoopLikeOpInterface> {
906 loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
908 .Default([&](
auto loopOp) -> FailureOr<LoopLikeOpInterface> {
927 for (
auto &loop : loops.drop_back()) {
931 auto forLoop = cast<scf::ForOp>(loop.getOperation());
935 newInits.append(newInitValues.begin(), newInitValues.end());
936 auto newLoop = scf::ForOp::create(
937 rewriter, forLoop.getLoc(), forLoop.getLowerBound(),
938 forLoop.getUpperBound(), forLoop.getStep(), newInits,
940 forLoop.getUnsignedCmp());
944 sourceBlockArgs.push_back(newLoop.getInductionVar());
945 auto newRegionIterArgs = newLoop.getRegionIterArgs();
946 sourceBlockArgs.append(
947 newRegionIterArgs.begin(),
948 std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
949 rewriter.
mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
951 forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
953 ivs.push_back(newLoop.getInductionVar());
954 newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
958 LoopLikeOpInterface innerMostLoop = loops.back();
959 FailureOr<LoopLikeOpInterface> newInnerMostLoop =
961 getNewTiledYieldsFn);
963 if (
failed(newInnerMostLoop))
964 return innerMostLoop.emitOpError(
"failed to return additional yields");
965 loops.back() = newInnerMostLoop.value();
969 for (
auto [outerLoop, innerLoop] :
970 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
972 auto outerForLoop = cast<scf::ForOp>(outerLoop);
973 auto outerLoopYield =
974 cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
976 llvm::to_vector(outerLoopYield.getOperands());
978 innerLoop->getResults().take_back(newInitValues.size());
979 newYields.append(additionalYields.begin(), additionalYields.end());
988 FailureOr<scf::SCFTilingResult>
1003 std::tie(tileSizes, numThreads) =
1009 tileSizes, numThreads))) {
1020 if (!
options.interchangeVector.empty()) {
1022 iterationDomain.size());
1024 "expected interchange vector to be a permutation");
1028 if (!numThreads.empty())
1032 FailureOr<TilingResult> tilingResult;
1044 rewriter, loc,
options.reductionStrategy, ivs, iterationDomain,
1045 tileSizes, numThreads, reductionDims);
1049 if (!interchangeVector.empty()) {
1058 auto clonedOp = cast<TilingInterface>(
1065 tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
1074 rewriter, clonedOp,
options.reductionStrategy, regionIterArgs, offsets,
1075 sizes, ivs, numThreads, tileSizes, reductionDims);
1076 if (
failed(tilingResult)) {
1078 return op.emitOpError(
"faild to tile operation");
1086 for (
auto [index, tiledValue] :
1088 tiledResults.push_back(tiledValue);
1091 rewriter,
options.reductionStrategy, index, tiledValue, op,
1092 offsets, sizes, ivs, numThreads, tileSizes, reductionDims,
1093 resultOffset, resultSize))) {
1094 for (
auto op : tilingResult->tiledOps) {
1098 op,
"failed to get slice of result produced");
1100 resultOffsets.emplace_back(std::move(resultOffset));
1101 resultSizes.emplace_back(std::move(resultSize));
1109 rewriter, op,
options.reductionStrategy, iterationDomain, numThreads,
1110 tileSizes, reductionDims);
1111 if (
failed(maybeInits)) {
1113 op,
"unable to create initial tensors for tiling");
1120 iterationDomain, tileSizes, numThreads,
1121 initTensors,
options.mappingVector,
1122 innerYieldTiledValuesFn, loops)))
1123 return op.emitOpError(
"failed to generate tiling loops");
1124 assert(succeeded(tilingResult) &&
1125 "expected tiling result to be computed after loop generation");
1127 if (loops.empty()) {
1133 tilingResult->tiledValues,
1134 tilingResult->generatedSlices,
1138 auto loopResults = llvm::map_to_vector(loops.front()->getResults(),
1144 tilingResult->
tiledOps, initTensors, loops, loopResults,
1145 tilingResult->generatedSlices, {}};
1150 rewriter, op,
options.reductionStrategy, reductionDims, loopResults);
1151 if (
failed(mergeResult)) {
1153 op,
"Failed to merge partial results from tiling");
1158 mergeResult->replacements,
1159 tilingResult->generatedSlices,
1160 mergeResult->mergeOps};
1163 FailureOr<scf::SCFTilingResult>
1165 PartialReductionOpInterface op,
1169 options.setReductionTilingStrategy(
1171 options.setTileSizes(tileSize);
1173 for (
auto [index, iteratorType] :
llvm::enumerate(op.getLoopIteratorTypes()))
1174 if (iteratorType == utils::IteratorType::reduction)
1175 reductionDims.push_back(index);
1176 options.setReductionDims(reductionDims);
1190 static std::tuple<OpResult, std::optional<OpOperand *>>
1193 std::optional<OpOperand *> destinationIterArg;
1194 assert(!loops.empty() &&
"expected non empty loops container");
1195 auto loopIt = loops.rbegin();
1196 while (loopIt != loops.rend() && isa<BlockArgument>(source->
get())) {
1197 auto iterArg = cast<BlockArgument>(source->
get());
1198 auto loop = *loopIt;
1199 if (iterArg.getOwner()->getParentOp() != loop)
1201 source = loop.getTiedLoopInit(iterArg);
1204 if (loopIt == loops.rend())
1205 destinationIterArg = source;
1206 return {dyn_cast<OpResult>(source->
get()), destinationIterArg};
1211 std::optional<scf::SCFFuseProducerOfSliceResult>
1213 RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1217 auto [fusableProducer, destinationInitArg] =
1220 if (!fusableProducer)
1221 return std::nullopt;
1222 unsigned resultNumber = fusableProducer.getResultNumber();
1230 Operation *fusableProducerOp = fusableProducer.getOwner();
1231 if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
1233 rewriter, fusableProducerOp->
getLoc(), fusableProducerOp,
1234 origDestinationTensors)))
1235 return std::nullopt;
1237 clonedOpDestinationTensors = origDestinationTensors;
1238 if (destinationInitArg &&
1239 isa<DestinationStyleOpInterface>(fusableProducerOp)) {
1243 clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1247 rewriter, fusableProducerOp, clonedOpDestinationTensors);
1252 llvm::to_vector(candidateSliceOp->getOperands());
1253 candidateSliceOpOperands[0] = clonedProducerOp->
getResult(resultNumber);
1254 tensor::ExtractSliceOp clonedCandidateSliceOp =
1256 candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1259 FailureOr<TilingResult> tileAndFuseResult =
1261 rewriter, clonedCandidateSliceOp,
1262 clonedProducerOp->
getResult(resultNumber));
1263 if (
failed(tileAndFuseResult))
1264 return std::nullopt;
1268 tileAndFuseResult->tiledValues[0]);
1269 rewriter.
eraseOp(clonedCandidateSliceOp);
1270 rewriter.
eraseOp(clonedProducerOp);
1315 if (destinationInitArg &&
1316 isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
1318 ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1319 .set(origDestinationTensors[resultNumber]);
1322 fusableProducer, tileAndFuseResult->tiledValues[0],
1323 tileAndFuseResult->
tiledOps, tileAndFuseResult->generatedSlices};
1328 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1336 *tiledOwner = fusedProducerInfo.
tiledOps[0];
1341 yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
1343 : llvm::to_vector(yieldResultNumber);
1345 for (
const auto &resultNumber : initNumberList) {
1347 rewriter, loc, originalOwner->
getResult(resultNumber));
1348 if (succeeded(initValue)) {
1349 initValueList.push_back(initValue.value());
1365 sliceSizes = sliceOp.getMixedSizes();
1368 if (!llvm::all_of(sliceOp.getMixedStrides(),
isOneInteger))
1371 unsigned sliceResultNumber =
1374 auto tilableOp = cast<TilingInterface>(originalOwner);
1378 if (tilableOp->getNumResults() > 1 &&
1379 failed(tilableOp.getIterationDomainTileFromResultTile(
1380 rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1381 iterDomainOffset, iterDomainSizes))) {
1396 for (
const auto &resultNumber : initNumberList) {
1397 if (resultNumber == sliceResultNumber) {
1398 offsetList.push_back(sliceOffset);
1399 sizesList.push_back(sliceSizes);
1401 assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1404 if (
failed(tilableOp.getResultTilePosition(
1405 rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1409 offsetList.push_back(offset);
1410 sizesList.push_back(sizes);
1416 if (
auto tiledDestStyleOp =
1417 dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1419 for (
const auto &&[index, newRegionArg] :
1421 auto destSlice = tensor::ExtractSliceOp::create(
1422 rewriter, loc, newRegionArg, offsetList[index], sizesList[index],
1425 generatedSlices.push_back(destSlice);
1426 unsigned resultNumber = initNumberList[index];
1428 tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1437 for (
const auto &&[index, resultNumber] :
llvm::enumerate(initNumberList)) {
1438 tiledResult.push_back(tiledOwner->getResult(resultNumber));
1439 tiledOffset.emplace_back(offsetList[index]);
1440 tiledSizes.emplace_back(sizesList[index]);
1446 newYieldValuesFn))) {
1449 return generatedSlices;
1463 explicit SliceTrackingListener(
1464 std::optional<FrozenRewritePatternSet>
patterns);
1465 SliceTrackingListener() =
default;
1474 void notifyOperationInserted(
Operation *op,
1481 void notifyOperationErased(
Operation *op)
override;
1488 std::deque<tensor::ExtractSliceOp> worklist;
1493 std::optional<FrozenRewritePatternSet>
patterns = std::nullopt;
1496 SliceTrackingListener::SliceTrackingListener(
1497 std::optional<FrozenRewritePatternSet> p) {
1504 if (
auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1505 worklist.push_back(slice);
1517 void SliceTrackingListener::notifyOperationInserted(
1519 auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1522 worklist.push_back(slice);
1528 void SliceTrackingListener::removeOp(
Operation *op) {
1529 if (!isa<tensor::ExtractSliceOp>(op))
1531 auto iter = worklist.begin();
1532 while (iter != worklist.end()) {
1537 if (iter == worklist.end())
1540 worklist.erase(iter);
1543 void SliceTrackingListener::notifyOperationErased(
Operation *op) {
1547 void SliceTrackingListener::notifyOperationReplaced(
Operation *op,
1563 : ForwardingListener(listener), replacements(replacements) {}
1565 void updateReplacementValues(
ValueRange origValues,
1569 for (
auto &[key, val] : replacements) {
1570 for (
auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) {
1579 ForwardingListener::notifyOperationReplaced(op, newOp);
1584 ForwardingListener::notifyOperationReplaced(op, values);
1585 updateReplacementValues(op->
getResults(), values);
1595 FailureOr<scf::SCFTileAndFuseResult>
1601 if (!consumer->getNumResults()) {
1603 consumer,
"invalid pattern for op with no results");
1609 FailureOr<scf::SCFTilingResult> tilingResult =
1612 if (
failed(tilingResult))
1614 tiledAndFusedOps.insert_range(tilingResult->tiledOps);
1617 for (
auto [origVal, replacement] :
1618 llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
1619 replacements[origVal] = replacement;
1623 auto &loops = tilingResult->loops;
1624 if (loops.empty()) {
1633 auto resetListener =
1634 llvm::make_scope_exit([&]() { rewriter.
setListener(previousListener); });
1635 ReplacementListener replaceListener(replacements, previousListener);
1645 struct WorklistItem {
1646 tensor::ExtractSliceOp candidateSlice;
1650 SliceTrackingListener sliceTracker =
1651 SliceTrackingListener(
options.cleanupPatterns);
1654 sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1658 while (!sliceTracker.worklist.empty()) {
1659 auto candidateSlice = sliceTracker.worklist.front();
1660 sliceTracker.worklist.pop_front();
1662 auto [fusableProducer, destinationInitArg] =
1665 if (!fusableProducer)
1668 std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1669 options.fusionControlFn(candidateSlice, fusableProducer,
1670 destinationInitArg.has_value());
1671 if (!controlFnResult)
1674 WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
1679 std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1687 if (worklistItem.controlFnResult.yieldProducerReplacement) {
1692 Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1693 FailureOr<SmallVector<Operation *>> newSlices =
1695 worklistItem.candidateSlice,
1696 fusedResult.value(), loops);
1699 fusableProducerOp,
"failed to replacement value for this "
1700 "operation from within the tiled loop");
1702 worklistCandidates.append(newSlices.value());
1703 for (
auto [index, result] :
1705 replacements[result] = loops.front()->getResult(
1706 loops.front()->getNumResults() -
1711 fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1712 fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1713 tiledAndFusedOps.insert(tiledAndFusedOp);
1716 if (
failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1731 static LogicalResult
1733 Value result = candidateSliceOp.getResult();
1735 if (!llvm::hasSingleElement(uses)) {
1736 LLVM_DEBUG(llvm::dbgs() <<
"Too many uses of the candidate slice op\n");
1739 OpOperand &operandUse = (*uses.begin());
1741 if (!isa<scf::YieldOp>(userOp)) {
1742 LLVM_DEBUG(llvm::dbgs()
1743 <<
"Expected scf.yield to be the only user, but got -> "
1748 LLVM_DEBUG(llvm::dbgs() <<
"Expected tensor.insert_slice and scf.yield to "
1749 "be in the same block\n");
1758 if (!isa<LoopLikeOpInterface>(loopOp))
1777 if (isa<tensor::ParallelInsertSliceOp>(userOp))
1780 if (loopOp->
getBlock() != userOp->getBlock())
1784 firstUserOfLoop = userOp;
1786 return firstUserOfLoop;
1827 static FailureOr<llvm::SetVector<Operation *>>
1829 bool reorderOperations) {
1831 if (
failed(firstUserOfLoop))
1837 options.omitBlockArguments =
true;
1838 bool includeLoopOp =
false;
1841 includeLoopOp =
true;
1851 assert(result.succeeded() &&
"expected a backward slice");
1855 if (!slice.empty()) {
1865 if (includeLoopOp || !reorderOperations)
1877 unsigned resultNumber) {
1878 if (!isa<LoopLikeOpInterface>(loopOp))
1883 Operation *consumerOp = opOperand.getOwner();
1885 if (!isa<TilingInterface>(consumerOp) ||
1886 !isa<DestinationStyleOpInterface>(consumerOp)) {
1893 if (loopBlock != consumerOp->
getBlock())
1900 FailureOr<llvm::SetVector<Operation *>> slice =
1906 if (!slice->empty()) {
1909 assert(succeeded(firstUserOfLoop) &&
"First user of loop is not found");
1910 for (
auto op : *slice) {
1925 static FailureOr<OpOperand *>
1927 tensor::InsertSliceOp candidateSliceOp,
1929 assert(!loops.empty() &&
"unexpected loops to be empty");
1932 if (containingOp != loops.back()) {
1935 "expected slice to be within body of inner-most loop");
1941 candidateSliceOp,
"expected passed loops to be perfectly nested.");
1946 Value sliceResult = candidateSliceOp.getResult();
1952 scf::ForOp topLevelForOp = cast<scf::ForOp>(loops.front().getOperation());
1959 static FailureOr<OpOperand *>
1961 tensor::ParallelInsertSliceOp candidateSliceOp,
1963 assert(!loops.empty() &&
"unexpected loops to be empty");
1965 if (loops.size() != 1) {
1967 candidateSliceOp,
"expected single surrounding scf.forall");
1969 auto forallOp = dyn_cast<scf::ForallOp>(loops.front().getOperation());
1972 candidateSliceOp,
"expected single surrounding scf.forall");
1976 Value sliceDest = candidateSliceOp.getDest();
1977 auto iterArg = dyn_cast<BlockArgument>(sliceDest);
1980 if (iterArg.getOwner()->getParentOp() != forallOp)
1983 unsigned resultNumber =
1984 forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
1995 assert(!loops.empty() &&
"unexpected empty loops");
1996 assert(!sliceOps.empty() &&
"unexpected empty list of candidate slices");
1998 for (
auto sliceOp : sliceOps) {
1999 FailureOr<OpOperand *> fusedOperand =
2001 .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2008 if (
failed(fusedOperand)) {
2011 if (!fusedOperands.empty() &&
2012 fusedOperand.value()->getOwner() != fusedOperands.front()->getOwner()) {
2014 fusedOperand.value()->getOwner(),
2015 "all candidate slices must be to the same consumer");
2017 fusedOperands.push_back(fusedOperand.value());
2019 return fusedOperands;
2022 template <
typename InsertSliceOpTy>
2024 InsertSliceOpTy sliceOp);
2027 tensor::InsertSliceOp
2028 cloneAsInsertSlice<tensor::InsertSliceOp>(
RewriterBase &rewriter,
2029 tensor::InsertSliceOp insertSliceOp) {
2030 return cast<tensor::InsertSliceOp>(
2031 rewriter.
clone(*insertSliceOp.getOperation()));
2035 tensor::InsertSliceOp cloneAsInsertSlice<tensor::ParallelInsertSliceOp>(
2036 RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp) {
2037 return tensor::InsertSliceOp::create(
2038 rewriter, insertSliceOp->getLoc(), insertSliceOp.getSource(),
2039 insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(),
2040 insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
2046 assert(!candidateSlices.empty() &&
2047 "unexpected empty list of slices to clone");
2049 for (
auto sliceOp : candidateSlices) {
2051 .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2054 clonedSlices.push_back(clonedOp);
2058 assert(0 &&
"unexpected slice type while cloning as insert slice");
2061 return clonedSlices;
2066 FailureOr<scf::SCFFuseConsumerOfSliceResult>
2070 if (candidateSlices.empty()) {
2073 "no candidate slices provided for consumer fusion");
2077 if (loops.empty()) {
2079 candidateSlices.front(),
2080 "cannot call tile and fuse consumer with an empty loop nest");
2083 if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
2084 llvm::all_of(candidateSlices,
2085 llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
2087 candidateSlices.front(),
2088 "candidates slices need to be all `tensor.extract_slice`s or "
2089 "`tensor.parallel_insert_slice`s");
2097 FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
2099 if (
failed(maybeConsumerOpOperand)) {
2101 "could not fetch consumer to fuse");
2103 std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
2104 consumerOp = consumerOpOperands.front()->getOwner();
2107 LoopLikeOpInterface outerMostLoop = loops.front();
2108 LoopLikeOpInterface innerMostLoop = loops.back();
2113 outerMostLoop,
"the first user of loop should not dominate any define "
2114 "of consumer operand(s)");
2120 auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
2123 "consumer op is not DPS operation");
2124 if (llvm::any_of(consumerOpOperands, [&](
OpOperand *opOperand) {
2125 return dstOp.isDpsInit(opOperand);
2129 "consumer op taking the result of scf.for as init is not supported");
2136 if (
failed(firstUserOfLoop)) {
2138 outerMostLoop,
"could not find the first user of outer most loop");
2140 rewriter.
moveOpBefore(outerMostLoop, *firstUserOfLoop);
2147 dyn_cast<tensor::ParallelInsertSliceOp>(candidateSlices.front())) {
2148 auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
2158 auto clonedConsumerOp = cast<TilingInterface>(rewriter.
clone(*consumerOp));
2160 llvm::map_to_vector(consumerOpOperands, [](
OpOperand *opOperand) {
2164 llvm::map_to_vector(operandNumbers, [&](
unsigned operandNum) {
2165 return &clonedConsumerOp->getOpOperand(operandNum);
2171 for (
auto [operandToReplace, clonedSliceOp] :
2172 llvm::zip_equal(clonedOpFusedOperandsList, clonedInsertSlices)) {
2173 operandToReplace->set(clonedSliceOp.getResult());
2179 FailureOr<TilingResult> tileAndFuseResult =
2181 clonedOpFusedOperandsList);
2182 if (
failed(tileAndFuseResult)) {
2186 auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
2187 for (
auto [operandNum, clonedSliceOp] :
2188 llvm::zip_equal(operandNumbers, clonedInsertSlices)) {
2190 clonedSliceOp.getSource());
2204 for (
auto candidateSliceOp : clonedInsertSlices) {
2212 candidateSliceOp,
"containingOp's result yield with stride");
2215 allOffsets.emplace_back(std::move(offsets));
2216 allSizes.emplace_back(std::move(sizes));
2226 if (
failed(clonedConsumerOp.getIterationDomainTileFromOperandTiles(
2227 rewriter, operandNumbers, allOffsets, allSizes, iterDomainOffsets,
2228 iterDomainSizes))) {
2231 "can't get iter domain position from input position");
2237 unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
2239 totalNumResultsOfConsumer);
2241 totalNumResultsOfConsumer);
2242 for (
auto [idx, v] :
llvm::enumerate(tiledConsumerOp->getResults())) {
2243 if (
failed(tiledConsumerOp.getResultTilePosition(
2244 rewriter, idx, iterDomainOffsets, iterDomainSizes,
2245 resultOffsets[idx], resultSizes[idx]))) {
2248 "can't get result domain position from iter domain position");
2254 if (
auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
2255 tiledConsumerOp.getOperation())) {
2257 for (
const auto &&[index, newRegionArg] :
2259 auto destSlice = tensor::ExtractSliceOp::create(
2260 rewriter, loc, newRegionArg, resultOffsets[index],
2266 auto dstNumber = index;
2268 tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
2277 for (
const auto &&[index, result] :
2279 tiledResult.push_back(result);
2280 tiledOffset.emplace_back(resultOffsets[index]);
2281 tiledSizes.emplace_back(resultSizes[index]);
2287 newYieldValuesFn))) {
2289 "unable to add new inits to nest loop");
2295 for (
auto &&[oldResult, newResult] :
2297 loops.front()->getResults().take_back(newInits.size()))) {
2302 rewriter.
eraseOp(clonedConsumerOp);
2305 llvm::map_to_vector(operandNumbers, [&](
unsigned operandNum) {
2306 return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum);
2309 std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands),
2310 std::move(tileAndFuseResult->tiledOps)};
2317 FailureOr<SmallVector<scf::ForOp>>
2319 TilingInterface op) {
2321 if (op->getNumResults() > 0) {
2323 op,
"unable to lower to loops operations with return values");
2330 for (
auto loopRange : domain) {
2337 auto loop = scf::ForOp::create(rewriter, op.getLoc(), offsetVal, sizeVal,
2339 loops.push_back(loop);
2340 ivs.push_back(loop.getInductionVar());
2343 if (
failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static SetVector< unsigned > getSanitizedReductionDims(ArrayRef< OpFoldResult > tileSizes, const scf::SCFTilingOptions &options)
Get the reduction dims that are tiled.
static LogicalResult checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp)
A utility function that checks whether the only use of the result of a tensor.insert_slice op is in a...
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
std::function< LogicalResult(RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, SmallVector< Value > &tiledValues, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> YieldTiledValuesFn
A function that allows returning additional yielded values during yieldTiledValuesAndReplace.
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, OpFoldResult numThreads, OpFoldResult iterationSize)
Returns true if the maximum tile offset tileSize * numThreads-1 is less than iterationSize.
static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc, scf::SCFTilingOptions::LoopType loopType, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ValueRange destinationTensors, ArrayRef< Attribute > mappingVector, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using the loop construct specifed in options.
static FailureOr< llvm::SetVector< Operation * > > checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, bool reorderOperations)
This utility currently checks whether the first userOp of loop is NOT before the last defineOp of con...
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, OpFoldResult offset, OpFoldResult tileSize)
Returns the bounded tile size given the current offset, loopRange and tileSize, i....
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Append the specified additional newInitOperands operands to the loops existing init operands (or simi...
static LogicalResult generateLoopNestUsingForOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ValueRange destinationTensors, YieldTiledValuesFn yieldTiledValuesFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.for operation.
static FailureOr< OpOperand * > getConsumerFromLoopUses(RewriterBase &rewriter, Operation *loopOp, unsigned resultNumber)
Fetches the OpOperand of the first valid user (and use) of the value val which implements TilingInter...
static FailureOr< MergeResult > mergeTilingResults(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, const SetVector< unsigned > &reductionDims, ValueRange partialResults)
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, ArrayRef< Range > iterationDomain, const scf::SCFTilingOptions &options)
Method to instantiate the tile sizes and/or number of threads specified by the user.
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< LoopLikeOpInterface > loops)
Return the untiled producer whose slice is used in a tiled consumer.
static bool tileDividesIterationDomain(Range loopRange)
Check if stride evenly divides the trip count size - offset.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes)
Function to return the bounds of the loops to be generated.
static Operation * cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *op, ValueRange newDestArgs)
Clones the operation and updates the destination if the operation implements the DestinationStyleOpIn...
static FailureOr< Operation * > getFirstUserOfLoop(Operation *loopOp)
An utility to get the first user of the given loopOp.
static LogicalResult addInitOperandsToLoopNest(RewriterBase &rewriter, MutableArrayRef< LoopLikeOpInterface > loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn)
Method to add new init values to a loop nest.
static FailureOr< SmallVector< OpOperand * > > getUntiledConsumerOperandsFromSlices(RewriterBase &rewriter, ArrayRef< Operation * > sliceOps, MutableArrayRef< LoopLikeOpInterface > loops)
A utility to fetch an untiled consumer of tensor.insert_slice/tensor.parallel_insert_slice.
static FailureOr< SmallVector< Value > > createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes, const SetVector< unsigned > &reductionDims)
static SmallVector< tensor::InsertSliceOp > cloneAsInsertSlices(RewriterBase &rewriter, ArrayRef< Operation * > candidateSlices)
static LogicalResult checkTileSizes(TilingInterface op, scf::SCFTilingOptions::LoopType loopType, ReductionTilingStrategy reductionStrategy, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Checks if any of the tiled loops are not parallel.
static tensor::InsertSliceOp cloneAsInsertSlice(RewriterBase &rewriter, InsertSliceOpTy sliceOp)
static FailureOr< OpOperand * > getUntiledConsumerFromSlice(RewriterBase &rewriter, tensor::InsertSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Fetch the untiled consumer of the outermost scf.for's result which is yielded by a tensor....
static LogicalResult verifyOptions(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options)
Verify the tile size options are set in a consistent manner.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ReductionTilingStrategy strategy, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, const llvm::SetVector< unsigned > &reductionDims)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes, const SetVector< unsigned > &reductionDims)
static SmallVector< OpFoldResult > getSplitReductionIvs(RewriterBase &rewriter, Location loc, ReductionTilingStrategy reductionStrategy, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes, const SetVector< unsigned > &reductionDims)
For the case of ReductionTilingStrategy::PartialReductionOuterParallel the PartialReductionOpInterfac...
static LogicalResult generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ArrayRef< Attribute > mappingVector, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.forall operation.
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
This class allows control over how the GreedyPatternRewriteDriver works.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class represents a saved insertion point.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation * getOwner() const
Returns the operation that owns this result.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< scf::SCFTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSizes)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
FailureOr< scf::SCFFuseConsumerOfSliceResult > tileAndFuseConsumerOfSlices(RewriterBase &rewriter, ArrayRef< Operation * > candidateSlices, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing consumer of a single slice by computing the slice of the consumer in-place f...
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< SmallVector< Operation * > > yieldReplacementForFusedProducer(RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef< LoopLikeOpInterface > loops, ArrayRef< unsigned > yieldResultNumber=ArrayRef< unsigned >{})
Reconstruct the fused producer from within the tiled-and-fused code.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
std::optional< SCFFuseProducerOfSliceResult > tileAndFuseProducerOfSlice(RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing producer of a single slice by computing the slice of the producer in-place.
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
FailureOr< TilingResult > replaceInsertSlicesWithTiledConsumer(OpBuilder &builder, ArrayRef< tensor::InsertSliceOp > sliceOps, ArrayRef< OpOperand * > consumerOperands)
Method to swap tensor.insert_slices with their consumers when the consumer implements the TilingInter...
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Include the generated interface declarations.
bool isPerfectlyNestedForLoops(MutableArrayRef< LoopLikeOpInterface > loops)
Check if the provided loops are perfectly nested for-loops.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
ReductionTilingStrategy
Tiling can be thought of as splitting a dimension into 2 and materializing the outer dimension as a l...
@ PartialReductionOuterReduction
@ PartialReductionOuterParallel
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
@ ExistingAndNewOps
Only pre-existing and newly created ops are processed.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
Container for result values of tiling.
Fuse the consumer candidateSlices by computing the required slice of the consumer in-place.
Fuse the producer of the source of candidateSliceOp by computing the required slice of the producer i...
SmallVector< Operation * > tiledOps
Control function to check if a slice needs to be fused or not, The control function receives 1) the s...
Options used to control tile + fuse.
Transformation information returned after tile and fuse.
Options to use to control tiling.
SCFTileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes to use for each loop.
SCFTilingOptions & setNumThreads(ArrayRef< OpFoldResult > numThreads)
Convenience function to set the numThreadsComputationFunction to a function that computes num threads...
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
LoopType
Specify which loop construct to use for tile and fuse.
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.