31 #include "llvm/ADT/ScopeExit.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/Debug.h"
36 #define DEBUG_TYPE "tile-using-interface"
43 auto tileSizes = llvm::to_vector(ts);
52 assert(!numThreadsComputationFunction &&
"num tiles already set");
53 auto numThreads = llvm::to_vector(nt);
64 size_t iterationDomainSize) {
66 if (filledVector.size() < iterationDomainSize) {
67 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
68 filledVector.append(range.begin(), range.end());
70 if (filledVector.size() > iterationDomainSize)
71 filledVector.resize(iterationDomainSize);
84 if (
options.numThreadsComputationFunction &&
87 loc,
"number of threads can only by specified when loop type is "
88 "set to use `scf.forall`");
92 if (!
options.interchangeVector.empty()) {
95 loc,
"invalid interchange vector, not a permutation of the entire "
110 size_t numLoops = iterationDomain.size();
113 if (
options.numThreadsComputationFunction) {
114 numThreads =
options.numThreadsComputationFunction(rewriter, op);
115 numThreads.resize(numLoops, zero);
118 if (
options.tileSizeComputationFunction) {
119 tileSizes =
options.tileSizeComputationFunction(rewriter, op);
120 tileSizes.resize(numLoops, zero);
121 return {tileSizes, numThreads};
133 tileSizes.resize(numLoops, zero);
134 for (
auto [index, range, nt] :
140 rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
142 tileSizes.resize(numLoops, zero);
143 return {tileSizes, numThreads};
150 assert(
options.tileSizeComputationFunction &&
151 "expected tile sizes to be specified");
152 tileSizes =
options.tileSizeComputationFunction(rewriter, op);
153 tileSizes.resize(numLoops, zero);
155 return {tileSizes, numThreads};
162 auto iterators = op.getLoopIteratorTypes();
163 assert(iterators.size() == tileSizes.size() &&
164 "expected as many tile size values as number of loops");
165 assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
166 "when specified, expected number of threads to use for each loop");
168 for (
auto [index, iterator, tileSize] :
172 if (!numThreads.empty()) {
173 if (std::optional<int64_t> constNumThreads =
175 if (constNumThreads.value() > 1 &&
176 iterator != utils::IteratorType::parallel) {
177 op.emitWarning() <<
"tiling is not thread safe at axis #" << index;
184 if (constTileSize.value() > 0 &&
185 iterator != utils::IteratorType::parallel) {
186 op.emitWarning() <<
"tiling is not thread safe at axis #" << index;
203 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
212 if (ts && ts.value() == 1)
239 if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
241 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
253 int materializedLoopNum = 0;
255 if (!numThreads.empty()) {
260 offsetExpr = d0 + d1 * s0;
261 residualTileSizeExpr = s1 - (d0 + d1 * s0);
263 for (
auto [nt, tileSize, loopRange] :
264 llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
269 offsets.push_back(loopRange.offset);
270 sizes.push_back(loopRange.size);
274 Value iv = ivs[materializedLoopNum++];
276 rewriter, loc, offsetExpr,
279 rewriter, loc, residualTileSizeExpr,
280 {loopRange.offset, nt, tileSize, loopRange.size});
286 {offset, loopRange.size});
290 {sizeMinusOffsetPerThread, tileSize});
306 rewriter, loc, maxMap, {rewriter.
getIndexAttr(0), size});
309 offsets.push_back(offset);
310 sizes.push_back(size);
312 return {offsets, sizes};
314 for (
auto [tileSize, loopRange] :
315 llvm::zip_equal(tileSizes, iterationDomain)) {
320 offsets.push_back(loopRange.offset);
321 sizes.push_back(loopRange.size);
325 Value iv = ivs[materializedLoopNum++];
327 offsets.push_back(offset);
330 sizes.push_back(size);
332 return {offsets, sizes};
342 for (
auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
346 lbs.push_back(loopRange.offset);
347 ubs.push_back(loopRange.size);
348 steps.push_back(tileSize);
350 return {lbs, ubs, steps};
380 if (newDestArgs.empty())
382 if (
auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
383 destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
401 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
402 assert(loopRanges.size() == tileSizes.size() &&
403 "expected as many tile sizes as loop ranges");
407 std::tie(lbs, ubs, steps) =
417 for (
auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
419 rewriter.
create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
422 loops.push_back(loop);
423 ivs.push_back(loop.getInductionVar());
425 destinationTensors = loop.getRegionIterArgs();
430 if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
431 tiledResults, resultOffsets, resultSizes))) {
433 loc,
"failed to generate inner tile loop body");
438 assert(tiledResults.size() == destinationTensors.size() &&
439 "Number of results of body should be equal to number of iter args");
443 for (
auto [tiledValue, destinationTensor, resultOffset, resultSize] :
444 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
448 auto insertSlice = rewriter.
create<tensor::InsertSliceOp>(
449 loc, tiledValue, destinationTensor, resultOffset, resultSize,
451 yieldedValues.push_back(insertSlice);
453 rewriter.
create<scf::YieldOp>(loc, yieldedValues);
456 for (
auto [outerLoop, innerLoop] :
460 cast<scf::ForOp>(outerLoop.getOperation()).getBody());
461 rewriter.
create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
482 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
483 assert(loopRanges.size() == tileSizes.size() &&
484 "expected as many tile sizes as loop ranges");
487 sizes(loopRanges.size());
489 std::optional<ArrayAttr> mappingAttr;
490 if (!mappingVector.empty())
493 scf::ForallOp forallOp;
494 bool useNumThreads = !numThreads.empty();
499 for (
auto nt : numThreads) {
502 nonZeroNumThreads.push_back(nt);
504 forallOp = rewriter.
create<scf::ForallOp>(loc, nonZeroNumThreads,
505 destinationTensors, mappingAttr);
508 std::tie(lbs, ubs, steps) =
510 forallOp = rewriter.
create<scf::ForallOp>(loc, lbs, ubs, steps,
511 destinationTensors, mappingAttr);
513 loops.push_back(forallOp);
516 destinationTensors = forallOp.getRegionOutArgs();
520 if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
521 destinationTensors, tiledResults, resultOffsets,
526 for (
auto [tiledValue, destinationTensor, resultOffset, resultSize] :
527 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
532 rewriter.
create<tensor::ParallelInsertSliceOp>(
533 loc, tiledValue, destinationTensor, resultOffset, resultSize,
559 return tiledBodyFn(rewriter, loc,
ValueRange{}, destinationTensors,
560 tiledResults, resultOffsets, resultSizes);
564 destinationTensors, tiledBodyFn, loops);
568 rewriter, loc, loopRanges, tileSizes, numThreads,
options.mappingVector,
569 destinationTensors, tiledBodyFn, loops);
574 static FailureOr<SmallVector<Value>>
580 switch (
options.reductionStrategy) {
587 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
590 op,
"PartialReductionOuterReduction tiling strategy is only supported"
591 "for operations implementing PartialReductionOpInterface");
597 for (
auto [idx, iteratorType] :
599 if (iteratorType == utils::IteratorType::reduction)
600 reductionDims.push_back(idx);
602 return redOp.generateInitialTensorForPartialReduction(
603 rewriter, loc, tileSizes, reductionDims);
607 "unhandled reduction tiling strategy");
611 static FailureOr<TilingResult>
616 switch (
options.reductionStrategy) {
618 return op.getTiledImplementation(rewriter, offsets, sizes);
621 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
624 op,
"PartialReductionOuterReduction tiling strategy is only "
625 "supported for operations "
626 "implementing PartialReductionOpInterface");
632 for (
auto [idx, iteratorType] :
634 if (iteratorType == utils::IteratorType::reduction)
635 reductionDims.push_back(idx);
637 return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg,
638 offsets, sizes, reductionDims);
642 "unhandled reduction tiling strategy");
654 switch (
options.reductionStrategy) {
656 return op.getResultTilePosition(rewriter, index, offsets, sizes,
657 resultOffset, resultSize);
660 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
663 op,
"PartialReductionOuterReduction tiling strategy is only supported"
664 "for operations implementing PartialReductionOpInterface");
670 for (
auto [idx, iteratorType] :
672 if (iteratorType == utils::IteratorType::reduction)
673 reductionDims.push_back(idx);
675 return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
676 resultOffset, resultSize,
681 "unhandled reduction tiling strategy");
685 static FailureOr<MergeResult>
689 switch (
options.reductionStrategy) {
695 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
698 op,
"PartialReductionOuterReduction tiling strategy is only "
699 "supported for operations "
700 "implementing PartialReductionOpInterface");
706 for (
auto [idx, iteratorType] :
708 if (iteratorType == utils::IteratorType::reduction)
709 reductionDims.push_back(idx);
711 return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
716 "unhandled reduction tiling strategy");
727 template <
typename LoopType>
728 FailureOr<LoopLikeOpInterface>
737 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
744 auto inits = llvm::to_vector(loopOp.getInitArgs());
745 inits.append(newInitOperands.begin(), newInitOperands.end());
746 auto newLoop = rewriter.
create<scf::ForOp>(
747 loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
751 Block *loopBody = loopOp.getBody();
752 Block *newLoopBody = newLoop.getBody();
754 loopBody, newLoopBody,
755 newLoopBody->
getArguments().take_front(loopBody->getNumArguments()));
757 auto yieldOp = cast<scf::YieldOp>(newLoopBody->
getTerminator());
763 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
764 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
765 newRegionIterArgs, tiledValues, resultOffsets,
772 for (
auto [tiledValue, regionIterArg, resultOffset, resultSize] :
773 llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
777 Value insert = rewriter.
create<tensor::InsertSliceOp>(
778 yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
780 newYieldValues.push_back(insert);
785 newLoop->getResults().take_front(loopOp.getNumResults()));
786 return cast<LoopLikeOpInterface>(newLoop.getOperation());
791 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
797 auto inits = llvm::to_vector(loopOp.getOutputs());
798 inits.append(newInitOperands.begin(), newInitOperands.end());
799 auto newLoop = rewriter.
create<scf::ForallOp>(
800 loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
801 loopOp.getMixedStep(), inits, loopOp.getMapping(),
805 Block *loopBody = loopOp.getBody();
806 Block *newLoopBody = newLoop.getBody();
808 loopBody, newLoopBody,
809 newLoopBody->
getArguments().take_front(loopBody->getNumArguments()));
811 auto terminator = cast<scf::InParallelOp>(newLoopBody->
getTerminator());
816 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
817 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
818 regionIterArgs, tiledValues, resultOffsets,
822 "failed to get yielded tiled values");
828 for (
auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
829 tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
832 rewriter.
create<tensor::ParallelInsertSliceOp>(
833 terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
838 newLoop->getResults().take_front(loopOp.getNumResults()));
839 return cast<LoopLikeOpInterface>(newLoop.getOperation());
849 loopLikeOp.getOperation())
850 .Case<scf::ForOp, scf::ForallOp>(
851 [&](
auto loopOp) -> FailureOr<LoopLikeOpInterface> {
853 loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
855 .Default([&](
auto loopOp) -> FailureOr<LoopLikeOpInterface> {
875 for (
auto &loop : loops.drop_back()) {
879 auto forLoop = cast<scf::ForOp>(loop.getOperation());
883 newInits.append(newInitValues.begin(), newInitValues.end());
884 auto newLoop = rewriter.
create<scf::ForOp>(
885 forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
886 forLoop.getStep(), newInits,
891 sourceBlockArgs.push_back(newLoop.getInductionVar());
892 auto newRegionIterArgs = newLoop.getRegionIterArgs();
893 sourceBlockArgs.append(
894 newRegionIterArgs.begin(),
895 std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
896 rewriter.
mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
898 forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
900 ivs.push_back(newLoop.getInductionVar());
901 newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
905 LoopLikeOpInterface innerMostLoop = loops.back();
906 FailureOr<LoopLikeOpInterface> newInnerMostLoop =
908 getNewTiledYieldsFn);
910 if (failed(newInnerMostLoop))
911 return innerMostLoop.emitOpError(
"failed to return additional yields");
912 loops.back() = newInnerMostLoop.value();
916 for (
auto [outerLoop, innerLoop] :
917 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
919 auto outerForLoop = cast<scf::ForOp>(outerLoop);
920 auto outerLoopYield =
921 cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
923 llvm::to_vector(outerLoopYield.getOperands());
925 innerLoop->getResults().take_back(newInitValues.size());
926 newYields.append(additionalYields.begin(), additionalYields.end());
935 FailureOr<scf::SCFTilingResult>
950 std::tie(tileSizes, numThreads) =
962 if (!
options.interchangeVector.empty()) {
964 iterationDomain.size());
966 "expected interchange vector to be a permutation");
970 if (!numThreads.empty())
974 FailureOr<TilingResult> tilingResult;
986 rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
990 if (!interchangeVector.empty()) {
999 auto clonedOp = cast<TilingInterface>(
1006 tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
1016 if (failed(tilingResult)) {
1018 return op.emitOpError(
"faild to tile operation");
1026 for (
auto [index, tiledValue] :
1028 tiledResults.push_back(tiledValue);
1031 sizes, resultOffset, resultSize,
1033 for (
auto op : tilingResult->tiledOps) {
1037 op,
"failed to get slice of result produced");
1039 resultOffsets.emplace_back(std::move(resultOffset));
1040 resultSizes.emplace_back(std::move(resultSize));
1047 FailureOr<SmallVector<Value>> maybeInits =
1049 if (failed(maybeInits)) {
1051 op,
"unable to create initial tensors for tiling");
1058 tileSizes, numThreads, initTensors,
1059 innerYieldTiledValuesFn, loops)))
1060 return op.emitOpError(
"failed to generate tiling loops");
1061 assert(succeeded(tilingResult) &&
1062 "expected tiling result to be computed after loop generation");
1065 if (loops.empty()) {
1068 partialResults = tilingResult->tiledValues;
1070 partialResults = llvm::map_to_vector(loops.front()->getResults(),
1074 FailureOr<MergeResult> mergeResult =
1076 if (failed(mergeResult)) {
1078 op,
"Failed to merge partial results from tiling");
1082 mergeResult.value(),
1083 tilingResult->generatedSlices};
1086 FailureOr<scf::SCFTilingResult>
1088 PartialReductionOpInterface op,
1093 PartialReductionOuterReduction);
1094 options.setTileSizes(tileSizes);
1096 TilingInterface tilingInterfaceOp =
1097 dyn_cast<TilingInterface>(op.getOperation());
1098 if (!tilingInterfaceOp) {
1101 "Operation implementing PartialReductionOpInterface should implement "
1118 static std::tuple<OpResult, std::optional<OpOperand *>>
1121 std::optional<OpOperand *> destinationIterArg;
1122 auto loopIt = loops.rbegin();
1123 while (
auto iterArg = dyn_cast<BlockArgument>(source->
get())) {
1124 auto loop = *loopIt;
1125 if (iterArg.getOwner()->getParentOp() != loop)
1127 source = loop.getTiedLoopInit(iterArg);
1130 if (loopIt == loops.rend())
1131 destinationIterArg = source;
1132 return {dyn_cast<OpResult>(source->
get()), destinationIterArg};
1137 std::optional<scf::SCFFuseProducerOfSliceResult>
1139 RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1143 auto [fusableProducer, destinationInitArg] =
1146 if (!fusableProducer)
1147 return std::nullopt;
1148 unsigned resultNumber = fusableProducer.getResultNumber();
1156 Operation *fusableProducerOp = fusableProducer.getOwner();
1157 if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
1159 rewriter, fusableProducerOp->
getLoc(), fusableProducerOp,
1160 origDestinationTensors)))
1161 return std::nullopt;
1163 clonedOpDestinationTensors = origDestinationTensors;
1164 if (destinationInitArg &&
1165 isa<DestinationStyleOpInterface>(fusableProducerOp)) {
1169 clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1173 rewriter, fusableProducerOp, clonedOpDestinationTensors);
1178 llvm::to_vector(candidateSliceOp->getOperands());
1179 candidateSliceOpOperands[0] = clonedProducerOp->
getResult(resultNumber);
1180 tensor::ExtractSliceOp clonedCandidateSliceOp =
1182 candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1185 FailureOr<TilingResult> tileAndFuseResult =
1187 rewriter, clonedCandidateSliceOp,
1188 clonedProducerOp->
getResult(resultNumber));
1189 if (failed(tileAndFuseResult))
1190 return std::nullopt;
1194 tileAndFuseResult->tiledValues[0]);
1195 rewriter.
eraseOp(clonedCandidateSliceOp);
1196 rewriter.
eraseOp(clonedProducerOp);
1241 if (destinationInitArg &&
1242 isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
1244 ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1245 .set(origDestinationTensors[resultNumber]);
1248 fusableProducer, tileAndFuseResult->tiledValues[0],
1249 tileAndFuseResult->
tiledOps, tileAndFuseResult->generatedSlices};
1254 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1262 *tiledOwner = fusedProducerInfo.
tiledOps[0];
1267 yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
1269 : llvm::to_vector(yieldResultNumber);
1271 for (
const auto &resultNumber : initNumberList) {
1273 rewriter, loc, originalOwner->
getResult(resultNumber));
1274 if (succeeded(initValue)) {
1275 initValueList.push_back(initValue.value());
1291 sliceSizes = sliceOp.getMixedSizes();
1294 if (llvm::any_of(sliceOp.getMixedStrides(), [](
OpFoldResult ofr) {
1295 return !isConstantIntValue(ofr, 1);
1299 unsigned sliceResultNumber =
1302 auto tilableOp = cast<TilingInterface>(originalOwner);
1306 if (tilableOp->getNumResults() > 1 &&
1307 failed(tilableOp.getIterationDomainTileFromResultTile(
1308 rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1309 iterDomainOffset, iterDomainSizes))) {
1324 for (
const auto &resultNumber : initNumberList) {
1325 if (resultNumber == sliceResultNumber) {
1326 offsetList.push_back(sliceOffset);
1327 sizesList.push_back(sliceSizes);
1329 assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1332 if (failed(tilableOp.getResultTilePosition(
1333 rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1337 offsetList.push_back(offset);
1338 sizesList.push_back(sizes);
1344 if (
auto tiledDestStyleOp =
1345 dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1347 for (
const auto &&[index, newRegionArg] :
1349 auto destSlice = rewriter.
create<tensor::ExtractSliceOp>(
1350 loc, newRegionArg, offsetList[index], sizesList[index],
1353 generatedSlices.push_back(destSlice);
1354 unsigned resultNumber = initNumberList[index];
1356 tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1365 for (
const auto &&[index, resultNumber] :
llvm::enumerate(initNumberList)) {
1366 tiledResult.push_back(tiledOwner->getResult(resultNumber));
1367 tiledOffset.emplace_back(offsetList[index]);
1368 tiledSizes.emplace_back(sizesList[index]);
1374 newYieldValuesFn))) {
1377 return generatedSlices;
1391 explicit SliceTrackingListener(
1392 std::optional<FrozenRewritePatternSet>
patterns);
1393 SliceTrackingListener() =
default;
1402 void notifyOperationInserted(
Operation *op,
1409 void notifyOperationErased(
Operation *op)
override;
1416 std::deque<tensor::ExtractSliceOp> worklist;
1421 std::optional<FrozenRewritePatternSet>
patterns = std::nullopt;
1424 SliceTrackingListener::SliceTrackingListener(
1425 std::optional<FrozenRewritePatternSet> p) {
1432 if (
auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1433 worklist.push_back(slice);
1445 void SliceTrackingListener::notifyOperationInserted(
1447 auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1450 worklist.push_back(slice);
1456 void SliceTrackingListener::removeOp(
Operation *op) {
1457 if (!isa<tensor::ExtractSliceOp>(op))
1459 auto iter = worklist.begin();
1460 while (iter != worklist.end()) {
1465 if (iter == worklist.end())
1468 worklist.erase(iter);
1471 void SliceTrackingListener::notifyOperationErased(
Operation *op) {
1475 void SliceTrackingListener::notifyOperationReplaced(
Operation *op,
1491 : ForwardingListener(listener), replacements(replacements) {}
1493 void updateReplacementValues(
ValueRange origValues,
1497 for (
auto &[key, val] : replacements) {
1498 for (
auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) {
1507 ForwardingListener::notifyOperationReplaced(op, newOp);
1512 ForwardingListener::notifyOperationReplaced(op, values);
1513 updateReplacementValues(op->
getResults(), values);
1523 FailureOr<scf::SCFTileAndFuseResult>
1529 if (!consumer->getNumResults()) {
1531 consumer,
"invalid pattern for op with no results");
1536 llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
1538 FailureOr<scf::SCFTilingResult> tilingResult =
1541 if (failed(tilingResult))
1543 for (
auto *tiledOp : tilingResult->tiledOps)
1544 tiledAndFusedOps.insert(tiledOp);
1547 for (
auto [origVal, replacement] : llvm::zip_equal(
1548 consumer->getResults(), tilingResult->mergeResult.replacements)) {
1549 replacements[origVal] = replacement;
1553 auto &loops = tilingResult->loops;
1554 if (loops.empty()) {
1563 auto resetListener =
1564 llvm::make_scope_exit([&]() { rewriter.
setListener(previousListener); });
1565 ReplacementListener replaceListener(replacements, previousListener);
1575 struct WorklistItem {
1576 tensor::ExtractSliceOp candidateSlice;
1580 SliceTrackingListener sliceTracker =
1581 SliceTrackingListener(
options.cleanupPatterns);
1584 sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1588 while (!sliceTracker.worklist.empty()) {
1589 auto candidateSlice = sliceTracker.worklist.front();
1590 sliceTracker.worklist.pop_front();
1592 auto [fusableProducer, destinationInitArg] =
1595 if (!fusableProducer)
1598 std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1599 options.fusionControlFn(candidateSlice, fusableProducer,
1600 destinationInitArg.has_value());
1601 if (!controlFnResult)
1604 WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
1609 std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1617 if (worklistItem.controlFnResult.yieldProducerReplacement) {
1622 Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1623 FailureOr<SmallVector<Operation *>> newSlices =
1625 worklistItem.candidateSlice,
1626 fusedResult.value(), loops);
1627 if (failed(newSlices)) {
1629 fusableProducerOp,
"failed to replacement value for this "
1630 "operation from within the tiled loop");
1632 worklistCandidates.append(newSlices.value());
1633 for (
auto [index, result] :
1635 replacements[result] = loops.front()->getResult(
1636 loops.front()->getNumResults() -
1641 fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1642 fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1643 tiledAndFusedOps.insert(tiledAndFusedOp);
1646 if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1661 static LogicalResult
1663 Value result = candidateSliceOp.getResult();
1665 if (!llvm::hasSingleElement(uses)) {
1666 LLVM_DEBUG(llvm::dbgs() <<
"Too many uses of the candidate slice op\n");
1669 OpOperand &operandUse = (*uses.begin());
1671 if (!isa<scf::YieldOp>(userOp)) {
1672 LLVM_DEBUG(llvm::dbgs()
1673 <<
"Expected scf.yield to be the only user, but got -> "
1678 LLVM_DEBUG(llvm::dbgs() <<
"Expected tensor.insert_slice and scf.yield to "
1679 "be in the same block\n");
1688 if (!isa<LoopLikeOpInterface>(loopOp))
1707 if (isa<tensor::ParallelInsertSliceOp>(userOp))
1710 if (loopOp->
getBlock() != userOp->getBlock())
1714 firstUserOfLoop = userOp;
1716 return firstUserOfLoop;
1757 static FailureOr<llvm::SetVector<Operation *>>
1759 bool reorderOperations) {
1761 if (failed(firstUserOfLoop))
1767 options.omitBlockArguments =
true;
1768 bool includeLoopOp =
false;
1771 includeLoopOp =
true;
1783 if (!slice.empty()) {
1793 if (includeLoopOp || !reorderOperations)
1805 unsigned resultNumber) {
1806 if (!isa<LoopLikeOpInterface>(loopOp))
1811 Operation *consumerOp = opOperand.getOwner();
1813 if (!isa<TilingInterface>(consumerOp) ||
1814 !isa<DestinationStyleOpInterface>(consumerOp)) {
1821 if (loopBlock != consumerOp->
getBlock())
1828 FailureOr<llvm::SetVector<Operation *>> slice =
1834 if (!slice->empty()) {
1837 assert(succeeded(firstUserOfLoop) &&
"First user of loop is not found");
1838 for (
auto op : *slice) {
1867 auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp());
1870 auto isForOpYieldResultOfInnerLoop =
1871 [](scf::ForOp outerLoop) -> LogicalResult {
1872 Block *body = outerLoop.getBody();
1876 auto innerForOp = dyn_cast<scf::ForOp>(body->
front());
1880 return success(innerForOp->getNumResults() == yieldOp->getNumOperands());
1883 while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) {
1884 nestLoops.push_back(outerLoop);
1885 outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp());
1888 return {nestLoops.rbegin(), nestLoops.rend()};
1895 static FailureOr<OpOperand *>
1897 tensor::InsertSliceOp candidateSliceOp) {
1900 Value sliceResult = candidateSliceOp.getResult();
1906 auto forOp = dyn_cast<scf::ForOp>(containingOp);
1916 static FailureOr<OpOperand *>
1918 tensor::ParallelInsertSliceOp candidateSliceOp) {
1920 Value sliceDest = candidateSliceOp.getDest();
1921 auto iterArg = dyn_cast<BlockArgument>(sliceDest);
1928 auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
1931 unsigned resultNumber =
1932 forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
1940 static FailureOr<OpOperand *>
1942 if (
auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1944 }
else if (
auto parallelInsertSlice =
1945 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1954 FailureOr<scf::SCFFuseConsumerOfSliceResult>
1957 if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
1961 bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
1965 FailureOr<OpOperand *> maybeConsumerOpOperand =
1967 if (failed(maybeConsumerOpOperand)) {
1969 "could not fetch consumer to fuse");
1971 OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
1974 unsigned resultNumber = 0;
1975 if (
auto producerResult = dyn_cast<OpResult>(consumerOpOperand->
get())) {
1976 resultNumber = producerResult.getResultNumber();
1979 consumerOp,
"consumer op's operand doesn't seem to be an OpResult");
1986 LoopLikeOpInterface innerMostLoop =
1989 if (isInsertSliceOp) {
1990 nestedLoops = llvm::map_to_vector(
1992 cast<scf::ForOp>(innerMostLoop.getOperation())),
1993 [](scf::ForOp forOp) {
1994 return cast<LoopLikeOpInterface>(forOp.getOperation());
1997 nestedLoops = {innerMostLoop};
2000 LoopLikeOpInterface outerMostLoop = nestedLoops.front();
2005 outerMostLoop,
"the first user of loop should not dominate any define "
2006 "of consumer operand(s)");
2012 auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
2015 "consumer op is not DPS operation");
2017 llvm::map_to_vector(dstOp.getDpsInits(), [](
Value v) { return v; });
2018 if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
2021 "consumer op taking the result of scf.for as init is not supported");
2025 Location loc = outerMostLoop->getLoc();
2030 if (failed(firstUserOfLoop)) {
2032 outerMostLoop,
"could not find the first user of outer most loop");
2034 rewriter.
moveOpBefore(outerMostLoop, *firstUserOfLoop);
2040 tensor::InsertSliceOp clonedInsertSliceOp;
2042 dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
2043 auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
2045 clonedInsertSliceOp = rewriter.
create<tensor::InsertSliceOp>(
2046 loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
2047 sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
2050 clonedInsertSliceOp =
2051 cast<tensor::InsertSliceOp>(rewriter.
clone(*candidateSliceOp));
2055 auto clonedConsumerOp = cast<TilingInterface>(rewriter.
clone(*consumerOp));
2059 OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
2061 operandToReplace.
set(clonedInsertSliceOp.getResult());
2067 cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
2068 FailureOr<TilingResult> tileAndFuseResult =
2070 rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
2071 if (failed(tileAndFuseResult)) {
2074 auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
2076 clonedInsertSliceOp.getSource());
2097 candidateSliceOp,
"containingOp's result yield with stride");
2107 if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
2108 rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
2109 iterDomainSizes))) {
2112 "can't get iter domain position from input position");
2118 unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
2120 totalNumResultsOfConsumer);
2122 totalNumResultsOfConsumer);
2123 for (
auto [idx, v] :
llvm::enumerate(tiledConsumerOp->getResults())) {
2124 if (failed(tiledConsumerOp.getResultTilePosition(
2125 rewriter, idx, iterDomainOffsets, iterDomainSizes,
2126 resultOffsets[idx], resultSizes[idx]))) {
2129 "can't get result domain position from iter domain position");
2135 if (
auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
2136 tiledConsumerOp.getOperation())) {
2138 for (
const auto &&[index, newRegionArg] :
2140 auto destSlice = rewriter.
create<tensor::ExtractSliceOp>(
2141 loc, newRegionArg, resultOffsets[index], resultSizes[index],
2146 auto dstNumber = index;
2148 tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
2157 for (
const auto &&[index, result] :
2159 tiledResult.push_back(result);
2160 tiledOffset.emplace_back(resultOffsets[index]);
2161 tiledSizes.emplace_back(resultSizes[index]);
2167 newYieldValuesFn))) {
2169 "unable to add new inits to nest loop");
2175 for (
auto &&[oldResult, newResult] : llvm::zip(
2177 nestedLoops.front()->getResults().take_back(newInits.size()))) {
2182 rewriter.
eraseOp(clonedConsumerOp);
2186 &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
2187 tileAndFuseResult->tiledOps};
2194 FailureOr<SmallVector<scf::ForOp>>
2196 TilingInterface op) {
2198 if (op->getNumResults() > 0) {
2200 op,
"unable to lower to loops operations with return values");
2207 for (
auto loopRange : domain) {
2214 auto loop = rewriter.
create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
2216 loops.push_back(loop);
2217 ivs.push_back(loop.getInductionVar());
2220 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static LogicalResult verifyTileSizeOptions(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options)
Verify the tile size options are set in a consistent manner.
static LogicalResult checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp)
A utility function that checks whether the only use of the result of a tensor.insert_slice op is in a...
std::function< LogicalResult(RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, SmallVector< Value > &tiledValues, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> YieldTiledValuesFn
A function that allows returning additional yielded values during yieldTiledValuesAndReplace.
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, OpFoldResult numThreads, OpFoldResult iterationSize)
Returns true if the maximum tile offset tileSize * numThreads-1 is less than iterationSize.
static LogicalResult getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize, const scf::SCFTilingOptions &options)
static FailureOr< llvm::SetVector< Operation * > > checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, bool reorderOperations)
This utility currently checks whether the first userOp of loop is NOT before the last defineOp of con...
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, OpFoldResult offset, OpFoldResult tileSize)
Returns the bounded tile size given the current offset, loopRange and tileSize, i....
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Append the specified additional newInitOperands operands to the loops existing init operands (or simi...
static LogicalResult generateLoopNestUsingForOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ValueRange destinationTensors, YieldTiledValuesFn yieldTiledValuesFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.for operation.
static FailureOr< OpOperand * > getConsumerFromLoopUses(RewriterBase &rewriter, Operation *loopOp, unsigned resultNumber)
Fetches the OpOperand of the first valid user (and use) of the value val which implements TilingInter...
static FailureOr< OpOperand * > getUntiledConsumerFromSlice(RewriterBase &rewriter, tensor::InsertSliceOp candidateSliceOp)
Fetch the untiled consumer of a scf.for's result which is yielded by a tensor.insert_slice.
static void checkSafeToTileToForall(TilingInterface op, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Checks if any of the tiled loops are not parallel.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, ArrayRef< Range > iterationDomain, const scf::SCFTilingOptions &options)
Method to instantiate the tile sizes and/or number of threads specified by the user.
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< LoopLikeOpInterface > loops)
Return the untiled producer whose slice is used in a tiled consumer.
static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using the loop construct specifed in options.
static SmallVector< scf::ForOp > getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop)
Find the perfectly nested loops outside of given loop(included) sorted from outer to inner.
static bool tileDividesIterationDomain(Range loopRange)
Check if stride evenly divides the trip count size - offset.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes)
Function to return the bounds of the loops to be generated.
static Operation * cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *op, ValueRange newDestArgs)
Clones the operation and updates the destination if the operation implements the DestinationStyleOpIn...
static FailureOr< Operation * > getFirstUserOfLoop(Operation *loopOp)
An utility to get the first user of the given loopOp.
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, const scf::SCFTilingOptions &options)
static LogicalResult addInitOperandsToLoopNest(RewriterBase &rewriter, MutableArrayRef< LoopLikeOpInterface > loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn)
Method to add new init values to a loop nest.
static FailureOr< MergeResult > mergeTilingResults(RewriterBase &rewriter, TilingInterface op, ValueRange partialResults, const scf::SCFTilingOptions &options)
static FailureOr< SmallVector< Value > > createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, ArrayRef< OpFoldResult > tileSizes, const scf::SCFTilingOptions &options)
static LogicalResult generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ArrayRef< Attribute > mappingVector, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.forall operation.
Base type for affine expression.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
A class for computing basic dominance information.
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
This class allows control over how the GreedyPatternRewriteDriver works.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class represents a saved insertion point.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation * getOwner() const
Returns the operation that owns this result.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
FailureOr< scf::SCFFuseConsumerOfSliceResult > tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp)
Implementation of fusing consumer of a single slice by computing the slice of the consumer in-place f...
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< SmallVector< Operation * > > yieldReplacementForFusedProducer(RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef< LoopLikeOpInterface > loops, ArrayRef< unsigned > yieldResultNumber=ArrayRef< unsigned >{})
Reconstruct the fused producer from within the tiled-and-fused code.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
FailureOr< scf::SCFTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
std::optional< SCFFuseProducerOfSliceResult > tileAndFuseProducerOfSlice(RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing producer of a single slice by computing the slice of the producer in-place.
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
FailureOr< TilingResult > replaceInsertSliceWithTiledConsumer(OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp, OpOperand &consumerOp)
Method to swap an tensor.insert_slice with its consumer when the consumer implements the TilingInterf...
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
const FrozenRewritePatternSet GreedyRewriteConfig config
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
@ ExistingAndNewOps
Only pre-existing and newly created ops are processed.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Container for the result of merge operation of tiling.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A listener that forwards all notifications to another listener.
Container for result values of tiling.
Fuse the consumer of the source of candidateSliceOp by computing the required slice of the consumer i...
Fuse the producer of the source of candidateSliceOp by computing the required slice of the producer i...
SmallVector< Operation * > tiledOps
Control function to check if a slice needs to be fused or not, The control function receives 1) the s...
Options used to control tile + fuse.
Transformation information returned after tile and fuse.
Options to use to control tiling.
SCFTileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes to use for each loop.
SCFTilingOptions & setNumThreads(ArrayRef< OpFoldResult > numThreads)
Convenience function to set the numThreadsComputationFunction to a function that computes num threads...
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
ReductionTilingStrategy
Specify how reduction dimensions should be tiled.
@ PartialReductionOuterReduction
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.