29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/Support/Debug.h"
33 #define DEBUG_TYPE "tile-using-interface"
40 auto tileSizes = llvm::to_vector(ts);
49 assert(!numThreadsComputationFunction &&
"num tiles already set");
50 auto numThreads = llvm::to_vector(nt);
61 size_t iterationDomainSize) {
63 if (filledVector.size() < iterationDomainSize) {
64 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
65 filledVector.append(range.begin(), range.end());
67 if (filledVector.size() > iterationDomainSize)
68 filledVector.resize(iterationDomainSize);
81 if (
options.numThreadsComputationFunction &&
84 loc,
"number of threads can only by specified when loop type is "
85 "set to use `scf.forall`");
89 if (!
options.interchangeVector.empty()) {
92 loc,
"invalid interchange vector, not a permutation of the entire "
107 size_t numLoops = iterationDomain.size();
110 if (
options.numThreadsComputationFunction) {
111 numThreads =
options.numThreadsComputationFunction(rewriter, op);
112 numThreads.resize(numLoops, zero);
115 if (
options.tileSizeComputationFunction) {
116 tileSizes =
options.tileSizeComputationFunction(rewriter, op);
117 tileSizes.resize(numLoops, zero);
118 return {tileSizes, numThreads};
130 tileSizes.resize(numLoops, zero);
131 for (
auto [index, range, nt] :
137 rewriter, op.
getLoc(), tileSizeExpr, {range.offset, range.size, nt});
139 tileSizes.resize(numLoops, zero);
140 return {tileSizes, numThreads};
147 assert(
options.tileSizeComputationFunction &&
148 "expected tile sizes to be specified");
149 tileSizes =
options.tileSizeComputationFunction(rewriter, op);
150 tileSizes.resize(numLoops, zero);
152 return {tileSizes, numThreads};
159 auto iterators = op.getLoopIteratorTypes();
160 assert(iterators.size() == tileSizes.size() &&
161 "expected as many tile size values as number of loops");
162 assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
163 "when specified, expected number of threads to use for each loop");
165 for (
auto [index, iterator, tileSize] :
169 if (!numThreads.empty()) {
170 if (std::optional<int64_t> constNumThreads =
172 if (constNumThreads.value() > 1 &&
173 iterator != utils::IteratorType::parallel) {
174 op.
emitWarning() <<
"tiling is not thread safe at axis #" << index;
181 if (constTileSize.value() > 0 &&
182 iterator != utils::IteratorType::parallel) {
183 op.
emitWarning() <<
"tiling is not thread safe at axis #" << index;
200 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
209 if (ts && ts.value() == 1)
236 if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
238 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
250 int materializedLoopNum = 0;
252 if (!numThreads.empty()) {
257 offsetExpr = d0 + d1 * s0;
258 residualTileSizeExpr = s1 - (d0 + d1 * s0);
260 for (
auto [nt, tileSize, loopRange] :
261 llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
266 offsets.push_back(loopRange.offset);
267 sizes.push_back(loopRange.size);
271 Value iv = ivs[materializedLoopNum++];
273 rewriter, loc, offsetExpr,
276 rewriter, loc, residualTileSizeExpr,
277 {loopRange.offset, nt, tileSize, loopRange.size});
283 {offset, loopRange.size});
287 {sizeMinusOffsetPerThread, tileSize});
303 rewriter, loc, maxMap, {rewriter.
getIndexAttr(0), size});
306 offsets.push_back(offset);
307 sizes.push_back(size);
309 return {offsets, sizes};
311 for (
auto [tileSize, loopRange] :
312 llvm::zip_equal(tileSizes, iterationDomain)) {
317 offsets.push_back(loopRange.offset);
318 sizes.push_back(loopRange.size);
322 Value iv = ivs[materializedLoopNum++];
324 offsets.push_back(offset);
327 sizes.push_back(size);
329 return {offsets, sizes};
339 for (
auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
343 lbs.push_back(loopRange.offset);
344 ubs.push_back(loopRange.size);
345 steps.push_back(tileSize);
347 return {lbs, ubs, steps};
377 if (newDestArgs.empty())
379 if (
auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
380 destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
398 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
399 assert(loopRanges.size() == tileSizes.size() &&
400 "expected as many tile sizes as loop ranges");
404 std::tie(lbs, ubs, steps) =
414 for (
auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
416 rewriter.
create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
419 loops.push_back(loop);
420 ivs.push_back(loop.getInductionVar());
422 destinationTensors = loop.getRegionIterArgs();
427 if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
428 tiledResults, resultOffsets, resultSizes))) {
430 loc,
"failed to generate inner tile loop body");
435 assert(tiledResults.size() == destinationTensors.size() &&
436 "Number of results of body should be equal to number of iter args");
440 for (
auto [tiledValue, destinationTensor, resultOffset, resultSize] :
441 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
445 auto insertSlice = rewriter.
create<tensor::InsertSliceOp>(
446 loc, tiledValue, destinationTensor, resultOffset, resultSize,
448 yieldedValues.push_back(insertSlice);
450 rewriter.
create<scf::YieldOp>(loc, yieldedValues);
453 for (
auto [outerLoop, innerLoop] :
457 cast<scf::ForOp>(outerLoop.getOperation()).getBody());
458 rewriter.
create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
479 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
480 assert(loopRanges.size() == tileSizes.size() &&
481 "expected as many tile sizes as loop ranges");
484 sizes(loopRanges.size());
486 std::optional<ArrayAttr> mappingAttr;
487 if (!mappingVector.empty())
490 scf::ForallOp forallOp;
491 bool useNumThreads = !numThreads.empty();
496 for (
auto nt : numThreads) {
499 nonZeroNumThreads.push_back(nt);
501 forallOp = rewriter.
create<scf::ForallOp>(loc, nonZeroNumThreads,
502 destinationTensors, mappingAttr);
505 std::tie(lbs, ubs, steps) =
507 forallOp = rewriter.
create<scf::ForallOp>(loc, lbs, ubs, steps,
508 destinationTensors, mappingAttr);
510 loops.push_back(forallOp);
513 destinationTensors = forallOp.getRegionOutArgs();
517 if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
518 destinationTensors, tiledResults, resultOffsets,
523 for (
auto [tiledValue, destinationTensor, resultOffset, resultSize] :
524 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
529 rewriter.
create<tensor::ParallelInsertSliceOp>(
530 loc, tiledValue, destinationTensor, resultOffset, resultSize,
556 return tiledBodyFn(rewriter, loc,
ValueRange{}, destinationTensors,
557 tiledResults, resultOffsets, resultSizes);
561 destinationTensors, tiledBodyFn, loops);
565 rewriter, loc, loopRanges, tileSizes, numThreads,
options.mappingVector,
566 destinationTensors, tiledBodyFn, loops);
578 template <
typename LoopType>
579 FailureOr<LoopLikeOpInterface>
588 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
595 auto inits = llvm::to_vector(loopOp.getInitArgs());
596 inits.append(newInitOperands.begin(), newInitOperands.end());
597 auto newLoop = rewriter.
create<scf::ForOp>(
598 loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
602 Block *loopBody = loopOp.getBody();
603 Block *newLoopBody = newLoop.getBody();
605 loopBody, newLoopBody,
606 newLoopBody->
getArguments().take_front(loopBody->getNumArguments()));
608 auto yieldOp = cast<scf::YieldOp>(newLoopBody->
getTerminator());
614 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
615 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
616 newRegionIterArgs, tiledValues, resultOffsets,
623 for (
auto [tiledValue, regionIterArg, resultOffset, resultSize] :
624 llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
628 Value insert = rewriter.
create<tensor::InsertSliceOp>(
629 yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
631 newYieldValues.push_back(insert);
636 newLoop->getResults().take_front(loopOp.getNumResults()));
637 return cast<LoopLikeOpInterface>(newLoop.getOperation());
642 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
648 auto inits = llvm::to_vector(loopOp.getOutputs());
649 inits.append(newInitOperands.begin(), newInitOperands.end());
650 auto newLoop = rewriter.
create<scf::ForallOp>(
651 loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
652 loopOp.getMixedStep(), inits, loopOp.getMapping(),
656 Block *loopBody = loopOp.getBody();
657 Block *newLoopBody = newLoop.getBody();
659 loopBody, newLoopBody,
660 newLoopBody->
getArguments().take_front(loopBody->getNumArguments()));
662 auto terminator = cast<scf::InParallelOp>(newLoopBody->
getTerminator());
667 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
668 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
669 regionIterArgs, tiledValues, resultOffsets,
673 "failed to get yielded tiled values");
679 for (
auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
680 tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
683 rewriter.
create<tensor::ParallelInsertSliceOp>(
684 terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
689 newLoop->getResults().take_front(loopOp.getNumResults()));
690 return cast<LoopLikeOpInterface>(newLoop.getOperation());
700 loopLikeOp.getOperation())
701 .Case<scf::ForOp, scf::ForallOp>(
702 [&](
auto loopOp) -> FailureOr<LoopLikeOpInterface> {
704 loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
706 .Default([&](
auto loopOp) -> FailureOr<LoopLikeOpInterface> {
726 for (
auto &loop : loops.drop_back()) {
730 auto forLoop = cast<scf::ForOp>(loop.getOperation());
734 newInits.append(newInitValues.begin(), newInitValues.end());
735 auto newLoop = rewriter.
create<scf::ForOp>(
736 forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
737 forLoop.getStep(), newInits,
742 sourceBlockArgs.push_back(newLoop.getInductionVar());
743 auto newRegionIterArgs = newLoop.getRegionIterArgs();
744 sourceBlockArgs.append(
745 newRegionIterArgs.begin(),
746 std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
747 rewriter.
mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
749 forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
751 ivs.push_back(newLoop.getInductionVar());
752 newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
756 LoopLikeOpInterface innerMostLoop = loops.back();
757 FailureOr<LoopLikeOpInterface> newInnerMostLoop =
759 getNewTiledYieldsFn);
761 if (failed(newInnerMostLoop))
762 return innerMostLoop.emitOpError(
"failed to return additional yields");
763 loops.back() = newInnerMostLoop.value();
767 for (
auto [outerLoop, innerLoop] :
768 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
770 auto outerForLoop = cast<scf::ForOp>(outerLoop);
771 auto outerLoopYield =
772 cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
774 llvm::to_vector(outerLoopYield.getOperands());
776 innerLoop->getResults().take_back(newInitValues.size());
777 newYields.append(additionalYields.begin(), additionalYields.end());
786 FailureOr<scf::SCFTilingResult>
801 std::tie(tileSizes, numThreads) =
813 if (!
options.interchangeVector.empty()) {
815 iterationDomain.size());
817 "expected interchange vector to be a permutation");
821 if (!numThreads.empty())
825 FailureOr<TilingResult> tilingResult;
837 rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
841 if (!interchangeVector.empty()) {
850 auto clonedOp = cast<TilingInterface>(
857 tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
865 tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes);
866 if (failed(tilingResult)) {
876 for (
auto [index, tiledValue] :
878 tiledResults.push_back(tiledValue);
880 if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
881 resultOffset, resultSize))) {
882 for (
auto op : tilingResult->tiledOps) {
886 op,
"failed to get slice of result produced");
888 resultOffsets.emplace_back(std::move(resultOffset));
889 resultSizes.emplace_back(std::move(resultSize));
898 destinationTensors))) {
900 "unable to create destination tensors");
906 tileSizes, numThreads, destinationTensors,
907 innerYieldTiledValuesFn, loops)))
908 return op.
emitOpError(
"failed to generate tiling loops");
909 assert(succeeded(tilingResult) &&
910 "expected tiling result to be computed after loop generation");
916 tilingResult->tiledValues,
917 tilingResult->generatedSlices};
921 loops.front()->getResults(), [](
OpResult r) ->
Value { return r; });
923 tilingResult->generatedSlices};
926 FailureOr<scf::SCFReductionTilingResult>
928 PartialReductionOpInterface op,
933 auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
935 auto tileSizesVector = llvm::to_vector(tileSizes);
936 if (tileSizesVector.size() < iterationDomain.size()) {
938 tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
942 tilingInterfaceOp.getLoopIteratorTypes();
945 for (
auto [idx, iteratorType] :
947 if (iteratorType == utils::IteratorType::reduction)
948 reductionDims.push_back(idx);
952 FailureOr<SmallVector<Value>> maybeInitTensors =
953 op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
955 if (failed(maybeInitTensors)) {
962 auto innerYieldTiledValuesFn =
970 int materializedLoopNum = 0;
971 for (
auto [tileSize, loopRange] :
972 llvm::zip_equal(tileSizesVector, iterationDomain)) {
974 offsets.push_back(loopRange.offset);
975 sizes.push_back(loopRange.size);
978 Value iv = ivs[materializedLoopNum++];
979 offsets.push_back(iv);
987 auto clonedOp = cast<PartialReductionOpInterface>(
991 FailureOr<TilingResult> partialTilingResult =
992 clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets,
993 sizes, reductionDims);
994 if (failed(partialTilingResult)) {
997 std::swap(parallelTiledOps, partialTilingResult->tiledOps);
998 std::swap(tiledResult, partialTilingResult->tiledValues);
1006 for (
auto result : tiledResult) {
1008 resultOffsets.emplace_back(std::move(outOffsets));
1011 for (
size_t i = 0; i < offsets.size(); i++) {
1014 resultSizes.emplace_back(std::move(outSizes));
1025 initTensors, innerYieldTiledValuesFn, loops)))
1029 loops.front()->getResults(), [](
OpResult r) ->
Value { return r; });
1033 FailureOr<MergeResult> mergeResult =
1034 op.mergeReductions(b, loc, replacements, reductionDims);
1035 if (failed(mergeResult)) {
1038 b.
replaceOp(op, mergeResult->replacements);
1042 std::swap(reductionTilingResult.
mergeOps, mergeResult->mergeOps);
1043 std::swap(reductionTilingResult.
initialValues, initTensors);
1044 std::swap(reductionTilingResult.
loops, loops);
1045 std::swap(reductionTilingResult.
replacements, mergeResult->replacements);
1047 return reductionTilingResult;
1059 static std::tuple<OpResult, std::optional<OpOperand *>>
1062 std::optional<OpOperand *> destinationIterArg;
1063 auto loopIt = loops.rbegin();
1064 while (
auto iterArg = dyn_cast<BlockArgument>(source->
get())) {
1065 auto loop = *loopIt;
1066 if (iterArg.getOwner()->getParentOp() != loop)
1068 source = loop.getTiedLoopInit(iterArg);
1071 if (loopIt == loops.rend())
1072 destinationIterArg = source;
1073 return {dyn_cast<OpResult>(source->
get()), destinationIterArg};
1078 std::optional<scf::SCFFuseProducerOfSliceResult>
1080 RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1084 auto [fusableProducer, destinationInitArg] =
1087 if (!fusableProducer)
1088 return std::nullopt;
1089 unsigned resultNumber = fusableProducer.getResultNumber();
1097 Operation *fusableProducerOp = fusableProducer.getOwner();
1098 if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
1100 rewriter, fusableProducerOp->
getLoc(), fusableProducerOp,
1101 origDestinationTensors)))
1102 return std::nullopt;
1104 clonedOpDestinationTensors = origDestinationTensors;
1105 if (destinationInitArg &&
1106 isa<DestinationStyleOpInterface>(fusableProducerOp)) {
1110 clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1114 rewriter, fusableProducerOp, clonedOpDestinationTensors);
1119 llvm::to_vector(candidateSliceOp->getOperands());
1120 candidateSliceOpOperands[0] = clonedProducerOp->
getResult(resultNumber);
1121 tensor::ExtractSliceOp clonedCandidateSliceOp =
1123 candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1126 FailureOr<TilingResult> tileAndFuseResult =
1128 rewriter, clonedCandidateSliceOp,
1129 clonedProducerOp->
getResult(resultNumber));
1130 if (failed(tileAndFuseResult))
1131 return std::nullopt;
1135 tileAndFuseResult->tiledValues[0]);
1136 rewriter.
eraseOp(clonedCandidateSliceOp);
1137 rewriter.
eraseOp(clonedProducerOp);
1182 if (destinationInitArg &&
1183 isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
1185 ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1186 .set(origDestinationTensors[resultNumber]);
1189 fusableProducer, tileAndFuseResult->tiledValues[0],
1190 tileAndFuseResult->
tiledOps, tileAndFuseResult->generatedSlices};
1195 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1203 *tiledOwner = fusedProducerInfo.
tiledOps[0];
1208 yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
1210 : llvm::to_vector(yieldResultNumber);
1212 for (
const auto &resultNumber : initNumberList) {
1214 rewriter, loc, originalOwner->
getResult(resultNumber));
1215 if (succeeded(initValue)) {
1216 initValueList.push_back(initValue.value());
1232 sliceSizes = sliceOp.getMixedSizes();
1235 if (llvm::any_of(sliceOp.getMixedStrides(), [](
OpFoldResult ofr) {
1236 return !isConstantIntValue(ofr, 1);
1240 unsigned sliceResultNumber =
1243 auto tilableOp = cast<TilingInterface>(originalOwner);
1247 if (tilableOp->getNumResults() > 1 &&
1248 failed(tilableOp.getIterationDomainTileFromResultTile(
1249 rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1250 iterDomainOffset, iterDomainSizes))) {
1265 for (
const auto &resultNumber : initNumberList) {
1266 if (resultNumber == sliceResultNumber) {
1267 offsetList.push_back(sliceOffset);
1268 sizesList.push_back(sliceSizes);
1270 assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1273 if (failed(tilableOp.getResultTilePosition(
1274 rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1278 offsetList.push_back(offset);
1279 sizesList.push_back(sizes);
1284 if (
auto tiledDestStyleOp =
1285 dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1287 for (
const auto &&[index, newRegionArg] :
1289 auto destSlice = rewriter.
create<tensor::ExtractSliceOp>(
1290 loc, newRegionArg, offsetList[index], sizesList[index],
1293 generatedSlices.push_back(destSlice);
1294 unsigned resultNumber = initNumberList[index];
1296 tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1305 for (
const auto &&[index, resultNumber] :
llvm::enumerate(initNumberList)) {
1306 tiledResult.push_back(tiledOwner->getResult(resultNumber));
1307 tiledOffset.emplace_back(offsetList[index]);
1308 tiledSizes.emplace_back(sizesList[index]);
1314 newYieldValuesFn))) {
1317 return generatedSlices;
1331 explicit SliceTrackingListener(
1332 std::optional<FrozenRewritePatternSet> patterns);
1333 SliceTrackingListener() =
default;
1341 void notifyOperationInserted(
Operation *op,
1348 void notifyOperationErased(
Operation *op)
override;
1355 std::deque<tensor::ExtractSliceOp> worklist;
1359 std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
1362 SliceTrackingListener::SliceTrackingListener(
1363 std::optional<FrozenRewritePatternSet> p) {
1364 patterns = std::move(p);
1370 if (
auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1371 worklist.push_back(slice);
1383 void SliceTrackingListener::notifyOperationInserted(
1385 auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1388 worklist.push_back(slice);
1393 void SliceTrackingListener::removeOp(
Operation *op) {
1394 if (!isa<tensor::ExtractSliceOp>(op))
1396 auto iter = worklist.begin();
1397 while (iter != worklist.end()) {
1402 if (iter == worklist.end())
1405 worklist.erase(iter);
1408 void SliceTrackingListener::notifyOperationErased(
Operation *op) {
1412 void SliceTrackingListener::notifyOperationReplaced(
Operation *op,
1419 FailureOr<scf::SCFTileAndFuseResult>
1425 if (!consumer->getNumResults()) {
1427 consumer,
"invalid pattern for op with no results");
1432 llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
1434 FailureOr<scf::SCFTilingResult> tilingResult =
1437 if (failed(tilingResult))
1439 for (
auto *tiledOp : tilingResult->tiledOps)
1440 tiledAndFusedOps.insert(tiledOp);
1443 auto &loops = tilingResult->loops;
1444 if (loops.empty()) {
1446 for (
auto [origVal, replacement] :
1447 llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
1448 replacements[origVal] = replacement;
1458 for (
auto [index, result] :
llvm::enumerate(consumer->getResults())) {
1459 origValToResultNumber[result] = index;
1469 struct WorklistItem {
1470 tensor::ExtractSliceOp candidateSlice;
1474 SliceTrackingListener sliceTracker =
1475 SliceTrackingListener(
options.cleanupPatterns);
1478 sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1482 while (!sliceTracker.worklist.empty()) {
1483 auto candidateSlice = sliceTracker.worklist.front();
1484 sliceTracker.worklist.pop_front();
1486 auto [fusableProducer, destinationInitArg] =
1489 if (!fusableProducer)
1492 std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1493 options.fusionControlFn(candidateSlice, fusableProducer,
1494 destinationInitArg.has_value());
1495 if (!controlFnResult)
1498 WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
1503 std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1511 if (worklistItem.controlFnResult.yieldProducerReplacement) {
1515 Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1516 FailureOr<SmallVector<Operation *>> newSlices =
1518 worklistItem.candidateSlice,
1519 fusedResult.value(), loops);
1520 if (failed(newSlices)) {
1522 fusableProducerOp,
"failed to replacement value for this "
1523 "operation from within the tiled loop");
1525 worklistCandidates.append(newSlices.value());
1526 for (
auto [index, result] :
1528 origValToResultNumber[result] = loops.front()->getNumResults() -
1534 fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1535 fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1536 tiledAndFusedOps.insert(tiledAndFusedOp);
1539 if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1545 for (
auto [origVal, resultNumber] : origValToResultNumber) {
1546 replacements[origVal] = loops.front()->getResult(resultNumber);
1559 static LogicalResult
1561 Value result = candidateSliceOp.getResult();
1563 if (!llvm::hasSingleElement(uses)) {
1564 LLVM_DEBUG(llvm::dbgs() <<
"Too many uses of the candidate slice op\n");
1567 OpOperand &operandUse = (*uses.begin());
1569 if (!isa<scf::YieldOp>(userOp)) {
1570 LLVM_DEBUG(llvm::dbgs()
1571 <<
"Expected scf.yield to be the only user, but got -> "
1576 LLVM_DEBUG(llvm::dbgs() <<
"Expected tensor.insert_slice and scf.yield to "
1577 "be in the same block\n");
1587 Block *containingOpBlock) {
1592 Operation *consumerOp = opOperand.getOwner();
1593 if (isa<scf::YieldOp, tensor::ParallelInsertSliceOp>(consumerOp))
1600 if (!isa<TilingInterface>(consumerOp) ||
1601 !isa<DestinationStyleOpInterface>(consumerOp))
1603 if (containingOpBlock != consumerOp->
getBlock())
1605 operand = &opOperand;
1633 auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp());
1636 auto isForOpYieldResultOfInnerLoop =
1637 [](scf::ForOp outerLoop) -> LogicalResult {
1638 Block *body = outerLoop.getBody();
1642 auto innerForOp = dyn_cast<scf::ForOp>(body->
front());
1646 return success(innerForOp->getNumResults() == yieldOp->getNumOperands());
1649 while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) {
1650 nestLoops.push_back(outerLoop);
1651 outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp());
1654 return {nestLoops.rbegin(), nestLoops.rend()};
1661 static FailureOr<OpOperand *>
1665 Value sliceResult = candidateSliceOp.getResult();
1671 auto forOp = dyn_cast<scf::ForOp>(containingOp);
1675 Value resultingValue = topLevelForOp->getResult(resultNumber);
1682 static FailureOr<OpOperand *>
1685 Value sliceDest = candidateSliceOp.getDest();
1686 auto iterArg = dyn_cast<BlockArgument>(sliceDest);
1693 auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
1696 Value resultingValue =
1697 forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
1718 if (userOp == consumerOp)
1720 if (parentBlock != userOp->getBlock() ||
1730 if (
auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1732 }
else if (
auto parallelInsertSlice =
1733 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1742 FailureOr<scf::SCFFuseConsumerOfSliceResult>
1745 if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
1749 bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
1753 FailureOr<OpOperand *> maybeConsumerOpOperand =
1755 if (failed(maybeConsumerOpOperand)) {
1757 "could not fetch consumer to fuse");
1759 OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
1762 unsigned resultNumber = 0;
1763 if (
auto producerResult = dyn_cast<OpResult>(consumerOpOperand->
get())) {
1764 resultNumber = producerResult.getResultNumber();
1767 consumerOp,
"consumer op's operand doesn't seem to be an OpResult");
1774 LoopLikeOpInterface innerMostLoop =
1777 if (isInsertSliceOp) {
1778 nestedLoops = llvm::map_to_vector(
1780 cast<scf::ForOp>(innerMostLoop.getOperation())),
1781 [](scf::ForOp forOp) {
1782 return cast<LoopLikeOpInterface>(forOp.getOperation());
1785 nestedLoops = {innerMostLoop};
1788 LoopLikeOpInterface outerMostLoop = nestedLoops.front();
1793 "containing loop op should either yield just one value or "
1794 "have the consumer op as its first user");
1800 auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
1803 "consumer op is not DPS operation");
1805 llvm::map_to_vector(dstOp.getDpsInits(), [](
Value v) { return v; });
1806 if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
1809 "consumer op taking the result of scf.for as init is not supported");
1813 Location loc = outerMostLoop->getLoc();
1823 tensor::InsertSliceOp clonedInsertSliceOp;
1825 dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
1826 auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
1828 clonedInsertSliceOp = rewriter.
create<tensor::InsertSliceOp>(
1829 loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
1830 sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
1833 clonedInsertSliceOp =
1834 cast<tensor::InsertSliceOp>(rewriter.
clone(*candidateSliceOp));
1838 auto clonedConsumerOp = cast<TilingInterface>(rewriter.
clone(*consumerOp));
1842 OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
1844 operandToReplace.
set(clonedInsertSliceOp.getResult());
1850 cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
1851 FailureOr<TilingResult> tileAndFuseResult =
1853 rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
1854 if (failed(tileAndFuseResult)) {
1857 auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
1859 clonedInsertSliceOp.getSource());
1880 candidateSliceOp,
"containingOp's result yield with stride");
1885 if (failed(tiledConsumerOp.getIterationDomainTileFromOperandTile(
1886 rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
1887 iterDomainSizes))) {
1890 "can't get iter domain position from input position");
1896 unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
1898 totalNumResultsOfConsumer);
1900 totalNumResultsOfConsumer);
1901 for (
auto [idx, v] :
llvm::enumerate(tiledConsumerOp->getResults())) {
1902 if (failed(tiledConsumerOp.getResultTilePosition(
1903 rewriter, idx, iterDomainOffsets, iterDomainSizes,
1904 resultOffsets[idx], resultSizes[idx]))) {
1907 "can't get result domain position from iter domain position");
1913 if (
auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
1914 tiledConsumerOp.getOperation())) {
1916 for (
const auto &&[index, newRegionArg] :
1918 auto destSlice = rewriter.
create<tensor::ExtractSliceOp>(
1919 loc, newRegionArg, resultOffsets[index], resultSizes[index],
1924 auto dstNumber = index;
1926 tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
1935 for (
const auto &&[index, result] :
1937 tiledResult.push_back(result);
1938 tiledOffset.emplace_back(resultOffsets[index]);
1939 tiledSizes.emplace_back(resultSizes[index]);
1945 newYieldValuesFn))) {
1947 "unable to add new inits to nest loop");
1952 for (
auto &&[oldResult, newResult] : llvm::zip(
1954 nestedLoops.front()->getResults().take_back(newInits.size()))) {
1959 rewriter.
eraseOp(clonedConsumerOp);
1963 &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
1964 tileAndFuseResult->tiledOps};
1971 FailureOr<SmallVector<scf::ForOp>>
1973 TilingInterface op) {
1977 op,
"unable to lower to loops operations with return values");
1984 for (
auto loopRange : domain) {
1991 auto loop = rewriter.
create<scf::ForOp>(op.
getLoc(), offsetVal, sizeVal,
1993 loops.push_back(loop);
1994 ivs.push_back(loop.getInductionVar());
1997 if (failed(op.generateScalarImplementation(rewriter, op.
getLoc(), ivs))) {
static llvm::ManagedStatic< PassManagerOptions > options
static LogicalResult checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp)
This utility currently checks whether the loop either :-.
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static LogicalResult verifyTileSizeOptions(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options)
Verify the tile size options are set in a consistent manner.
static LogicalResult checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp)
A utility function that checks whether the only use of the result of a tensor.insert_slice op is in a...
std::function< LogicalResult(RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, SmallVector< Value > &tiledValues, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> YieldTiledValuesFn
A function that allows returning additional yielded values during yieldTiledValuesAndReplace.
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, OpFoldResult numThreads, OpFoldResult iterationSize)
Returns true if the maximum tile offset tileSize * numThreads-1 is less than iterationSize.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static FailureOr< OpOperand * > getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp)
Fetch the untiled consumer of a scf.for's result which is yielded by a tensor.insert_slice.
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, OpFoldResult offset, OpFoldResult tileSize)
Returns the bounded tile size given the current offset, loopRange and tileSize, i....
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Append the specified additional newInitOperands operands to the loops existing init operands (or simi...
static LogicalResult generateLoopNestUsingForOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ValueRange destinationTensors, YieldTiledValuesFn yieldTiledValuesFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.for operation.
static void checkSafeToTileToForall(TilingInterface op, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Checks if any of the tiled loops are not parallel.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, ArrayRef< Range > iterationDomain, const scf::SCFTilingOptions &options)
Method to instantiate the tile sizes and/or number of threads specified by the user.
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< LoopLikeOpInterface > loops)
Return the untiled producer whose slice is used in a tiled consumer.
static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using the loop construct specifed in options.
static SmallVector< scf::ForOp > getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop)
Find the perfectly nested loops outside of given loop(included) sorted from outer to inner.
static bool tileDividesIterationDomain(Range loopRange)
Check if stride evenly divides the trip count size - offset.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes)
Function to return the bounds of the loops to be generated.
static Operation * cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *op, ValueRange newDestArgs)
Clones the operation and updates the destination if the operation implements the DestinationStyleOpIn...
static LogicalResult addInitOperandsToLoopNest(RewriterBase &rewriter, MutableArrayRef< LoopLikeOpInterface > loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn)
Method to add new init values to a loop nest.
static LogicalResult generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ArrayRef< Attribute > mappingVector, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.forall operation.
static FailureOr< OpOperand * > getConsumerFromUses(Value val, Block *containingOpBlock)
Fetches the OpOperand of the only user (and use) of the value val which implements TilingInterface an...
Base type for affine expression.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteStrictness strictMode
Strict mode can restrict the ops that are added to the worklist during the rewrite.
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class represents a saved insertion point.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation * getOwner() const
Returns the operation that owns this result.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
user_range getUsers()
Returns a range of all users.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
FailureOr< scf::SCFReductionTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
FailureOr< scf::SCFFuseConsumerOfSliceResult > tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp)
Implementation of fusing consumer of a single slice by computing the slice of the consumer in-place f...
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< SmallVector< Operation * > > yieldReplacementForFusedProducer(RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef< LoopLikeOpInterface > loops, ArrayRef< unsigned > yieldResultNumber=ArrayRef< unsigned >{})
Reconstruct the fused producer from within the tiled-and-fused code.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
std::optional< SCFFuseProducerOfSliceResult > tileAndFuseProducerOfSlice(RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing producer of a single slice by computing the slice of the producer in-place.
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
FailureOr< TilingResult > replaceInsertSliceWithTiledConsumer(OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp, OpOperand &consumerOp)
Method to swap an tensor.insert_slice with its consumer when the consumer implements the TilingInterf...
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
LogicalResult applyOpPatternsAndFold(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
@ ExistingAndNewOps
Only pre-existing and newly created ops are processed.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Container for result values of tiling.
Fuse the consumer of the source of candidateSliceOp by computing the required slice of the consumer i...
Fuse the producer of the source of candidateSliceOp by computing the required slice of the producer i...
SmallVector< Operation * > tiledOps
Transformation information returned after reduction tiling.
SmallVector< Value > replacements
The replacements to use for the results of the tiled operation.
SmallVector< Value > initialValues
Initial values used for reduction.
SmallVector< Operation * > parallelTiledOps
The partial reduction tiled op generated.
SmallVector< LoopLikeOpInterface > loops
The loop operations that iterate over the tiles.
SmallVector< Operation * > mergeOps
The final reduction operation merging all the partial reductions.
Control function to check if a slice needs to be fused or not, The control function receives 1) the s...
Options used to control tile + fuse.
Transformation information returned after tile and fuse.
Options to use to control tiling.
SCFTileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes to use for each loop.
SCFTilingOptions & setNumThreads(ArrayRef< OpFoldResult > numThreads)
Convenience function to set the numThreadsComputationFunction to a function that computes num threads...
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.