27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/Debug.h"
31 #define DEBUG_TYPE "tile-using-interface"
38 auto tileSizes = llvm::to_vector(ts);
49 size_t iterationDomainSize) {
51 if (filledVector.size() < iterationDomainSize) {
52 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
53 filledVector.append(range.begin(), range.end());
55 if (filledVector.size() > iterationDomainSize)
56 filledVector.resize(iterationDomainSize);
75 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
84 if (ts && ts.value() == 1)
130 if (newDestArgs.empty())
132 if (
auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
133 destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
151 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
152 assert(loopRanges.size() == tileSizes.size() &&
153 "expected as many tile sizes as loop ranges");
157 for (
auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
167 rewriter.
create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
170 loops.push_back(loop);
171 ivs.push_back(loop.getInductionVar());
173 destinationTensors = loop.getRegionIterArgs();
178 if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
179 tiledResults, resultOffsets, resultSizes))) {
181 loc,
"failed to generate inner tile loop body");
186 assert(tiledResults.size() == destinationTensors.size() &&
187 "Number of results of body should be equal to number of iter args");
191 for (
auto [tiledValue, destinationTensor, resultOffset, resultSize] :
192 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
196 auto insertSlice = rewriter.
create<tensor::InsertSliceOp>(
197 loc, tiledValue, destinationTensor, resultOffset, resultSize,
199 yieldedValues.push_back(insertSlice);
201 rewriter.
create<scf::YieldOp>(loc, yieldedValues);
204 for (
auto [outerLoop, innerLoop] :
208 cast<scf::ForOp>(outerLoop.getOperation()).getBody());
209 rewriter.
create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
231 assert(!loopRanges.empty() &&
"unexpected empty loop ranges");
232 assert(loopRanges.size() == tileSizes.size() &&
233 "expected as many tile sizes as loop ranges");
236 sizes(loopRanges.size());
238 for (
auto [tileSize, loopRange] : llvm::zip_equal(tileSizes, loopRanges)) {
241 lbs.push_back(loopRange.offset);
242 ubs.push_back(loopRange.size);
243 steps.push_back(tileSize);
245 assert(!lbs.empty() &&
"Expected at least one loop range");
247 std::optional<ArrayAttr> mappingAttr;
248 if (!mappingVector.empty())
251 auto forallOp = rewriter.
create<scf::ForallOp>(
252 loc, lbs, ubs, steps, destinationTensors, mappingAttr);
253 loops.push_back(forallOp);
256 destinationTensors = forallOp.getRegionOutArgs();
260 if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
261 destinationTensors, tiledResults, resultOffsets,
266 for (
auto [tiledValue, destinationTensor, resultOffset, resultSize] :
267 llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
272 rewriter.
create<tensor::ParallelInsertSliceOp>(
273 loc, tiledValue, destinationTensor, resultOffset, resultSize,
301 return tiledBodyFn(rewriter, loc,
ValueRange{}, destinationTensors,
302 tiledResults, resultOffsets, resultSizes);
306 destinationTensors, tiledBodyFn, loops);
310 rewriter, loc, loopRanges, tileSizes,
options.mappingVector,
311 destinationTensors, tiledBodyFn, loops);
323 template <
typename LoopType>
324 FailureOr<LoopLikeOpInterface>
333 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
340 auto inits = llvm::to_vector(loopOp.getInitArgs());
341 inits.append(newInitOperands.begin(), newInitOperands.end());
342 auto newLoop = rewriter.
create<scf::ForOp>(
343 loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
347 Block *loopBody = loopOp.getBody();
348 Block *newLoopBody = newLoop.getBody();
350 loopBody, newLoopBody,
351 newLoopBody->
getArguments().take_front(loopBody->getNumArguments()));
353 auto yieldOp = cast<scf::YieldOp>(newLoopBody->
getTerminator());
359 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
360 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
361 newRegionIterArgs, tiledValues, resultOffsets,
368 for (
auto [tiledValue, regionIterArg, resultOffset, resultSize] :
369 llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
373 Value insert = rewriter.
create<tensor::InsertSliceOp>(
374 yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
376 newYieldValues.push_back(insert);
381 newLoop->getResults().take_front(loopOp.getNumResults()));
382 return cast<LoopLikeOpInterface>(newLoop.getOperation());
387 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
393 auto inits = llvm::to_vector(loopOp.getOutputs());
394 inits.append(newInitOperands.begin(), newInitOperands.end());
395 auto newLoop = rewriter.
create<scf::ForallOp>(
396 loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
397 loopOp.getMixedStep(), inits, loopOp.getMapping(),
401 Block *loopBody = loopOp.getBody();
402 Block *newLoopBody = newLoop.getBody();
404 loopBody, newLoopBody,
405 newLoopBody->
getArguments().take_front(loopBody->getNumArguments()));
407 auto terminator = cast<scf::InParallelOp>(newLoopBody->
getTerminator());
412 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
413 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
414 regionIterArgs, tiledValues, resultOffsets,
418 "failed to get yielded tiled values");
424 for (
auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
425 tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
428 rewriter.
create<tensor::ParallelInsertSliceOp>(
429 terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
434 newLoop->getResults().take_front(loopOp.getNumResults()));
435 return cast<LoopLikeOpInterface>(newLoop.getOperation());
445 loopLikeOp.getOperation())
446 .Case<scf::ForOp, scf::ForallOp>(
447 [&](
auto loopOp) -> FailureOr<LoopLikeOpInterface> {
449 loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
451 .Default([&](
auto loopOp) -> FailureOr<LoopLikeOpInterface> {
471 for (
auto &loop : loops.drop_back()) {
475 auto forLoop = cast<scf::ForOp>(loop.getOperation());
479 newInits.append(newInitValues.begin(), newInitValues.end());
480 auto newLoop = rewriter.
create<scf::ForOp>(
481 forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
482 forLoop.getStep(), newInits,
487 sourceBlockArgs.push_back(newLoop.getInductionVar());
488 auto newRegionIterArgs = newLoop.getRegionIterArgs();
489 sourceBlockArgs.append(
490 newRegionIterArgs.begin(),
491 std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
492 rewriter.
mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
494 forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
496 ivs.push_back(newLoop.getInductionVar());
497 newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
501 LoopLikeOpInterface innerMostLoop = loops.back();
502 FailureOr<LoopLikeOpInterface> newInnerMostLoop =
504 getNewTiledYieldsFn);
506 if (failed(newInnerMostLoop))
507 return innerMostLoop.emitOpError(
"failed to return additional yields");
508 loops.back() = newInnerMostLoop.value();
512 for (
auto [outerLoop, innerLoop] :
513 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
515 auto outerForLoop = cast<scf::ForOp>(outerLoop);
516 auto outerLoopYield =
517 cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
519 llvm::to_vector(outerLoopYield.getOperands());
521 innerLoop->getResults().take_back(newInitValues.size());
522 newYields.append(additionalYields.begin(), additionalYields.end());
531 FailureOr<scf::SCFTilingResult>
537 if (!
options.tileSizeComputationFunction) {
539 op,
"missing tile size computation function");
544 size_t numLoops = iterationDomain.size();
551 options.tileSizeComputationFunction(rewriter, op);
552 if (tileSizes.size() < iterationDomain.size()) {
554 tileSizes.append(numLoops - tileSizes.size(), zero);
560 if (!
options.interchangeVector.empty()) {
562 iterationDomain.size());
564 if (!interchangeVector.empty()) {
567 op,
"invalid intechange vector, not a permutation of the entire "
575 FailureOr<TilingResult> tilingResult;
587 int materializedLoopNum = 0;
588 for (
auto [tileSize, loopRange] :
589 llvm::zip_equal(tileSizes, iterationDomain)) {
591 offsets.push_back(loopRange.offset);
592 sizes.push_back(loopRange.size);
595 Value iv = ivs[materializedLoopNum++];
596 offsets.push_back(iv);
604 if (!interchangeVector.empty()) {
613 auto clonedOp = cast<TilingInterface>(
620 tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
627 tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes);
628 if (failed(tilingResult)) {
638 for (
auto [index, tiledValue] :
640 tiledResults.push_back(tiledValue);
642 if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
643 resultOffset, resultSize))) {
644 for (
auto op : tilingResult->tiledOps) {
648 op,
"failed to get slice of result produced");
650 resultOffsets.emplace_back(std::move(resultOffset));
651 resultSizes.emplace_back(std::move(resultSize));
660 destinationTensors))) {
662 "unable to create destination tensors");
668 tileSizes, destinationTensors,
669 innerYieldTiledValuesFn, loops)))
670 return op.
emitOpError(
"failed to generate tiling loops");
671 assert(succeeded(tilingResult) &&
672 "expected tiling result to be computed after loop generation");
678 tilingResult->tiledValues};
682 loops.front()->getResults(), [](
OpResult r) ->
Value { return r; });
686 FailureOr<scf::SCFReductionTilingResult>
688 PartialReductionOpInterface op,
693 auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
695 auto tileSizesVector = llvm::to_vector(tileSizes);
696 if (tileSizesVector.size() < iterationDomain.size()) {
698 tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
702 tilingInterfaceOp.getLoopIteratorTypes();
705 for (
auto [idx, iteratorType] :
707 if (iteratorType == utils::IteratorType::reduction)
708 reductionDims.push_back(idx);
712 FailureOr<SmallVector<Value>> maybeInitTensors =
713 op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
715 if (failed(maybeInitTensors)) {
722 auto innerYieldTiledValuesFn =
730 int materializedLoopNum = 0;
731 for (
auto [tileSize, loopRange] :
732 llvm::zip_equal(tileSizesVector, iterationDomain)) {
734 offsets.push_back(loopRange.offset);
735 sizes.push_back(loopRange.size);
738 Value iv = ivs[materializedLoopNum++];
739 offsets.push_back(iv);
747 auto clonedOp = cast<PartialReductionOpInterface>(
751 FailureOr<TilingResult> partialTilingResult =
752 clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets,
753 sizes, reductionDims);
754 if (failed(partialTilingResult)) {
757 std::swap(parallelTiledOps, partialTilingResult->tiledOps);
758 std::swap(tiledResult, partialTilingResult->tiledValues);
766 for (
auto result : tiledResult) {
768 resultOffsets.emplace_back(std::move(outOffsets));
771 for (
size_t i = 0; i < offsets.size(); i++) {
774 resultSizes.emplace_back(std::move(outSizes));
784 initTensors, innerYieldTiledValuesFn, loops)))
788 loops.front()->getResults(), [](
OpResult r) ->
Value { return r; });
792 FailureOr<MergeResult> mergeResult =
793 op.mergeReductions(b, loc, replacements, reductionDims);
794 if (failed(mergeResult)) {
797 b.
replaceOp(op, mergeResult->replacements);
801 std::swap(reductionTilingResult.
mergeOps, mergeResult->mergeOps);
803 std::swap(reductionTilingResult.
loops, loops);
804 std::swap(reductionTilingResult.
replacements, mergeResult->replacements);
806 return reductionTilingResult;
818 static std::tuple<OpResult, std::optional<OpOperand *>>
821 std::optional<OpOperand *> destinationIterArg;
822 auto loopIt = loops.rbegin();
823 while (
auto iterArg = dyn_cast<BlockArgument>(source->
get())) {
825 if (iterArg.getOwner()->getParentOp() != loop)
827 source = loop.getTiedLoopInit(iterArg);
830 if (loopIt == loops.rend())
831 destinationIterArg = source;
832 return {dyn_cast<OpResult>(source->
get()), destinationIterArg};
837 std::optional<scf::SCFFuseProducerOfSliceResult>
839 RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
843 auto [fusableProducer, destinationInitArg] =
846 if (!fusableProducer)
848 unsigned resultNumber = fusableProducer.getResultNumber();
856 Operation *fusableProducerOp = fusableProducer.getOwner();
857 if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
859 rewriter, fusableProducerOp->
getLoc(), fusableProducerOp,
860 origDestinationTensors)))
863 clonedOpDestinationTensors = origDestinationTensors;
864 if (destinationInitArg &&
865 isa<DestinationStyleOpInterface>(fusableProducerOp)) {
869 clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
873 rewriter, fusableProducerOp, clonedOpDestinationTensors);
878 llvm::to_vector(candidateSliceOp->getOperands());
879 candidateSliceOpOperands[0] = clonedProducerOp->
getResult(resultNumber);
880 tensor::ExtractSliceOp clonedCandidateSliceOp =
882 candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
885 FailureOr<TilingResult> tileAndFuseResult =
887 rewriter, clonedCandidateSliceOp,
888 clonedProducerOp->
getResult(resultNumber));
889 if (failed(tileAndFuseResult))
894 tileAndFuseResult->tiledValues[0]);
895 rewriter.
eraseOp(clonedCandidateSliceOp);
896 rewriter.
eraseOp(clonedProducerOp);
941 if (destinationInitArg &&
942 isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
944 ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
945 .set(origDestinationTensors[resultNumber]);
948 tileAndFuseResult->tiledValues[0],
962 *tiledOwner = fusedProducerInfo.
tiledOps[0];
967 yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
969 : llvm::to_vector(yieldResultNumber);
971 for (
const auto &resultNumber : initNumberList) {
973 rewriter, loc, originalOwner->
getResult(resultNumber));
974 if (succeeded(initValue)) {
975 initValueList.push_back(initValue.value());
990 sliceSizes = sliceOp.getMixedSizes();
993 if (llvm::any_of(sliceOp.getMixedStrides(), [](
OpFoldResult ofr) {
994 return !isConstantIntValue(ofr, 1);
998 unsigned sliceResultNumber =
1001 auto tilableOp = cast<TilingInterface>(originalOwner);
1005 if (tilableOp->getNumResults() > 1 &&
1006 failed(tilableOp.getIterationDomainTileFromResultTile(
1007 rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1008 iterDomainOffset, iterDomainSizes))) {
1023 for (
const auto &resultNumber : initNumberList) {
1024 if (resultNumber == sliceResultNumber) {
1025 offsetList.push_back(sliceOffset);
1026 sizesList.push_back(sliceSizes);
1028 assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1031 if (failed(tilableOp.getResultTilePosition(
1032 rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1036 offsetList.push_back(offset);
1037 sizesList.push_back(sizes);
1042 if (
auto tiledDestStyleOp =
1043 dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1045 for (
const auto &&[index, newRegionArg] :
1047 auto destSlice = rewriter.
create<tensor::ExtractSliceOp>(
1048 loc, newRegionArg, offsetList[index], sizesList[index],
1051 unsigned resultNumber = initNumberList[index];
1053 tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1062 for (
const auto &&[index, resultNumber] :
llvm::enumerate(initNumberList)) {
1063 tiledResult.push_back(tiledOwner->getResult(resultNumber));
1064 tiledOffset.emplace_back(offsetList[index]);
1065 tiledSizes.emplace_back(sizesList[index]);
1075 FailureOr<scf::SCFTileAndFuseResult>
1081 if (!consumer->getNumResults()) {
1083 consumer,
"invalid pattern for op with no results");
1088 llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
1090 FailureOr<scf::SCFTilingResult> tilingResult =
1093 if (failed(tilingResult))
1095 for (
auto *tiledOp : tilingResult->tiledOps)
1096 tiledAndFusedOps.insert(tiledOp);
1099 auto &loops = tilingResult->loops;
1100 if (loops.empty()) {
1102 for (
auto [origVal, replacement] :
1103 llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
1104 replacements[origVal] = replacement;
1114 for (
auto [index, result] :
llvm::enumerate(consumer->getResults())) {
1115 origValToResultNumber[result] = index;
1125 auto addCandidateSlices = [](
Operation *fusedOp,
1126 std::deque<tensor::ExtractSliceOp> &candidates) {
1128 if (
auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
1129 candidates.push_back(sliceOp);
1132 std::deque<tensor::ExtractSliceOp> candidates;
1133 addCandidateSlices(tiledAndFusedOps.back(), candidates);
1135 while (!candidates.empty()) {
1137 tensor::ExtractSliceOp candidateSliceOp = candidates.front();
1138 candidates.pop_front();
1141 auto [fusableProducer, destinationInitArg] =
1144 if (!fusableProducer)
1147 auto [fuseSlice, yieldReplacement] =
options.fusionControlFn(
1148 candidateSliceOp, fusableProducer, destinationInitArg.has_value());
1155 std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1160 if (yieldReplacement) {
1164 Operation *fusableProducerOp = fusableProducer.getOwner();
1166 rewriter, candidateSliceOp, fusedResult.value(), loops))) {
1168 fusableProducerOp,
"failed to replacement value for this "
1169 "operation from within the tiled loop");
1171 for (
auto [index, result] :
1173 origValToResultNumber[result] = loops.front()->getNumResults() -
1180 fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1181 fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1182 tiledAndFusedOps.insert(tiledAndFusedOp);
1183 addCandidateSlices(tiledAndFusedOp, candidates);
1188 for (
auto [origVal, resultNumber] : origValToResultNumber) {
1189 replacements[origVal] = loops.front()->getResult(resultNumber);
1202 static LogicalResult
1204 Value result = candidateSliceOp.getResult();
1206 if (!llvm::hasSingleElement(uses)) {
1207 LLVM_DEBUG(llvm::dbgs() <<
"Too many uses of the candidate slice op\n");
1210 OpOperand &operandUse = (*uses.begin());
1212 if (!isa<scf::YieldOp>(userOp)) {
1213 LLVM_DEBUG(llvm::dbgs()
1214 <<
"Expected scf.yield to be the only user, but got -> "
1219 LLVM_DEBUG(llvm::dbgs() <<
"Expected tensor.insert_slice and scf.yield to "
1220 "be in the same block\n");
1230 Block *containingOpBlock) {
1232 if (!llvm::hasSingleElement(val.
getUses()))
1240 if (!isa<TilingInterface>(consumerOp) ||
1241 !isa<DestinationStyleOpInterface>(consumerOp))
1243 if (containingOpBlock != consumerOp->
getBlock())
1252 static FailureOr<OpOperand *>
1256 Value sliceResult = candidateSliceOp.getResult();
1262 auto forOp = dyn_cast<scf::ForOp>(containingOp);
1265 Value resultingValue = forOp->getResult(resultNumber);
1272 static FailureOr<OpOperand *>
1275 Value sliceDest = candidateSliceOp.getDest();
1276 auto iterArg = dyn_cast<BlockArgument>(sliceDest);
1283 auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
1286 Value resultingValue =
1287 forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
1308 if (userOp == consumerOp)
1310 if (parentBlock != userOp->getBlock() ||
1320 if (
auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1322 }
else if (
auto parallelInsertSlice =
1323 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1338 scf::YieldOp oldTerminatorOp =
1339 cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
1340 unsigned totalOldResults = oldTerminatorOp->getNumResults();
1341 unsigned totalTiledResults = tilingResult.
tiledOps[0]->getNumResults();
1343 newYieldOperands.reserve(totalOldResults + totalTiledResults);
1344 for (
auto oldResult : oldTerminatorOp.getResults()) {
1345 newYieldOperands.push_back(oldResult);
1349 for (
auto [tiledResult, bbArg, resultOffset, resultSize] :
1350 llvm::zip_equal(tilingResult.
tiledOps[0]->getResults(), bbArgs,
1351 resultOffsets, resultSizes)) {
1354 Value newInsertSliceOp = rewriter.
create<tensor::InsertSliceOp>(
1355 loc, tiledResult, bbArg, resultOffset, resultSize, strides);
1356 newYieldOperands.push_back(newInsertSliceOp);
1358 rewriter.
create<scf::YieldOp>(loc, newYieldOperands);
1359 rewriter.
eraseOp(oldTerminatorOp);
1370 scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
1373 (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
1374 for (
auto [tiledResult, bbArg, resultOffset, resultSize] :
1375 llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
1378 rewriter.
create<tensor::ParallelInsertSliceOp>(
1379 firstYieldOpLoc, tiledResult, bbArg, resultOffset, resultSize, strides);
1385 FailureOr<scf::SCFFuseConsumerOfSliceResult>
1388 if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
1392 bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
1396 FailureOr<OpOperand *> maybeConsumerOpOperand =
1398 if (failed(maybeConsumerOpOperand)) {
1400 "could not fetch consumer to fuse");
1402 OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
1405 unsigned resultNumber = 0;
1406 if (
auto producerResult = dyn_cast<OpResult>(consumerOpOperand->
get())) {
1407 resultNumber = producerResult.getResultNumber();
1410 consumerOp,
"consumer op's operand doesn't seem to be an OpResult");
1415 Block *oldLoopBody =
nullptr;
1416 unsigned initSize = 0;
1418 if (isInsertSliceOp) {
1421 llvm::append_range(newOuts, forOp.getInits());
1422 oldLoopBody = forOp.getBody();
1423 initSize = forOp.getInits().size();
1426 oldLoopOp = forallOp;
1427 llvm::append_range(newOuts, forallOp.getOutputs());
1428 oldLoopBody = forallOp.getBody();
1429 initSize = forallOp.getOutputs().size();
1430 rank = forallOp.getRank();
1435 oldLoopOp,
"containing loop op should either yield just one value or "
1436 "have the consumer op as its first user");
1442 auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
1444 llvm::map_to_vector(dstOp.getDpsInits(), [](
Value v) { return v; });
1445 if (llvm::is_contained(dpsInits, oldLoopOp->
getResult(resultNumber))) {
1448 "consumer op taking the result of scf.for as init is not supported");
1450 newOuts.append(dpsInits);
1457 Block *newLoopBody =
nullptr;
1458 if (isInsertSliceOp) {
1459 auto forOp = cast<scf::ForOp>(oldLoopOp);
1460 auto newForOp = rewriter.
create<scf::ForOp>(loc, forOp.getLowerBound(),
1461 forOp.getUpperBound(),
1462 forOp.getStep(), newOuts);
1463 newLoopOp = newForOp;
1464 newLoopBody = newForOp.getBody();
1466 auto forallOp = cast<scf::ForallOp>(oldLoopOp);
1467 auto newForallOp = rewriter.
create<scf::ForallOp>(
1468 loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1469 forallOp.getMixedStep(), newOuts, forallOp.getMapping());
1470 newLoopOp = newForallOp;
1471 rewriter.
eraseOp(newForallOp.getTerminator());
1472 newLoopBody = newForallOp.getBody();
1478 newLoopBody->
getArguments().take_front(oldNumArguments));
1484 tensor::InsertSliceOp clonedInsertSliceOp;
1486 dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
1487 auto newForallOp = cast<scf::ForallOp>(newLoopOp);
1489 clonedInsertSliceOp = rewriter.
create<tensor::InsertSliceOp>(
1490 loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
1491 sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
1494 clonedInsertSliceOp =
1495 cast<tensor::InsertSliceOp>(rewriter.
clone(*candidateSliceOp));
1499 auto newForOpBlockArgsForConsumerDest =
1500 newLoopBody->
getArguments().drop_front(oldNumArguments);
1502 rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
1506 OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
1508 operandToReplace.
set(clonedInsertSliceOp.getResult());
1514 cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
1515 FailureOr<TilingResult> tileAndFuseResult =
1517 rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
1518 if (failed(tileAndFuseResult)) {
1522 tileAndFuseResult->tiledOps[0]->getOperand(operandNumber),
1523 clonedInsertSliceOp.getSource());
1536 candidateSliceOp,
"containingOp's result yield with stride");
1541 if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
1542 rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
1543 iterDomainSizes))) {
1545 clonedConsumerOp,
"can't get iter domain position from input position");
1551 unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
1553 totalNumResultsOfConsumer);
1555 for (
auto [idx, v] :
llvm::enumerate(clonedConsumerOp->getResults())) {
1556 if (failed(clonedConsumerOp.getResultTilePosition(
1557 rewriter, idx, iterDomainOffsets, iterDomainSizes,
1558 resultOffsets[idx], resultSizes[idx]))) {
1561 "can't get result domain position from iter domain position");
1567 if (isInsertSliceOp) {
1568 auto newForOp = cast<scf::ForOp>(newLoopOp);
1570 rewriter, newForOp, *tileAndFuseResult, arrayRefOffsets, arrayRefSizes,
1571 newForOp.getBody()->getArguments().drop_front(1 + initSize));
1573 auto newForallOp = cast<scf::ForallOp>(newLoopOp);
1575 rewriter, newForallOp, tileAndFuseResult->tiledOps[0]->getResults(),
1576 arrayRefOffsets, arrayRefSizes,
1577 newForallOp.getBody()->getArguments().drop_front(rank + initSize));
1581 for (
auto &&[oldResult, newResult] :
1586 for (
auto &&[oldResult, newResult] :
1588 newLoopOp->
getResults().drop_front(initSize))) {
1594 rewriter.
eraseOp(clonedConsumerOp);
1598 &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
1599 tileAndFuseResult->tiledOps};
1606 FailureOr<SmallVector<scf::ForOp>>
1608 TilingInterface op) {
1612 op,
"unable to lower to loops operations with return values");
1619 for (
auto loopRange : domain) {
1626 auto loop = rewriter.
create<scf::ForOp>(op.
getLoc(), offsetVal, sizeVal,
1628 loops.push_back(loop);
1629 ivs.push_back(loop.getInductionVar());
1632 if (failed(op.generateScalarImplementation(rewriter, op.
getLoc(), ivs))) {
static llvm::ManagedStatic< PassManagerOptions > options
static LogicalResult checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp)
This utility currently checks whether the loop either :-.
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static LogicalResult checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp)
A utility function that checks whether the only use of the result of a tensor.insert_slice op is in a...
static void fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp, TilingResult &tilingResult, ArrayRef< SmallVector< OpFoldResult >> &resultOffsets, ArrayRef< SmallVector< OpFoldResult >> &resultSizes, ArrayRef< BlockArgument > bbArgs)
After fusing consumer into scf.for we want to modify the scf.yield operation to reflect the same by r...
std::function< LogicalResult(RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, SmallVector< Value > &tiledValues, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> YieldTiledValuesFn
A function that allows returning additional yielded values during yieldTiledValuesAndReplace.
static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using the loop construct specifed in options.
static FailureOr< OpOperand * > getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp)
Fetch the untiled consumer of a scf.for's result which is yielded by a tensor.insert_slice.
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Append the specified additional newInitOperands operands to the loops existing init operands (or simi...
static LogicalResult generateLoopNestUsingForOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ValueRange destinationTensors, YieldTiledValuesFn yieldTiledValuesFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.for operation.
static LogicalResult generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< Attribute > mappingVector, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.forall operation.
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< LoopLikeOpInterface > loops)
Return the untiled producer whose slice is used in a tiled consumer.
static bool tileDividesIterationDomain(Range loopRange)
static Operation * cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *op, ValueRange newDestArgs)
Clones the operation and updates the destination if the operation implements the DestinationStyleOpIn...
static LogicalResult addInitOperandsToLoopNest(RewriterBase &rewriter, MutableArrayRef< LoopLikeOpInterface > loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn)
Method to add new init values to a loop nest.
static FailureOr< OpOperand * > getConsumerFromUses(Value val, Block *containingOpBlock)
Fetches the OpOperand of the only user (and use) of the value val which implements TilingInterface an...
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, Value iv, OpFoldResult tileSize)
Returns the bounded tile size given the current iv, loopRange and tileSize, i.e., min(tileSize,...
static void fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp, SmallVector< Value > tiledResults, ArrayRef< SmallVector< OpFoldResult >> &resultOffsets, ArrayRef< SmallVector< OpFoldResult >> &resultSizes, ArrayRef< BlockArgument > bbArgs)
After fusing consumer into scf.forall we want to yield each of the resulting values by the tiled cons...
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
Operation * getOwner() const
Returns the operation that owns this result.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult yieldReplacementForFusedProducer(RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef< LoopLikeOpInterface > loops, ArrayRef< unsigned > yieldResultNumber=ArrayRef< unsigned >{})
Reconstruct the fused producer from within the tiled-and-fused code.
FailureOr< scf::SCFReductionTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
FailureOr< scf::SCFFuseConsumerOfSliceResult > tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp)
Implementation of fusing consumer of a single slice by computing the slice of the consumer in-place f...
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
std::optional< SCFFuseProducerOfSliceResult > tileAndFuseProducerOfSlice(RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing producer of a single slice by computing the slice of the producer in-place.
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
FailureOr< TilingResult > replaceInsertSliceWithTiledConsumer(OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp, OpOperand &consumerOp)
Method to swap an tensor.insert_slice with its consumer when the consumer implements the TilingInterf...
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Container for result values of tiling.
SmallVector< Operation * > tiledOps
Fuse the consumer of the source of candidateSliceOp by computing the required slice of the consumer i...
Fuse the producer of the source of candidateSliceOp by computing the required slice of the producer i...
SmallVector< Operation * > tiledOps
Transformation information returned after reduction tiling.
SmallVector< Value > replacements
The replacements to use for the results of the tiled operation.
SmallVector< Value > initialValues
Initial values used for reduction.
SmallVector< Operation * > parallelTiledOps
The partial reduction tiled op generated.
SmallVector< LoopLikeOpInterface > loops
The loop operations that iterate over the tiles.
SmallVector< Operation * > mergeOps
The final reduction operation merging all the partial reductions.
Options used to control tile + fuse.
Transformation information returned after tile and fuse.
Options to use to control tiling.
SCFTileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes for each operation.
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > ts)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.