MLIR 23.0.0git
TileUsingInterface.cpp
Go to the documentation of this file.
1//===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the tiling using TilingInterface.
10//
11//===----------------------------------------------------------------------===//
12
14
22#include "mlir/IR/Dominance.h"
28#include "llvm/ADT/ScopeExit.h"
29#include "llvm/ADT/TypeSwitch.h"
30#include "llvm/Support/Debug.h"
31#include <optional>
32
33#define DEBUG_TYPE "tile-using-interface"
34
35using namespace mlir;
36
37scf::SCFTilingOptions &
38scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
39 assert(!tileSizeComputationFunction && "tile sizes already set");
40 auto tileSizes = llvm::to_vector(ts);
41 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
42 return tileSizes;
43 };
44 return *this;
45}
46
47scf::SCFTilingOptions &
48scf::SCFTilingOptions::setNumThreads(ArrayRef<OpFoldResult> nt) {
49 assert(!numThreadsComputationFunction && "num tiles already set");
50 auto numThreads = llvm::to_vector(nt);
51 numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) {
52 return numThreads;
53 };
54 return *this;
55}
56
57/// Helper method to adjust the interchange vector to match the iteration
58/// domain.
61 size_t iterationDomainSize) {
62 SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
63 if (filledVector.size() < iterationDomainSize) {
64 auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
65 filledVector.append(range.begin(), range.end());
66 }
67 if (filledVector.size() > iterationDomainSize)
68 filledVector.resize(iterationDomainSize);
69 return filledVector;
70}
71
72//===----------------------------------------------------------------------===//
73// tileUsingSCF implementation.
74//===----------------------------------------------------------------------===//
75
76/// Verify the tile size options are set in a consistent manner.
77static LogicalResult verifyOptions(RewriterBase &rewriter, Location loc,
78 const scf::SCFTilingOptions &options) {
79 // Specifying number of threads is only supported on `scf.forall` op.
80 if (options.numThreadsComputationFunction &&
81 options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) {
82 return rewriter.notifyMatchFailure(
83 loc, "number of threads can only by specified when loop type is "
84 "set to use `scf.forall`");
85 }
86
87 // If specified, check that the interchange vector is a permutation.
88 if (!options.interchangeVector.empty()) {
89 if (!isPermutationVector(options.interchangeVector)) {
90 return rewriter.notifyMatchFailure(
91 loc, "invalid interchange vector, not a permutation of the entire "
92 "iteration space");
93 }
94 }
95 return success();
96}
97
98/// Method to instantiate the tile sizes and/or number of threads specified
99/// by the user.
100static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
101getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
102 ArrayRef<Range> iterationDomain,
103 const scf::SCFTilingOptions &options) {
104 OpFoldResult zero = rewriter.getIndexAttr(0);
105 SmallVector<OpFoldResult> tileSizes, numThreads;
106 size_t numLoops = iterationDomain.size();
107
108 // Check whether the number of tiles to use is specified.
109 if (options.numThreadsComputationFunction) {
110 numThreads = options.numThreadsComputationFunction(rewriter, op);
111 numThreads.resize(numLoops, zero);
112
113 // If the number of tiles is also specified, use that.
114 if (options.tileSizeComputationFunction) {
115 tileSizes = options.tileSizeComputationFunction(rewriter, op);
116 tileSizes.resize(numLoops, zero);
117 return {tileSizes, numThreads};
118 }
119
120 // Compute the tile sizes from the iteration domain and number
121 // of tiles as follows
122 // - niters = ceilDiv(ub - lb, step)
123 // - tileSize = ceilDiv(niters, numThreads)
124 AffineExpr s0, s1, s2;
125 bindSymbols(rewriter.getContext(), s0, s1, s2);
126 // TODO: The step here is assumed to be 1.
127 AffineExpr numItersExpr = (s1 - s0);
128 AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s2);
129 tileSizes.resize(numLoops, zero);
130 for (auto [index, range, nt] :
131 llvm::enumerate(iterationDomain, numThreads)) {
132 if (isZeroInteger(nt))
133 continue;
134
136 rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
137 }
138 tileSizes.resize(numLoops, zero);
139 return {tileSizes, numThreads};
140 }
141
142 // Enforce the convention that "tiling by zero"
143 // skips tiling a particular dimension. This convention is significantly
144 // simpler to handle instead of adjusting affine maps to account for missing
145 // dimensions.
146 assert(options.tileSizeComputationFunction &&
147 "expected tile sizes to be specified");
148 tileSizes = options.tileSizeComputationFunction(rewriter, op);
149 tileSizes.resize(numLoops, zero);
150
151 return {tileSizes, numThreads};
152}
153
154/// Checks if any of the tiled loops are not parallel.
155static LogicalResult checkTileSizes(TilingInterface op,
156 scf::SCFTilingOptions::LoopType loopType,
157 ReductionTilingStrategy reductionStrategy,
158 ArrayRef<OpFoldResult> givenTileSizes,
159 ArrayRef<OpFoldResult> numThreads) {
160 auto iterators = op.getLoopIteratorTypes();
161 assert(iterators.size() == givenTileSizes.size() &&
162 "expected as many tile size values as number of loops");
163 assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
164 "when specified, expected number of threads to use for each loop");
165
166 bool isParallelTiling = false;
167 for (auto [index, iterator, givenTileSize] :
168 llvm::enumerate(iterators, givenTileSizes)) {
169 if (!isConstantIntValue(givenTileSize, 0)) {
170 isParallelTiling |= iterator == utils::IteratorType::parallel;
171 }
172
173 if (loopType == scf::SCFTilingOptions::LoopType::ForallOp &&
174 reductionStrategy == ReductionTilingStrategy::FullReduction) {
175 // If num threads is specified, check that it is greater than one only for
176 // parallel dimensions.
177 if (!numThreads.empty()) {
178 if (std::optional<int64_t> constNumThreads =
179 getConstantIntValue(numThreads[index])) {
180 if (constNumThreads.value() > 1 &&
181 iterator != utils::IteratorType::parallel) {
182 op.emitWarning() << "tiling is not thread safe at axis #" << index;
183 }
184 }
185 continue;
186 }
187
188 if (std::optional<int64_t> constTileSize =
189 getConstantIntValue(givenTileSize)) {
190 if (constTileSize.value() > 0 &&
191 iterator != utils::IteratorType::parallel) {
192 op.emitWarning() << "tiling is not thread safe at axis #" << index;
193 }
194 }
195 }
196 }
197
198 if (reductionStrategy != ReductionTilingStrategy::FullReduction) {
199 if (isParallelTiling) {
200 return op->emitOpError("tiling parallel dimensions is not supported with "
201 "partial reduction tiling strategies");
202 }
203 }
204 return success();
205}
206
207/// Get the reduction dims that are tiled. This accounts for reduction dims
208/// that are specified as tiled, but the tile size is 0.
211 const scf::SCFTilingOptions &options) {
212 SetVector<unsigned> reductionDims;
213 for (auto dim : options.reductionDims) {
214 if (isConstantIntValue(givenTileSizes[dim], 0))
215 continue;
216 reductionDims.insert(dim);
217 }
218 return reductionDims;
219}
220
221/// Check if `stride` evenly divides the trip count `size - offset`.
222static bool tileDividesIterationDomain(Range loopRange) {
223 std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
224 if (!offsetAsInt)
225 return false;
226 std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
227 if (!sizeAsInt)
228 return false;
229 std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
230 if (!strideAsInt)
231 return false;
232 return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
233}
234
235/// Returns the bounded tile size given the current `offset`, `loopRange` and
236/// `tileSize`, i.e., `min(tileSize, range.end() - offset)`.
238 Range loopRange, OpFoldResult offset,
239 OpFoldResult givenTileSize) {
240 std::optional<int64_t> ts = getConstantIntValue(givenTileSize);
241 if (ts && ts.value() == 1)
242 return givenTileSize;
243
245 Range{loopRange.offset, loopRange.size, givenTileSize}))
246 return givenTileSize;
247
248 // The tile size to use (to avoid out of bounds access) is minimum of
249 // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
250 // loop.
251 AffineExpr s0, s1, d0;
252 bindDims(b.getContext(), d0);
253 bindSymbols(b.getContext(), s0, s1);
254 AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext());
255 Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
257 b, loc, minMap, SmallVector<OpFoldResult>{offset, size, givenTileSize});
258}
259
260/// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
261/// than `iterationSize`.
263 OpFoldResult numThreads,
264 OpFoldResult iterationSize) {
265 std::optional<int64_t> tileSizeConst = getConstantIntValue(givenTileSize);
266 std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
267 std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
268 if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
269 return false;
270 return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
271}
272
273/// Compute the `OpFoldResult`s that represents the multi-dimensional
274/// `offset`s and `size`s of the tile of the iteration space that the
275/// innermost loop body of the generated tiled loops corresponds to.
276static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
278 ArrayRef<Range> iterationDomain,
279 ArrayRef<OpFoldResult> givenTileSizes) {
280 SmallVector<OpFoldResult> offsets, sizes;
281 int materializedLoopNum = 0;
282 for (auto [givenTileSize, loopRange] :
283 llvm::zip_equal(givenTileSizes, iterationDomain)) {
284
285 // Non-tiled cases, set the offset and size to the
286 // `loopRange.offset/size`.
287 if (isZeroInteger(givenTileSize)) {
288 offsets.push_back(loopRange.offset);
289 sizes.push_back(loopRange.size);
290 continue;
291 }
292
293 Value iv = ivs[materializedLoopNum++];
294 OpFoldResult offset = getAsOpFoldResult(iv);
295 offsets.push_back(offset);
296 OpFoldResult size =
297 getBoundedTileSize(rewriter, loc, loopRange, offset, givenTileSize);
298 sizes.push_back(size);
299 }
300 return {offsets, sizes};
301}
302
303/// Function to return the bounds of the loops to be generated.
304static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
307 ArrayRef<OpFoldResult> givenTileSizes) {
308 SmallVector<OpFoldResult> lbs, ubs, steps;
309 for (auto [loopRange, givenTileSize] :
310 llvm::zip_equal(loopRanges, givenTileSizes)) {
311 // No loop if the tile size is 0.
312 if (isZeroInteger(givenTileSize))
313 continue;
314 lbs.push_back(loopRange.offset);
315 ubs.push_back(loopRange.size);
316 steps.push_back(givenTileSize);
317 }
318 return {lbs, ubs, steps};
319}
320
321/// Typedef for function that allows returning additional yielded values during
322/// `yieldTiledValuesAndReplace`.
323/// - `ivs` induction variable for the loop.
324/// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
325/// - `tiledValues` the tiled values to return. Must be of same size as
326/// `newbbArgs`, each element of this array is inserted into the corresponding
327/// element in `newbbArgs`.
328/// - `resultOffsets` is of the same size as `tiledValues` and represents
329/// the offsets to use when inserting corresponding element from `tiledValues`
330/// into the element from `newBbArgs`.
331/// - `resultSizes` is of the same size as `tiledValues` and represents
332/// the size of the corresponding element from `tiledValues` inserted into
333/// the element from `newBbArgs`.
334/// In case the method needs to return `failure()` the method is expected
335/// to clean up any inserted operations.
336using YieldTiledValuesFn = std::function<LogicalResult(
337 RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
338 SmallVector<Value> &tiledValues,
341
342/// Typedef for function that implements the body of a tiled loop.
343/// - `ivs` induction variable for the loop.
344/// - `tileOffsets` represents offsets for the tiled iteration space.
345/// - `tileSizes` represents the sizes for the tiled iteraiton space.
346/// - `outerDestinationTensors` tensor that holds the result. Is same size
347/// as the destination operands of the original operations.
348/// - `tiledResults` results of the tiled computation, corresponds to
349/// tiles of the original operation computed by the loop body.
350/// Should be same size as the `destinationTensors`
351/// - `resultOffsets` is of the same size as `tiledResults` and represents
352/// the offset to use when writing the corresponding element from
353/// `tiledResults` into `destinationTensors`.
354/// - `resultOffsets` is of the same size as `tiledResults` and represents
355/// the size to use when writing the corresponding element from
356/// `tiledResults` into `destinationTensors`.
357/// In case the method needs to return `failure()` the method is expected
358/// to clean up any inserted operations.
359using GenerateTiledBodyFn = std::function<LogicalResult(
360 RewriterBase &rewriter, Location Loc, ValueRange ivs,
361 ArrayRef<OpFoldResult> tileOffsets, ArrayRef<OpFoldResult> tileSizes,
362 ValueRange outerDestinationTensors, SmallVector<Value> &tiledResults,
365
366/// Clones the operation and updates the destination if the operation
367/// implements the `DestinationStyleOpInterface`.
369 Operation *op,
370 ValueRange newDestArgs) {
371 Operation *clonedOp = rewriter.clone(*op);
372 if (newDestArgs.empty())
373 return clonedOp;
374 if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
375 destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
376 return clonedOp;
377}
378
379/// Generate the tile-loop nest using `scf.for` operation.
380/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
381/// - `givenTileSizes` is the tile sizes to use. Zero represent untiled loops.
382/// - `outerDestinationTensors` are the init values to use for the outer most
383/// loop.
384/// - `tiledBodyFn` is called to generated the loop body of the inner
385/// most
386/// loop.
387/// Returns the generated `scf.for` loops on success.
388static FailureOr<SmallVector<LoopLikeOpInterface>> generateLoopNestUsingForOp(
389 RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
390 ArrayRef<OpFoldResult> givenTileSizes, ValueRange outerDestinationTensors,
391 GenerateTiledBodyFn tiledBodyFn) {
392 assert(!loopRanges.empty() && "unexpected empty loop ranges");
393 assert(loopRanges.size() == givenTileSizes.size() &&
394 "expected as many tile sizes as loop ranges");
395 OpBuilder::InsertionGuard guard(rewriter);
396
397 SmallVector<OpFoldResult> lbs, ubs, steps;
398 std::tie(lbs, ubs, steps) =
399 getLoopBounds(rewriter, loc, loopRanges, givenTileSizes);
400 SmallVector<Value> lbVals =
401 getValueOrCreateConstantIndexOp(rewriter, loc, lbs);
402 SmallVector<Value> ubVals =
403 getValueOrCreateConstantIndexOp(rewriter, loc, ubs);
404 SmallVector<Value> stepVals =
405 getValueOrCreateConstantIndexOp(rewriter, loc, steps);
406
409 ValueRange innerDestinationTensors(outerDestinationTensors);
410 for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
411 auto loop =
412 scf::ForOp::create(rewriter, loc, lb, ub, step, innerDestinationTensors,
413 [](OpBuilder &bodyBuilder, Location bodyLoc,
414 Value iv, ValueRange /*iterArgs*/) {});
415 loops.push_back(loop);
416 ivs.push_back(loop.getInductionVar());
417 rewriter.setInsertionPointToEnd(loop.getBody());
418 innerDestinationTensors = loop.getRegionIterArgs();
419 }
420 if (loops.empty())
421 return success();
422
423 // Compute the `offsets` and `sizes` to use for tiling.
424 SmallVector<OpFoldResult> offsets, sizes;
425 std::tie(offsets, sizes) =
426 getTileOffsetAndSizes(rewriter, loc, ivs, loopRanges, givenTileSizes);
427
428 SmallVector<Value> tiledResults;
429 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
430 if (failed(tiledBodyFn(rewriter, loc, ivs, offsets, sizes,
431 innerDestinationTensors, tiledResults, resultOffsets,
432 resultSizes))) {
433 return rewriter.notifyMatchFailure(
434 loc, "failed to generate inner tile loop body");
435 }
436 if (loops.empty())
437 return loops;
438
439 assert(tiledResults.size() == innerDestinationTensors.size() &&
440 "Number of results of body should be equal to number of iter args");
441
442 // 6. Yield all the results of the tiled operation.
443 SmallVector<Value> yieldedValues;
444 for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
445 llvm::zip_equal(tiledResults, innerDestinationTensors, resultOffsets,
446 resultSizes)) {
447 SmallVector<OpFoldResult> resultStride(resultOffset.size(),
448 rewriter.getIndexAttr(1));
449 auto insertSlice = tensor::InsertSliceOp::create(
450 rewriter, loc, tiledValue, destinationTensor, resultOffset, resultSize,
451 resultStride);
452 yieldedValues.push_back(insertSlice);
453 }
454 scf::YieldOp::create(rewriter, loc, yieldedValues);
455
456 // Add the scf.yield operations for all the outer loops.
457 for (auto [outerLoop, innerLoop] :
458 llvm::zip_equal(MutableArrayRef(loops).drop_back(),
459 MutableArrayRef(loops).drop_front())) {
460 rewriter.setInsertionPointToEnd(
461 cast<scf::ForOp>(outerLoop.getOperation()).getBody());
462 scf::YieldOp::create(rewriter, outerLoop.getLoc(), innerLoop->getResults());
463 }
464 return loops;
465}
466
467/// Compute the `OpFoldResult`s that represents the multi-dimensional
468/// `offset`s and `size`s of the tile of the iteration space that the
469/// innermost loop body of the generated tiled loops corresponds to
470/// when tiling using `forall` op. This is handle separately due to
471/// the special case handling needed for when the tiling is done by
472/// specifying number of threads.
473static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
475 ValueRange ivs,
476 ArrayRef<Range> iterationDomain,
477 ArrayRef<OpFoldResult> givenTileSizes,
478 ArrayRef<OpFoldResult> numThreads) {
479 if (numThreads.empty()) {
480 return getTileOffsetAndSizes(rewriter, loc, ivs, iterationDomain,
481 givenTileSizes);
482 }
483
484 SmallVector<OpFoldResult> offsets, sizes;
485 int materializedLoopNum = 0;
486
487 AffineExpr d0, d1, s0, s1;
488 AffineExpr offsetExpr, residualTileSizeExpr;
489 bindDims(rewriter.getContext(), d0, d1);
490 bindSymbols(rewriter.getContext(), s0, s1);
491 offsetExpr = d0 + d1 * s0;
492 residualTileSizeExpr = s1 - (d0 + d1 * s0);
493
494 for (auto [index, nt, givenTileSize, loopRange] :
495 llvm::enumerate(numThreads, givenTileSizes, iterationDomain)) {
496
497 // Non-tiled cases, set the offset and size to the
498 // `loopRange.offset/size`.
499 if (isZeroInteger(nt)) {
500 offsets.push_back(loopRange.offset);
501 sizes.push_back(loopRange.size);
502 continue;
503 }
504
505 Value iv = ivs[materializedLoopNum++];
507 rewriter, loc, offsetExpr,
508 ArrayRef<OpFoldResult>{loopRange.offset, iv, givenTileSize});
510 rewriter, loc, residualTileSizeExpr,
511 {loopRange.offset, nt, givenTileSize, loopRange.size});
512
513 OpFoldResult size = givenTileSize;
514 if (!isZeroInteger(residualTileSize)) {
515 OpFoldResult sizeMinusOffsetPerThread =
516 affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
517 {offset, loopRange.size});
519 rewriter, loc,
521 {sizeMinusOffsetPerThread, givenTileSize});
522 }
523
524 // Consider the case where the original loop was `[0, 100)`.
525 // If number of threads are `7`, the tile size would be computed as
526 // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6)
527 // - `offset = 0 + 6 * 15 = 105`
528 // - `tileSize = min(15, 100 - 105) = -5`
529 // To avoid negative tile sizes, we need to do a further
530 // `nonNegativeTileSize = affine.max(0, tileSize)`.
531 // This `max` can be avoided if
532 // `offset + tileSize * (numThreads - 1) < (ub - lb)`
533 if (!canOmitTileOffsetInBoundsCheck(givenTileSize, nt, loopRange.size)) {
534 AffineMap maxMap =
537 rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});
538 }
539
540 offsets.push_back(offset);
541 sizes.push_back(size);
542 }
543 return {offsets, sizes};
544}
545
546/// Generate the tile-loop nest using `scf.forall` operation.
547/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
548/// - `giventileSizes` is the tile sizes to use. Zero represent untiled loops.
549/// - `outerDestinationTensors` are the init values to use for the loop.
550/// - `mappingVector` is the mapping attributes to use for loop construction.
551/// Can be empty.
552/// - `tiledBodyFn` is called to generated the loop body of the inner
553/// most
554/// loop.
555/// Returns the generated `scf.forall` loop on success.
556static FailureOr<SmallVector<LoopLikeOpInterface>>
558 ArrayRef<Range> loopRanges,
559 ArrayRef<OpFoldResult> givenTileSizes,
560 ArrayRef<OpFoldResult> numThreads,
561 ArrayRef<Attribute> mappingVector,
562 ValueRange outerDestinationTensors,
563 GenerateTiledBodyFn tiledBodyFn) {
564 assert(!loopRanges.empty() && "unexpected empty loop ranges");
565 assert(loopRanges.size() == givenTileSizes.size() &&
566 "expected as many tile sizes as loop ranges");
567 OpBuilder::InsertionGuard guard(rewriter);
568
569 std::optional<ArrayAttr> mappingAttr;
570 if (!mappingVector.empty())
571 mappingAttr = rewriter.getArrayAttr(mappingVector);
572
573 scf::ForallOp forallOp;
574 bool useNumThreads = !numThreads.empty();
575
577 if (useNumThreads) {
578 // Prune the zero numthreads.
579 SmallVector<OpFoldResult> nonZeroNumThreads;
580 for (auto nt : numThreads) {
581 if (isZeroInteger(nt))
582 continue;
583 nonZeroNumThreads.push_back(nt);
584 }
585 forallOp = scf::ForallOp::create(rewriter, loc, nonZeroNumThreads,
586 outerDestinationTensors, mappingAttr);
587 } else {
588 SmallVector<OpFoldResult> lbs, ubs, steps;
589 std::tie(lbs, ubs, steps) =
590 getLoopBounds(rewriter, loc, loopRanges, givenTileSizes);
591 forallOp = scf::ForallOp::create(rewriter, loc, lbs, ubs, steps,
592 outerDestinationTensors, mappingAttr);
593 }
594 loops.push_back(forallOp);
595
596 rewriter.setInsertionPoint(forallOp.getTerminator());
597 ValueRange innerDestinationTensors = forallOp.getRegionOutArgs();
598 SmallVector<Value> ivs = forallOp.getInductionVars();
599
600 // Compute the `offsets` and `sizes` to use for tiling.
601 SmallVector<OpFoldResult> offsets, sizes;
602 std::tie(offsets, sizes) = getTileOffsetAndSizesWithForAllOp(
603 rewriter, loc, ivs, loopRanges, givenTileSizes, numThreads);
604
605 SmallVector<Value> tiledResults;
606 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
607 if (failed(tiledBodyFn(rewriter, loc, ivs, offsets, sizes,
608 innerDestinationTensors, tiledResults, resultOffsets,
609 resultSizes)))
610 return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
611
612 rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
613 for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
614 llvm::zip_equal(tiledResults, innerDestinationTensors, resultOffsets,
615 resultSizes)) {
616 SmallVector<OpFoldResult> resultStride(resultOffset.size(),
617 rewriter.getIndexAttr(1));
618
619 tensor::ParallelInsertSliceOp::create(rewriter, loc, tiledValue,
620 destinationTensor, resultOffset,
621 resultSize, resultStride);
622 }
623 return loops;
624}
625
626/// Generate the tile-loop nest using custom loop operation.
627/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
628/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
629/// - `destinationTensors` are the init values to use for the outer most loop.
630/// - `mappingVector` is the mapping attributes to use for loop construction.
631/// Can be empty.
632/// - `tiledBodyFn` is called to generated the loop body of the inner
633/// most
634/// loop.
635/// Returns the generated `scf.forall` loop on success.
636static FailureOr<SmallVector<LoopLikeOpInterface>>
638 RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
639 ArrayRef<OpFoldResult> givenTileSizes, ValueRange outerDestinationTensors,
640 const scf::SCFTilingOptions::GenerateLoopHeaderFn &generateLoopHeaderFn,
641 const scf::SCFTilingOptions::GenerateLoopTerminatorFn
642 &generateLoopTerminatorFn,
643 GenerateTiledBodyFn tiledBodyFn) {
644 assert(!loopRanges.empty() && "unexpected empty loop ranges");
645 assert(loopRanges.size() == givenTileSizes.size() &&
646 "expected as many tile sizes as loop ranges");
647 assert(generateLoopHeaderFn && generateLoopTerminatorFn &&
648 "expected loop header/terminator generation function");
649 OpBuilder::InsertionGuard guard(rewriter);
650
651 FailureOr<scf::SCFTilingOptions::CustomLoopHeaderInfo> loopHeaderInfo =
652 generateLoopHeaderFn(rewriter, loc, loopRanges, givenTileSizes,
653 outerDestinationTensors);
654 if (failed(loopHeaderInfo)) {
655 return failure();
656 }
657
659 SmallVector<Value> tiledResults;
660 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
661 if (failed(tiledBodyFn(rewriter, loc, ivs, loopHeaderInfo->tileOffset,
662 loopHeaderInfo->tileSizes,
663 loopHeaderInfo->destinationTensors, tiledResults,
664 resultOffsets, resultSizes))) {
665 return failure();
666 }
667
668 if (failed(generateLoopTerminatorFn(rewriter, loc, loopHeaderInfo->loops,
669 tiledResults, resultOffsets, resultSizes,
670 loopHeaderInfo->destinationTensors))) {
671 return failure();
672 }
673
674 return loopHeaderInfo->loops;
675}
676
677/// Generate the tile-loop nest using the loop construct specifed in `options`.
678/// - `options`: Tiling options specified.
679/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
680/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
681/// - `outerDestinationTensors` are the init values to use for the outer most
682/// loop.
683/// - `yieldTiledValuesFn` is called to generated the loop body of the inner
684/// most
685/// loop.
686/// Returns the generated loops on success.
687static FailureOr<SmallVector<LoopLikeOpInterface>> generateLoopNest(
688 RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options,
689 ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> givenTileSizes,
690 ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors,
691 GenerateTiledBodyFn tiledBodyFn) {
692 // If the tile sizes are all zero, no loops are generated. Just call the
693 // callback function to handle untiled case.
694 if (llvm::all_of(givenTileSizes, isZeroInteger)) {
695 SmallVector<Value> tiledResults;
696 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
697 auto tileOffsets =
698 llvm::map_to_vector(loopRanges, [](Range r) { return r.offset; });
699 auto tileSizes =
700 llvm::map_to_vector(loopRanges, [](Range r) { return r.size; });
701 if (failed(tiledBodyFn(rewriter, loc, ValueRange{}, tileOffsets, tileSizes,
702 destinationTensors, tiledResults, resultOffsets,
703 resultSizes))) {
704 return failure();
705 }
707 }
708 if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) {
709 return generateLoopNestUsingForOp(rewriter, loc, loopRanges, givenTileSizes,
710 destinationTensors, tiledBodyFn);
711 }
712 if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
714 rewriter, loc, loopRanges, givenTileSizes, numThreads,
715 options.mappingVector, destinationTensors, tiledBodyFn);
716 }
717 if (options.loopType == scf::SCFTilingOptions::LoopType::CustomOp) {
719 rewriter, loc, loopRanges, givenTileSizes, destinationTensors,
720 options.generateLoopHeaderFn, options.generateLoopTerminatorFn,
721 tiledBodyFn);
722 }
723 return rewriter.notifyMatchFailure(loc, "unhandled loop type");
724}
725
726static FailureOr<SmallVector<Value>> createInitialTensorsForTiling(
727 RewriterBase &rewriter, TilingInterface op,
728 ReductionTilingStrategy reductionStrategy, ArrayRef<Range> iterationDomain,
729 ArrayRef<OpFoldResult> numThreads, ArrayRef<OpFoldResult> givenTileSizes,
730 const SetVector<unsigned> &reductionDims) {
731 SmallVector<Value> initTensors;
732 Location loc = op->getLoc();
733 if (reductionStrategy == ReductionTilingStrategy::FullReduction) {
734 if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, initTensors)))
735 return failure();
736 return initTensors;
737 }
738
739 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
740 if (!redOp) {
741 return op->emitOpError(
742 "PartialReductionOuterReduction tiling strategy is only supported for "
743 "operations implementing PartialReductionOpInterface");
744 }
745 SmallVector<OpFoldResult> sizes(iterationDomain.size());
746 AffineExpr s0, s1, s2;
747 bindSymbols(rewriter.getContext(), s0, s1, s2);
748 AffineExpr sizeExpr = ((s0 - s1).ceilDiv(s2));
749 AffineExpr divExpr = s0.ceilDiv(s1);
750 for (auto [index, domain, tileSize] :
751 llvm::enumerate(iterationDomain, givenTileSizes)) {
752 if (!numThreads.empty()) {
753 // Untiled case.
754 if (isConstantIntValue(numThreads[index], 0)) {
756 rewriter, op.getLoc(), sizeExpr,
757 {domain.size, domain.offset, domain.stride});
758 continue;
759 }
760 sizes[index] = numThreads[index];
761 continue;
762 }
763
764 // Non reduction dimensions/non-tiled dimensions.
765 if (!reductionDims.contains(index) || isConstantIntValue(tileSize, 0)) {
767 rewriter, op.getLoc(), sizeExpr,
768 {domain.size, domain.offset, domain.stride});
769 continue;
770 }
771
772 if (reductionStrategy ==
774 sizes[index] = tileSize;
775 continue;
776 }
777
778 assert(reductionStrategy ==
781 rewriter, op.getLoc(), sizeExpr,
782 {domain.size, domain.offset, domain.stride});
784 rewriter, op.getLoc(), divExpr, {normalizedRange, tileSize});
785 }
786 return redOp.generateInitialTensorForPartialReduction(rewriter, loc, sizes,
787 reductionDims);
788}
789
790/// For the case of `ReductionTilingStrategy::PartialReductionOuterParallel`
791/// the `PartialReductionOpInterface` methods need the index of the parallel
792/// split reduction being executed.
795 ReductionTilingStrategy reductionStrategy, ValueRange ivs,
796 ArrayRef<OpFoldResult> numThreads,
797 ArrayRef<OpFoldResult> givenTileSizes,
798 const SetVector<unsigned> &reductionDims) {
799 SmallVector<OpFoldResult> splitReductionIvs;
800 splitReductionIvs.resize(reductionDims.size(), rewriter.getIndexAttr(0));
801 AffineExpr s0, s1;
802 bindSymbols(rewriter.getContext(), s0, s1);
803 AffineExpr divExpr = s0.floorDiv(s1);
804 int ivIndex = 0;
805 if (reductionStrategy ==
807 for (auto [index, reductionDim] : llvm::enumerate(reductionDims)) {
808 if (!numThreads.empty()) {
809 splitReductionIvs[index] = ivs[ivIndex++];
810 continue;
811 }
812 splitReductionIvs[index] = affine::makeComposedFoldedAffineApply(
813 rewriter, loc, divExpr,
814 ArrayRef<OpFoldResult>{ivs[ivIndex++], givenTileSizes[reductionDim]});
815 }
816 }
817 return splitReductionIvs;
818}
819
820static FailureOr<TilingResult>
821getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
822 ReductionTilingStrategy reductionStrategy,
823 ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
825 ArrayRef<OpFoldResult> numThreads,
826 ArrayRef<OpFoldResult> givenTileSizes,
827 const SetVector<unsigned> &reductionDims) {
828 if (reductionStrategy == ReductionTilingStrategy::FullReduction) {
829 return op.getTiledImplementation(rewriter, offsets, sizes);
830 }
831
832 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
833 if (!redOp) {
834 return rewriter.notifyMatchFailure(
835 op, "PartialReductionOuterReduction tiling strategy is only "
836 "supported for operations "
837 "implementing PartialReductionOpInterface");
838 }
839
840 SmallVector<OpFoldResult> splitReductionIvs =
841 getSplitReductionIvs(rewriter, op.getLoc(), reductionStrategy, ivs,
842 numThreads, givenTileSizes, reductionDims);
843 return redOp.tileToPartialReduction(rewriter, op.getLoc(), reductionStrategy,
844 regionIterArg, offsets, sizes,
845 reductionDims, splitReductionIvs);
846}
847
848static LogicalResult getResultTilePosition(
849 RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy,
850 int64_t index, Value tiledResult, TilingInterface op,
852 ValueRange ivs, ArrayRef<OpFoldResult> numThreads,
853 ArrayRef<OpFoldResult> givenTileSizes,
854 const SetVector<unsigned> &reductionDims,
855 SmallVector<OpFoldResult> &resultOffset,
856 SmallVector<OpFoldResult> &resultSize) {
857
858 if (reductionStrategy == ReductionTilingStrategy::FullReduction) {
859 return op.getResultTilePosition(rewriter, index, offsets, sizes,
860 resultOffset, resultSize);
861 }
862 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
863 if (!redOp) {
864 return rewriter.notifyMatchFailure(
865 op, "PartialReductionOuterReduction tiling strategy is only supported"
866 "for operations implementing PartialReductionOpInterface");
867 }
868 SmallVector<OpFoldResult> splitReductionIvs =
869 getSplitReductionIvs(rewriter, op.getLoc(), reductionStrategy, ivs,
870 numThreads, givenTileSizes, reductionDims);
871 return redOp.getPartialResultTilePosition(
872 rewriter, index, reductionStrategy, offsets, sizes, reductionDims,
873 splitReductionIvs, resultOffset, resultSize);
874}
875
876static FailureOr<MergeResult>
877mergeTilingResults(RewriterBase &rewriter, TilingInterface op,
878 ReductionTilingStrategy reductionStrategy,
879 const SetVector<unsigned> &reductionDims,
880 ValueRange partialResults) {
881 assert(reductionStrategy != ReductionTilingStrategy::FullReduction &&
882 "expected merge to be called for only partial reduction cases");
883
884 auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
885 if (!redOp) {
886 return rewriter.notifyMatchFailure(
887 op, "PartialReductionOuterReduction tiling strategy is only "
888 "supported for operations "
889 "implementing PartialReductionOpInterface");
890 }
891 return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
892 reductionDims);
893}
894
895/// Append the specified additional `newInitOperands` operands to the
896/// loops existing `init` operands (or similar), and replace `loopOp` with
897/// the new loop that has the additional init operands. The loop body of
898/// this loop is moved over to the new loop. `yieldTiledValuesFn`
899/// is called to get the new tiled values returned, and the offset
900/// and sizes at which the tiled value is inserted into the
901/// new region iter_args that correspond to the newly added init operands.
902template <typename LoopType>
903FailureOr<LoopLikeOpInterface>
905 ValueRange newInitOperands,
906 YieldTiledValuesFn yieldTiledValuesFn) {
907 return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
908}
909
910/// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`.
911template <>
912FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
913 scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
914 YieldTiledValuesFn yieldTiledValuesFn) {
915 OpBuilder::InsertionGuard g(rewriter);
916 Location loc = loopOp.getLoc();
917 rewriter.setInsertionPoint(loopOp);
918
919 auto inits = llvm::to_vector(loopOp.getInitArgs());
920 inits.append(newInitOperands.begin(), newInitOperands.end());
921 auto newLoop = scf::ForOp::create(
922 rewriter, loc, loopOp.getLowerBound(), loopOp.getUpperBound(),
923 loopOp.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {},
924 loopOp.getUnsignedCmp());
925
926 // Move the loop body to the new op.
927 Block *loopBody = loopOp.getBody();
928 Block *newLoopBody = newLoop.getBody();
929 rewriter.mergeBlocks(
930 loopBody, newLoopBody,
931 newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
932
933 auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator());
934 rewriter.setInsertionPoint(yieldOp);
935
936 SmallVector<Value> tiledValues;
937 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
938 ValueRange newRegionIterArgs =
939 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
940 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
941 newRegionIterArgs, tiledValues, resultOffsets,
942 resultSizes))) {
943 rewriter.eraseOp(newLoop);
944 return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");
945 }
946
947 SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
948 for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
949 llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
950 resultSizes)) {
951 SmallVector<OpFoldResult> resultStride(resultOffset.size(),
952 rewriter.getIndexAttr(1));
953 Value insert = tensor::InsertSliceOp::create(
954 rewriter, yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset,
955 resultSize, resultStride);
956 newYieldValues.push_back(insert);
957 }
958
959 rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
960 rewriter.replaceOp(loopOp,
961 newLoop->getResults().take_front(loopOp.getNumResults()));
962 return cast<LoopLikeOpInterface>(newLoop.getOperation());
963}
964
965/// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall`
966template <>
967FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
968 scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
969 YieldTiledValuesFn yieldTiledValuesFn) {
970 OpBuilder::InsertionGuard g(rewriter);
971 Location loc = loopOp.getLoc();
972 rewriter.setInsertionPoint(loopOp);
973 auto inits = llvm::to_vector(loopOp.getOutputs());
974 inits.append(newInitOperands.begin(), newInitOperands.end());
975 auto newLoop = scf::ForallOp::create(
976 rewriter, loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
977 loopOp.getMixedStep(), inits, loopOp.getMapping(),
978 [](OpBuilder &, Location, ValueRange) {});
979
980 // Move the region of the current block to the newly created op.
981 Block *loopBody = loopOp.getBody();
982 Block *newLoopBody = newLoop.getBody();
983 rewriter.mergeBlocks(
984 loopBody, newLoopBody,
985 newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
986
987 auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
988 rewriter.setInsertionPoint(terminator);
989 SmallVector<Value> tiledValues;
990 SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
991 ValueRange regionIterArgs =
992 newLoop.getRegionIterArgs().take_back(newInitOperands.size());
993 if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
994 regionIterArgs, tiledValues, resultOffsets,
995 resultSizes))) {
996 rewriter.eraseOp(newLoop);
997 return rewriter.notifyMatchFailure(loopOp,
998 "failed to get yielded tiled values");
999 }
1000
1001 // Update the terminator.
1002 rewriter.setInsertionPointToEnd(terminator.getBody());
1003
1004 for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
1005 tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
1006 SmallVector<OpFoldResult> resultStride(resultOffset.size(),
1007 rewriter.getIndexAttr(1));
1008 tensor::ParallelInsertSliceOp::create(rewriter, terminator.getLoc(),
1009 tiledValue, iterArg, resultOffset,
1010 resultSize, resultStride);
1011 }
1012
1013 rewriter.replaceOp(loopOp,
1014 newLoop->getResults().take_front(loopOp.getNumResults()));
1015 return cast<LoopLikeOpInterface>(newLoop.getOperation());
1016}
1017
1018/// Implementation of `yieldTiledValuesAndReplaceLoop` for
1019/// `LoopLikeOpInterface`, that just dispatches to the implementation for each
1020/// supported loop type.
1021FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
1022 LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
1023 ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
1025 loopLikeOp.getOperation())
1026 .Case<scf::ForOp, scf::ForallOp>(
1027 [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
1029 loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
1030 })
1031 .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
1032 return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
1033 });
1034}
1035
1036/// Method to add new init values to a loop nest. Updates `loops` in-place
1037/// with new loops that use the `newInitValues`. The outer-loops are updated
1038/// to yield the new result values of the inner loop. For the innermost loop,
1039/// the call back `getNewYields` is invoked to get the additional values to
1040/// yield form the innermost loop.
1041static LogicalResult addInitOperandsToLoopNest(
1043 ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
1044 if (loops.empty())
1045 return success();
1046 OpBuilder::InsertionGuard g(rewriter);
1047 rewriter.setInsertionPoint(loops.front());
1048
1050 for (auto &loop : loops.drop_back()) {
1051 rewriter.setInsertionPoint(loop);
1052
1053 // if loops.size() > 1 we assume that scf.for is used for the loops.
1054 auto forLoop = cast<scf::ForOp>(loop.getOperation());
1055
1056 // Create a new loop with the new init values for this loop.
1057 SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs());
1058 newInits.append(newInitValues.begin(), newInitValues.end());
1059 auto newLoop = scf::ForOp::create(
1060 rewriter, forLoop.getLoc(), forLoop.getLowerBound(),
1061 forLoop.getUpperBound(), forLoop.getStep(), newInits,
1062 [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {},
1063 forLoop.getUnsignedCmp());
1064
1065 // Merge the body of the new loop with the body of the old loops.
1066 SmallVector<Value> sourceBlockArgs;
1067 sourceBlockArgs.push_back(newLoop.getInductionVar());
1068 auto newRegionIterArgs = newLoop.getRegionIterArgs();
1069 sourceBlockArgs.append(
1070 newRegionIterArgs.begin(),
1071 std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
1072 rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
1073 rewriter.replaceOp(
1074 forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
1075 loop = newLoop;
1076 ivs.push_back(newLoop.getInductionVar());
1077 newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
1078 }
1079
1080 // Update the loop body of the innermost loop to get new yield values.
1081 LoopLikeOpInterface innerMostLoop = loops.back();
1082 FailureOr<LoopLikeOpInterface> newInnerMostLoop =
1083 yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues,
1084 getNewTiledYieldsFn);
1085
1086 if (failed(newInnerMostLoop))
1087 return innerMostLoop.emitOpError("failed to return additional yields");
1088 loops.back() = newInnerMostLoop.value();
1089
1090 // Make all other loops except the innermost loops yield the values returned
1091 // by the inner loop.
1092 for (auto [outerLoop, innerLoop] :
1093 llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
1094 // Again assume that all the outer loops are scf.for operations.
1095 auto outerForLoop = cast<scf::ForOp>(outerLoop.getOperation());
1096 auto outerLoopYield =
1097 cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
1098 SmallVector<Value> newYields =
1099 llvm::to_vector(outerLoopYield.getOperands());
1100 ValueRange additionalYields =
1101 innerLoop->getResults().take_back(newInitValues.size());
1102 newYields.append(additionalYields.begin(), additionalYields.end());
1103 rewriter.setInsertionPoint(outerLoopYield);
1104 rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
1105 }
1106 return success();
1107}
1108
1109/// Implementation of tiling transformation of `op` that implements the
1110/// `TilingInterface` using `scf.for` to iterate over the tiles.
1111FailureOr<scf::SCFTilingResult>
1112mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
1113 const scf::SCFTilingOptions &options) {
1114 if (failed(verifyOptions(rewriter, op.getLoc(), options))) {
1115 return failure();
1116 }
1117
1118 OpBuilder::InsertionGuard guard(rewriter);
1119 rewriter.setInsertionPointAfter(op);
1120
1121 // 1. Get the range of the loops that are represented by the operation.
1122 SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
1123
1124 // 2. Materialize the tile sizes and/or number of threads;
1125 SmallVector<OpFoldResult> givenTileSizes, numThreads;
1126 std::tie(givenTileSizes, numThreads) =
1127 getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options);
1128
1129 // Check if it is safe to tile. This is hold over from previous iterations
1130 // of tile to for-all. Consider dropping it.
1131 if (failed(checkTileSizes(op, options.loopType, options.reductionStrategy,
1132 givenTileSizes, numThreads))) {
1133 return failure();
1134 }
1135
1136 // Get the reduction dims
1137 SetVector<unsigned> reductionDims =
1138 getSanitizedReductionDims(givenTileSizes, options);
1139
1140 // 3. If there is an interchange specified, permute the iteration domain and
1141 // the tile sizes.
1142 SmallVector<int64_t> interchangeVector;
1143 if (!options.interchangeVector.empty()) {
1144 interchangeVector = fillInterchangeVector(options.interchangeVector,
1145 iterationDomain.size());
1146 assert(isPermutationVector(interchangeVector) &&
1147 "expected interchange vector to be a permutation");
1148
1149 applyPermutationToVector(iterationDomain, interchangeVector);
1150 applyPermutationToVector(givenTileSizes, interchangeVector);
1151 if (!numThreads.empty())
1152 applyPermutationToVector(numThreads, interchangeVector);
1153 }
1154
1155 FailureOr<TilingResult> tilingResult;
1156 // 4. Define the lambda function used later to generate the body of the
1157 // innermost tiled loop.
1158 GenerateTiledBodyFn innerYieldTiledValuesFn =
1159 [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
1160 ArrayRef<OpFoldResult> tileOffsets, ArrayRef<OpFoldResult> tileSizes,
1161 ValueRange regionIterArgs, SmallVector<Value> &tiledResults,
1164 -> LogicalResult {
1165 // 4b. If interchange was provided, apply inverse of the interchange
1166 // to get back the offsets/sizes in the order to be specified.
1167 SmallVector<OpFoldResult> tileOffsetsVec = llvm::to_vector(tileOffsets);
1168 SmallVector<OpFoldResult> tileSizesVec = llvm::to_vector(tileSizes);
1169 if (!interchangeVector.empty()) {
1170 auto inversePermutation = invertPermutationVector(interchangeVector);
1173 }
1174
1175 // 5. Generate the tiled implementation within the inner most loop.
1176
1177 // 5a. Clone the operation within the loop body.
1178 auto clonedOp = cast<TilingInterface>(
1179 cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
1180
1181 // 5b. Early return cloned op if tiling is not happening. We can not
1182 // return the original op because it could lead to `rewriter.replaceOp(op,
1183 // op->getResults())` and users would get crash.
1184 if (llvm::all_of(givenTileSizes, isZeroInteger)) {
1185 tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
1186 tilingResult =
1187 TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
1188 /*generatedSlices=*/{}};
1189 return success();
1190 }
1191
1192 // 5c. Tile the cloned operation.
1193 tilingResult =
1194 getTiledImplementation(rewriter, clonedOp, options.reductionStrategy,
1195 regionIterArgs, tileOffsetsVec, tileSizesVec,
1196 ivs, numThreads, givenTileSizes, reductionDims);
1197 if (failed(tilingResult)) {
1198 rewriter.eraseOp(clonedOp);
1199 return op.emitOpError("faild to tile operation");
1200 }
1201
1202 // 5d. Delete the cloned operation.
1203 rewriter.eraseOp(clonedOp);
1204
1205 // 5e. Compute the offsets at which the result values are to be inserted
1206 // back into its destinations.
1207 for (auto [index, tiledValue] :
1208 llvm::enumerate(tilingResult->tiledValues)) {
1209 tiledResults.push_back(tiledValue);
1210 SmallVector<OpFoldResult> resultOffset, resultSize;
1212 rewriter, options.reductionStrategy, index, tiledValue, op,
1213 tileOffsetsVec, tileSizesVec, ivs, numThreads, givenTileSizes,
1214 reductionDims, resultOffset, resultSize))) {
1215 for (auto op : tilingResult->tiledOps) {
1216 rewriter.eraseOp(op);
1217 }
1218 return rewriter.notifyMatchFailure(
1219 op, "failed to get slice of result produced");
1220 }
1221 resultOffsets.emplace_back(std::move(resultOffset));
1222 resultSizes.emplace_back(std::move(resultSize));
1223 }
1224
1225 return success();
1226 };
1227
1228 // 6. Find the destination tensors to use for the operation.
1229 FailureOr<SmallVector<Value>> maybeInits = createInitialTensorsForTiling(
1230 rewriter, op, options.reductionStrategy, iterationDomain, numThreads,
1231 givenTileSizes, reductionDims);
1232 if (failed(maybeInits)) {
1233 return rewriter.notifyMatchFailure(
1234 op, "unable to create initial tensors for tiling");
1235 }
1236 SmallVector<Value> &initTensors = maybeInits.value();
1237
1238 // 7. Generate the tiled loops nest using the callback defined above.
1240 {
1241 FailureOr<SmallVector<LoopLikeOpInterface>> loopsOr = generateLoopNest(
1242 rewriter, op.getLoc(), options, iterationDomain, givenTileSizes,
1243 numThreads, initTensors, innerYieldTiledValuesFn);
1244 if (failed(loopsOr))
1245 return op.emitOpError("failed to generate tiling loops");
1246 assert(succeeded(tilingResult) &&
1247 "expected tiling result to be computed after loop generation");
1248 std::swap(loops, loopsOr.value());
1249 }
1250
1251 if (loops.empty()) {
1252 // If loops are empty, the tiled op is used as the replacement for the
1253 // untiled op.
1254 return scf::SCFTilingResult{tilingResult->tiledOps,
1255 initTensors,
1256 loops,
1257 tilingResult->tiledValues,
1258 tilingResult->generatedSlices,
1259 {}};
1260 }
1261
1262 auto loopResults = llvm::map_to_vector(loops.front()->getResults(),
1263 [](OpResult r) -> Value { return r; });
1264
1265 // For the full reduction case, there is nothing more to do.
1266 if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) {
1267 return scf::SCFTilingResult{
1268 tilingResult->tiledOps, initTensors, loops, loopResults,
1269 tilingResult->generatedSlices, {}};
1270 }
1271
1272 // The results of the loop needs to be merged.
1273 FailureOr<MergeResult> mergeResult = mergeTilingResults(
1274 rewriter, op, options.reductionStrategy, reductionDims, loopResults);
1275 if (failed(mergeResult)) {
1276 return rewriter.notifyMatchFailure(
1277 op, "Failed to merge partial results from tiling");
1278 }
1279 return scf::SCFTilingResult{tilingResult->tiledOps,
1280 initTensors,
1281 loops,
1282 mergeResult->replacements,
1283 tilingResult->generatedSlices,
1284 mergeResult->mergeOps};
1285}
1286
1287FailureOr<scf::SCFTilingResult>
1288mlir::scf::tileReductionUsingScf(RewriterBase &b,
1289 PartialReductionOpInterface op,
1290 ArrayRef<OpFoldResult> tileSize) {
1291 scf::SCFTilingOptions options;
1292 options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
1293 options.setReductionTilingStrategy(
1295 options.setTileSizes(tileSize);
1296 SmallVector<unsigned> reductionDims;
1297 for (auto [index, iteratorType] : llvm::enumerate(op.getLoopIteratorTypes()))
1298 if (iteratorType == utils::IteratorType::reduction)
1299 reductionDims.push_back(index);
1300 options.setReductionDims(reductionDims);
1301 return tileUsingSCF(b, op, options);
1302}
1303
1304//===----------------------------------------------------------------------===//
1305// tileConsumerAndFuseProducersUsingSCF implementation.
1306//===----------------------------------------------------------------------===//
1307
1308/// Return the untiled producer whose slice is used in a tiled consumer. The
1309/// method traverses the tile loop nest (`loops`) if needed, and returns the
1310/// `iter_args` of the outer most that is encountered. Traversing the
1311/// iter_args indicates that this is a destination operand of the consumer. If
1312/// there was no loop traversal needed, the second value of the returned tuple
1313/// is empty.
1314static std::tuple<OpResult, std::optional<OpOperand *>>
1317 std::optional<OpOperand *> destinationIterArg;
1318 assert(!loops.empty() && "expected non empty loops container");
1319 auto loopIt = loops.rbegin();
1320 while (loopIt != loops.rend() && isa<BlockArgument>(source->get())) {
1321 auto iterArg = cast<BlockArgument>(source->get());
1322 auto loop = *loopIt;
1323 if (iterArg.getOwner()->getParentOp() != loop)
1324 break;
1325 source = loop.getTiedLoopInit(iterArg);
1326 loopIt++;
1327 }
1328 if (loopIt == loops.rend())
1329 destinationIterArg = source;
1330
1331 auto result = dyn_cast<OpResult>(source->get());
1332 if (result) {
1333 Operation *producer = result.getOwner();
1334 Operation *innermostLoop = loops.back();
1335 // If the producer is already inside the innermost loop (where the slice
1336 // is), it has already been fused. Skip it to avoid infinite loops.
1337 if (innermostLoop->isProperAncestor(producer))
1338 return {OpResult(), std::nullopt};
1339 }
1340
1341 return {result, destinationIterArg};
1342}
1343
1344/// Implementation of fusing producer of a single slice by computing the
1345/// slice of the producer in-place.
1346std::optional<scf::SCFFuseProducerOfSliceResult>
1347mlir::scf::tileAndFuseProducerOfSlice(
1348 RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1350 // 1. Get the producer of the source (potentially walking through
1351 // `iter_args` of nested `scf.for`)
1352 auto [fusableProducer, destinationInitArg] =
1353 getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
1354 loops);
1355 if (!fusableProducer)
1356 return std::nullopt;
1357 unsigned resultNumber = fusableProducer.getResultNumber();
1358
1359 OpBuilder::InsertionGuard g(rewriter);
1360 rewriter.setInsertionPoint(candidateSliceOp);
1361
1362 // 2. Clone the fused producer
1363 // 2a. Compute the destination operands to use for the cloned operation.
1364 SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
1365 Operation *fusableProducerOp = fusableProducer.getOwner();
1366 if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
1368 rewriter, fusableProducerOp->getLoc(), fusableProducerOp,
1369 origDestinationTensors)))
1370 return std::nullopt;
1371
1372 clonedOpDestinationTensors = origDestinationTensors;
1373 if (destinationInitArg &&
1374 isa<DestinationStyleOpInterface>(fusableProducerOp)) {
1375 // 2b. If the producer is also destination style, then to maintain the
1376 // destination passing style, update the destination of the producer to be
1377 // the source of the slice.
1378 clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1379 }
1380 // 2c. Clone the fused producer.
1381 Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
1382 rewriter, fusableProducerOp, clonedOpDestinationTensors);
1383 // 2d. Update the source of the candidateSlice to be the cloned producer.
1384 // Easier to just clone the slice with different source since
1385 // replacements and DCE of cloned ops becomes easier
1386 SmallVector<Value> candidateSliceOpOperands =
1387 llvm::to_vector(candidateSliceOp->getOperands());
1388 candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
1389 tensor::ExtractSliceOp clonedCandidateSliceOp =
1390 mlir::clone(rewriter, candidateSliceOp,
1391 candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1392
1393 // 3. Generate the tiled implementation of the producer of the source
1394 FailureOr<TilingResult> tileAndFuseResult =
1396 rewriter, clonedCandidateSliceOp,
1397 clonedProducerOp->getResult(resultNumber));
1398 if (failed(tileAndFuseResult))
1399 return std::nullopt;
1400 // Note: Do not delete the candidateSliceOp, since its passed in from the
1401 // caller.
1402 rewriter.replaceAllUsesWith(candidateSliceOp,
1403 tileAndFuseResult->tiledValues[0]);
1404 rewriter.eraseOp(clonedCandidateSliceOp);
1405 rewriter.eraseOp(clonedProducerOp);
1406
1407 // 3. If the slice is for a destination operand, for example,
1408 //
1409 // ```mlir
1410 // %0 = linalg.init
1411 // %1 = linalg.fill .. outs(%0 : )
1412 // %2 = scf.for .. iter_args(%arg0 = %1) {
1413 // %3 = scf.for .. iter_args(%arg1 = %arg0) {
1414 // %4 = tensor.extract_slice %arg1 [..]
1415 // .. = linalg.matmul .. outs(%4 : )
1416 // }
1417 // }
1418 // ```
1419 //
1420 // the IR is currently
1421 //
1422 // ```
1423 // %0 = linalg.init
1424 // %1 = linalg.fill
1425 // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
1426 // %3 = scf.for .. iter_args(%arg1 = %arg0) {
1427 // %4 = tensor.extract_slice %arg1[..]
1428 // %5 = linalg.fill .. outs(%4 : )
1429 // .. = linalg.matmul .. outs(%5 : )
1430 // }
1431 // }
1432 // ```
1433 //
1434 // The untiled `linalg.fill` is still used as the `init_value` since it
1435 // was originally a destination operand of the untiled `linalg.matmul`.
1436 // When fusing an operand that is a destination operand, the iter_arg of
1437 // the outer most loop should be changed to use the destination of the
1438 // fused operation. With this the IR will be.
1439 //
1440 // ```
1441 // %0 = linalg.init
1442 // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
1443 // %2 = scf.for .. iter_args(%arg1 = %arg0) {
1444 // %3 = tensor.extract_slice %arg1[..]
1445 // %4 = linalg.fill .. outs(%3 : )
1446 // .. = linalg.matmul .. outs(%4 : )
1447 // }
1448 // }
1449 // ```
1450 if (destinationInitArg &&
1451 isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
1452 loops.front()
1453 ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1454 .set(origDestinationTensors[resultNumber]);
1455 }
1456 return scf::SCFFuseProducerOfSliceResult{
1457 fusableProducer, tileAndFuseResult->tiledValues[0],
1458 tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices};
1459}
1460
1461/// Reconstruct the fused producer from within the tiled-and-fused code.
1462FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
1463 RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1464 scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
1466 ArrayRef<unsigned> yieldResultNumber) {
1467 if (loops.empty())
1468 return success();
1469
1470 Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
1471 *tiledOwner = fusedProducerInfo.tiledOps[0];
1472
1473 Location loc = originalOwner->getLoc();
1474 // a. collect all init Value to be appended
1475 SmallVector<unsigned> initNumberList =
1476 yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
1477 0, originalOwner->getNumResults()))
1478 : llvm::to_vector(yieldResultNumber);
1479 SmallVector<Value> initValueList;
1480 for (const auto &resultNumber : initNumberList) {
1481 FailureOr<Value> initValue = tensor::getOrCreateDestination(
1482 rewriter, loc, originalOwner->getResult(resultNumber));
1483 if (succeeded(initValue)) {
1484 initValueList.push_back(initValue.value());
1485 } else {
1486 return failure();
1487 }
1488 }
1489
1490 SmallVector<Operation *> generatedSlices;
1491 YieldTiledValuesFn newYieldValuesFn =
1492 [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
1493 ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
1495 SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
1496 OpBuilder::InsertionGuard g(innerRewriter);
1497
1498 // get sliceOp tile information
1499 SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
1500 sliceSizes = sliceOp.getMixedSizes();
1501
1502 // expect all strides of sliceOp being 1
1503 if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
1504 return failure();
1505
1506 unsigned sliceResultNumber =
1507 fusedProducerInfo.origProducer.getResultNumber();
1508
1509 auto tilableOp = cast<TilingInterface>(originalOwner);
1510 // b. get iterDomain Offset and Sizes based on sliceOp tile
1511 SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
1512 // Set insertion point before any operations that might create new SSA
1513 // values used in offset/size computations. This ensures all values created
1514 // by getIterationDomainTileFromResultTile and getResultTilePosition
1515 // dominate the extract_slice operations created later.
1516 if (auto tiledDestStyleOp =
1517 dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1518 rewriter.setInsertionPoint(tiledDestStyleOp);
1519 }
1520 // skip tensor.pack/unpack/pad, which expects single opResult
1521 if (tilableOp->getNumResults() > 1 &&
1522 failed(tilableOp.getIterationDomainTileFromResultTile(
1523 rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1524 iterDomainOffset, iterDomainSizes))) {
1525 // In theory, it is unnecessary to raise an error here. Actually
1526 // although it fails to reconstruct the result tensor, it should not
1527 // broke current fusion anyway. The reason why we must return failure
1528 // currently is that the callback function `newYieldValuesFn` will be
1529 // called after new init operand(s) has already been appended. It will
1530 // take more refactoring to make sure the init operands are added
1531 // consistently in the future. For more details, please refer to:
1532 // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
1533 return failure();
1534 }
1535
1536 // c. calculate offsets and sizes info of all OpResults respectively based
1537 // on iteration Domain Tile
1538 SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
1539 for (const auto &resultNumber : initNumberList) {
1540 if (resultNumber == sliceResultNumber) {
1541 offsetList.push_back(sliceOffset);
1542 sizesList.push_back(sliceSizes);
1543 } else {
1544 assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1545 // infer result tile according to the iteration domain tile
1546 SmallVector<OpFoldResult> offset, sizes;
1547 if (failed(tilableOp.getResultTilePosition(
1548 rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1549 offset, sizes))) {
1550 return failure();
1551 }
1552 offsetList.push_back(offset);
1553 sizesList.push_back(sizes);
1554 }
1555 }
1556
1557 // d. create `extract_slice` for `iter_args` for DPS operation if
1558 // necessary
1559 if (auto tiledDestStyleOp =
1560 dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1561 for (const auto &&[index, newRegionArg] :
1562 llvm::enumerate(newRegionIterArgs)) {
1563 auto destSlice = tensor::ExtractSliceOp::create(
1564 rewriter, loc, newRegionArg, offsetList[index], sizesList[index],
1565 SmallVector<OpFoldResult>(offsetList[index].size(),
1566 rewriter.getIndexAttr(1)));
1567 generatedSlices.push_back(destSlice);
1568 unsigned resultNumber = initNumberList[index];
1569 rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
1570 tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1571 });
1572 }
1573 }
1574
1575 // e. prepare tiled offset and sizes for later `insert_slice` creation by
1576 // caller
1577 Block *block = rewriter.getInsertionPoint()->getBlock();
1578 rewriter.setInsertionPoint(block->getTerminator());
1579 for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
1580 tiledResult.push_back(tiledOwner->getResult(resultNumber));
1581 tiledOffset.emplace_back(offsetList[index]);
1582 tiledSizes.emplace_back(sizesList[index]);
1583 }
1584 return success();
1585 };
1586
1587 if (failed(addInitOperandsToLoopNest(rewriter, loops, initValueList,
1588 newYieldValuesFn))) {
1589 return failure();
1590 }
1591 return generatedSlices;
1592}
1593
1594namespace {
1595
1596//===----------------------------------------------------------------------===//
1597// SliceTrackingListener
1598//===----------------------------------------------------------------------===//
1599
1600/// This class is a listener for tracking the insertion and removal of
1601/// `tensor.extract_slice` ops in a worklist. This can be used in a greedy
1602/// fusion algorithm to apply cleanup patterns in between fusion steps.
1603class SliceTrackingListener : public RewriterBase::Listener {
1604public:
1605 explicit SliceTrackingListener(
1606 std::optional<FrozenRewritePatternSet> patterns);
1607 SliceTrackingListener() = default;
1608
1609 /// Adds the given list of operations to the worklist, and if present,
1610 /// applies the list of `patterns` to the newly added operations. This only
1611 /// processes the given operations and any newly inserted ones by the
1612 /// pattern set.
1613 LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
1614
1615 /// Add to the new operation worklist if it is an extract_slice.
1616 void notifyOperationInserted(Operation *op,
1617 OpBuilder::InsertPoint previous) override;
1618
1619 /// Shared helper for operation removal from the worklist.
1620 void removeOp(Operation *op);
1621
1622 /// Remove the operation from the worklist.
1623 void notifyOperationErased(Operation *op) override;
1624
1625 /// Remove the operation from the worklist.
1626 void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
1627
1628 /// The worklist for this transformation keeps track of the slices to visit
1629 /// next for fusion.
1630 std::deque<tensor::ExtractSliceOp> worklist;
1631
1632private:
1633 /// Optional pattern set to apply when adding new operations to the
1634 /// worklist.
1635 std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
1636};
1637
1638SliceTrackingListener::SliceTrackingListener(
1639 std::optional<FrozenRewritePatternSet> p) {
1640 patterns = std::move(p);
1641}
1642
1643LogicalResult
1644SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
1645 for (Operation *op : ops) {
1646 if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1647 worklist.push_back(slice);
1648 }
1649
1650 if (!patterns)
1651 return success();
1652
1654 ops, patterns.value(),
1655 GreedyRewriteConfig().setListener(this).setStrictness(
1656 GreedyRewriteStrictness::ExistingAndNewOps));
1657}
1658
1659void SliceTrackingListener::notifyOperationInserted(
1660 Operation *op, OpBuilder::InsertPoint previous) {
1661 auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1662 if (!slice)
1663 return;
1664 worklist.push_back(slice);
1665}
1666
1667// Scan the worklist for the given op and remove it if present. The
1668// expectation is for the worklist to be small and for removal to be
1669// relatively rare.
1670void SliceTrackingListener::removeOp(Operation *op) {
1671 if (!isa<tensor::ExtractSliceOp>(op))
1672 return;
1673 auto iter = worklist.begin();
1674 while (iter != worklist.end()) {
1675 if (*iter == op)
1676 break;
1677 iter++;
1678 }
1679 if (iter == worklist.end())
1680 return;
1681
1682 worklist.erase(iter);
1683}
1684
1685void SliceTrackingListener::notifyOperationErased(Operation *op) {
1686 removeOp(op);
1687}
1688
1689void SliceTrackingListener::notifyOperationReplaced(Operation *op,
1691 removeOp(op);
1692}
1693
1694//===----------------------------------------------------------------------===//
1695// ReplacementListener
1696//===----------------------------------------------------------------------===//
1697
1698/// Listener that tracks updates replacements for values which can be mutated.
1699/// This listener runs on top of the existing listener for the rewriter,
1700/// to make sure external users can still run listeners.
1701class ReplacementListener : public RewriterBase::ForwardingListener {
1702public:
1703 ReplacementListener(DenseMap<Value, Value> &replacements,
1704 OpBuilder::Listener *listener)
1705 : ForwardingListener(listener), replacements(replacements) {}
1706
1707 void updateReplacementValues(ValueRange origValues,
1708 ValueRange replaceValues) {
1709 // This can probably be written better, but just iterates over the map
1710 // and the new replacements for now.
1711 for (auto &[key, val] : replacements) {
1712 for (auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) {
1713 if (val == orig) {
1714 val = replace;
1715 }
1716 }
1717 }
1718 }
1719
1720 void notifyOperationReplaced(Operation *op, Operation *newOp) override {
1721 ForwardingListener::notifyOperationReplaced(op, newOp);
1722 updateReplacementValues(op->getResults(), newOp->getResults());
1723 }
1724
1725 void notifyOperationReplaced(Operation *op, ValueRange values) override {
1726 ForwardingListener::notifyOperationReplaced(op, values);
1727 updateReplacementValues(op->getResults(), values);
1728 }
1729
1730private:
1731 DenseMap<Value, Value> &replacements;
1732};
1733
1734} // namespace
1735
1736/// Implementation of tile consumer and fuse producer greedily.
1737FailureOr<scf::SCFTileAndFuseResult>
1738mlir::scf::tileConsumerAndFuseProducersUsingSCF(
1739 RewriterBase &rewriter, TilingInterface consumer,
1740 const scf::SCFTileAndFuseOptions &options) {
1741 // This transformation is only valid for ops that return values (i.e. not
1742 // valid to use with operations that have memref operands).
1743 if (!consumer->getNumResults()) {
1744 return rewriter.notifyMatchFailure(
1745 consumer, "invalid pattern for op with no results");
1746 }
1747
1748 // 1. First tile the consumer.
1749 SetVector<Operation *> fusedProducers, tiledAndFusedOps;
1750
1751 FailureOr<scf::SCFTilingResult> tilingResult =
1752 tileUsingSCF(rewriter, consumer, options.tilingOptions);
1753
1754 if (failed(tilingResult))
1755 return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
1756 tiledAndFusedOps.insert_range(tilingResult->tiledOps);
1757
1758 DenseMap<Value, Value> replacements;
1759 for (auto [origVal, replacement] :
1760 llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
1761 replacements[origVal] = replacement;
1762 }
1763
1764 // If there are no loops generated, fusion is immaterial.
1765 auto &loops = tilingResult->loops;
1766 if (loops.empty()) {
1767 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1768 replacements};
1769 }
1770
1771 // Since the loop gets potentially replaced during fusion, we need to track
1772 // the mutation of replacement values. To do this, we attach a listener to
1773 // update the replacements as they happen.
1774 OpBuilder::Listener *previousListener = rewriter.getListener();
1775 llvm::scope_exit resetListener(
1776 [&]() { rewriter.setListener(previousListener); });
1777 ReplacementListener replaceListener(replacements, previousListener);
1778 rewriter.setListener(&replaceListener);
1779
1780 // 2. Typically, the operands of the tiled operation are slices of the
1781 // operands of the untiled operation. These are expressed in IR using
1782 // `tensor.extract_slice` operations with source being the operands of
1783 // the untiled operation. Create a worklist of these
1784 // `tensor.extract_slice` operations. If the producers of the source of
1785 // the `tensor.extract_slice` can be tiled such that the tiled value is
1786 // generated in-place, that effectively tiles + fuses the operations.
1787 struct WorklistItem {
1788 tensor::ExtractSliceOp candidateSlice;
1789 SCFTileAndFuseOptions::ControlFnResult controlFnResult;
1790 };
1791
1792 SliceTrackingListener sliceTracker =
1793 SliceTrackingListener(options.cleanupPatterns);
1794
1795 if (failed(
1796 sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1797 return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1798 }
1799 OpBuilder::InsertionGuard g(rewriter);
1800 while (!sliceTracker.worklist.empty()) {
1801 auto candidateSlice = sliceTracker.worklist.front();
1802 sliceTracker.worklist.pop_front();
1803
1804 auto [fusableProducer, destinationInitArg] =
1805 getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(),
1806 loops);
1807 if (!fusableProducer)
1808 continue;
1809
1810 std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1811 options.fusionControlFn(candidateSlice, fusableProducer,
1812 destinationInitArg.has_value());
1813 if (!controlFnResult)
1814 continue;
1815
1816 WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
1817
1818 // The operands of the fused producer might themselved be slices of
1819 // values produced by operations that implement the `TilingInterface`.
1820 // Add these operations to the worklist.
1821 std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1822 tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice,
1823 loops);
1824 if (!fusedResult)
1825 continue;
1826
1827 SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
1828
1829 if (worklistItem.controlFnResult.yieldProducerReplacement) {
1830 // Reconstruct and yield all opResult of fusableProducerOp by default.
1831 // The caller can specific which one to yield by designating optional
1832 // argument named `yieldResultNumber` of
1833 // `yieldReplacementForFusedProducer`.
1834 Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1835 FailureOr<SmallVector<Operation *>> newSlices =
1836 yieldReplacementForFusedProducer(rewriter,
1837 worklistItem.candidateSlice,
1838 fusedResult.value(), loops);
1839 if (failed(newSlices)) {
1840 return rewriter.notifyMatchFailure(
1841 fusableProducerOp, "failed to replacement value for this "
1842 "operation from within the tiled loop");
1843 }
1844 worklistCandidates.append(newSlices.value());
1845 for (auto [index, result] :
1846 llvm::enumerate(fusableProducerOp->getResults())) {
1847 replacements[result] = loops.front()->getResult(
1848 loops.front()->getNumResults() -
1849 fusableProducerOp->getNumResults() + index);
1850 }
1851 }
1852 if (Operation *tiledAndFusedOp =
1853 fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1854 fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1855 tiledAndFusedOps.insert(tiledAndFusedOp);
1856 }
1857
1858 if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1859 return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1860 }
1861 }
1862
1863 return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1864 replacements};
1865}
1866
1867//===----------------------------------------------------------------------===//
1868// tileAndFuseConsumerUsingSCF implementation.
1869//===----------------------------------------------------------------------===//
1870
1871/// A utility function that checks whether the only use of the result of a
1872/// tensor.insert_slice op is in a scf.yield op.
1873static LogicalResult
1874checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
1875 Value result = candidateSliceOp.getResult();
1876 Value::use_range uses = result.getUses();
1877 if (!llvm::hasSingleElement(uses)) {
1878 LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
1879 return failure();
1880 }
1881 OpOperand &operandUse = (*uses.begin());
1882 Operation *userOp = operandUse.getOwner();
1883 if (!isa<scf::YieldOp>(userOp)) {
1884 LLVM_DEBUG(llvm::dbgs()
1885 << "Expected scf.yield to be the only user, but got -> "
1886 << (*userOp));
1887 return failure();
1888 }
1889 if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
1890 LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
1891 "be in the same block\n");
1892 return failure();
1893 }
1894 return success();
1895}
1896
1897/// An utility to get the first user of the given loopOp. If any of user stay
1898/// in different block of loopOp, return failure.
1899static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
1900 if (!isa<LoopLikeOpInterface>(loopOp))
1901 return failure();
1902 Operation *firstUserOfLoop = nullptr;
1903 for (Operation *userOp : loopOp->getUsers()) {
1904 // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
1905 // block with any other types of operation. Thus, just redirecting to its
1906 // parent `InParallelOp`. E.g.
1907 //
1908 // ```
1909 // %1 = scf.for {
1910 // ...
1911 // }
1912 // %2 = consumerOp ins(%1, ...)
1913 // scf.forall.in_parallel {
1914 // tensor.parallel_insert_slice %1
1915 // }
1916 // ```
1917 // where `InParallelOp` but not `ParallelInsertSlice` stays in the same
1918 // same block with `consumerOp`.
1919 if (isa<tensor::ParallelInsertSliceOp>(userOp))
1920 userOp = userOp->getParentOfType<scf::InParallelOp>();
1921
1922 if (loopOp->getBlock() != userOp->getBlock())
1923 return failure();
1924
1925 if (!firstUserOfLoop || userOp->isBeforeInBlock(firstUserOfLoop))
1926 firstUserOfLoop = userOp;
1927 }
1928 return firstUserOfLoop;
1929}
1930
1931/// This utility currently checks whether the first userOp of loop is NOT
1932/// before the last defineOp of consumer operand. Because that we need to move
1933/// the whole loop structure right before the `firstUserOfLoop`. This utility
1934/// thus helps ensuring that no invalid IR is formed, i.e. no backward slice
1935/// of consumerOp is dominated by the `firstUserOfLoop`. Saying that:
1936///
1937/// ```
1938/// %0 = scf.for() {
1939/// ...
1940/// }
1941/// ...
1942/// %1 = firstUserOfLoop(%0)
1943/// ...
1944/// %2 = lastDefOfConsumerOperand
1945/// ...
1946/// %3 = consumerOp(%2)
1947/// ```
1948///
1949/// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it
1950/// would be invalid to move the `loopOp` right before the `firstUserOfLoop`,
1951/// a.k.a. use-def chain violation:
1952///
1953/// ```
1954/// %0:2 = scf.for() {
1955/// // use before define error
1956/// %3 = tiledConsumerOp(%2)
1957/// }
1958/// %1 = firstUserOfLoop(%0)
1959/// ...
1960/// %2 = lastDefOfConsumerOperand
1961/// ```
1962///
1963/// @param loopOp: loop operation
1964/// @param consumerOp: consumer operation
1965/// @param reorderOperations: the flag controls whether to reorder the
1966/// backward slice w.r.t. the defineOp of `consumerOp` operands.
1967/// @return: computed backward slice of consumerOp, but excluding those
1968/// already dominates `firstUserOfLoop`.
1969static FailureOr<llvm::SetVector<Operation *>>
1971 bool reorderOperations) {
1972 FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
1973 if (failed(firstUserOfLoop))
1974 return failure();
1975
1977 DominanceInfo dominanceInfo;
1978 options.inclusive = true;
1979 options.omitBlockArguments = true;
1980 bool includeLoopOp = false;
1981 options.filter = [&](Operation *op) {
1982 if (op == loopOp) {
1983 includeLoopOp = true;
1984 return false;
1985 }
1986 // Cut off the slice to not include any operation that already dominates
1987 // firstUserOfLoop.
1988 return !dominanceInfo.properlyDominates(op, *firstUserOfLoop);
1989 };
1991 for (auto operand : consumerOp->getOperands()) {
1992 LogicalResult result = getBackwardSlice(operand, &slice, options);
1993 assert(result.succeeded() && "expected a backward slice");
1994 (void)result;
1995 }
1996
1997 if (!slice.empty()) {
1998 // If consumerOp has one producer, which is also the user of loopOp.
1999 // E.g.
2000 // ```
2001 // %0 = %loopOp
2002 // %1 = consumerOp1 ins(%0)
2003 // %2 = consumerOp2 ins(%0, %1)
2004 // ```
2005 // We can not fuse consumerOp2 into loopOp due to UD chain, unless
2006 // consumerOp1 has already been fused into loopOp before.
2007 if (includeLoopOp || !reorderOperations)
2008 return failure();
2009 }
2010
2011 return slice;
2012}
2013
2014/// Fetches the OpOperand of the first valid user (and use) of the value `val`
2015/// which implements `TilingInterface` and `DestinationStyleOpInterface`.
2016/// Returns failure otherwise.
2017static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
2018 Operation *loopOp,
2019 unsigned resultNumber) {
2020 if (!isa<LoopLikeOpInterface>(loopOp))
2021 return failure();
2022 Value val = loopOp->getResult(resultNumber);
2023 Block *loopBlock = loopOp->getBlock();
2024 for (OpOperand &opOperand : val.getUses()) {
2025 Operation *consumerOp = opOperand.getOwner();
2026 // Step 1. Check if the user is tilable.
2027 if (!isa<TilingInterface>(consumerOp) ||
2028 !isa<DestinationStyleOpInterface>(consumerOp)) {
2029 // TODO: We have to init result of consumer before scf.for, use
2030 // DestinationStyleOpInterface to get result shape from init for now.
2031 // Add support for other op such as op has InferTypeOpInterface.
2032 continue;
2033 }
2034 // Step 2. Check if user stay in the same block.
2035 if (loopBlock != consumerOp->getBlock())
2036 continue;
2037 // Step 3. Check if user has succeeding user. Otherwise, it usually
2038 // represents already tiled.
2039 if (consumerOp->use_empty())
2040 continue;
2041 // Step 4. Check assumption for loop with `reorderOperations` enabled.
2042 FailureOr<llvm::SetVector<Operation *>> slice =
2043 checkAssumptionForLoop(loopOp, consumerOp, true);
2044 if (failed(slice))
2045 continue;
2046 // Step 5. If backward sice is not empty, move them before
2047 // firstUserOfLoop.
2048 if (!slice->empty()) {
2049 mlir::topologicalSort(*slice);
2050 FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
2051 assert(succeeded(firstUserOfLoop) && "First user of loop is not found");
2052 for (auto op : *slice) {
2053 rewriter.moveOpBefore(op, *firstUserOfLoop);
2054 }
2055 }
2056 return &opOperand;
2057 }
2058 return failure();
2059}
2060
2061/// Fetch the untiled consumer of the outermost scf.for's result which is
2062/// yielded by a tensor.insert_slice from the innermost scf.for. This function
2063/// makes the following assumptions :
2064/// 1. tensor.insert_slice has scf.yield as its only user.
2065/// 2. scf.for's corresponding result has only one use.
2066/// 3. The `loops` passed in are perfectly nested `scf.for` operations.
2067static FailureOr<OpOperand *>
2069 tensor::InsertSliceOp candidateSliceOp,
2071 assert(!loops.empty() && "unexpected loops to be empty");
2072 // 1. Expect slice to be part of the body of the inner most loop.
2073 Operation *containingOp = candidateSliceOp->getParentOp();
2074 if (containingOp != loops.back()) {
2075 return rewriter.notifyMatchFailure(
2076 candidateSliceOp,
2077 "expected slice to be within body of inner-most loop");
2078 }
2079
2080 // 2. Check that the loop is perfectly nested.
2081 if (!isPerfectlyNestedForLoops(loops)) {
2082 return rewriter.notifyMatchFailure(
2083 candidateSliceOp, "expected passed loops to be perfectly nested.");
2084 }
2085
2086 if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
2087 return failure();
2088 Value sliceResult = candidateSliceOp.getResult();
2089
2090 // 3. Fetch the corresponding output.
2091 OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
2092 unsigned resultNumber = yieldOpOperand.getOperandNumber();
2093
2094 scf::ForOp topLevelForOp = cast<scf::ForOp>(loops.front().getOperation());
2095
2096 return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber);
2097}
2098
2099/// Fetch the first untiled consumer of a scf.forall's result which is yielded
2100/// by a tensor.parallel_insert_slice.
2101static FailureOr<OpOperand *>
2103 tensor::ParallelInsertSliceOp candidateSliceOp,
2105 assert(!loops.empty() && "unexpected loops to be empty");
2106 // 1. Check that the surrounding loop is a single scf.forall loop.
2107 if (loops.size() != 1) {
2108 return rewriter.notifyMatchFailure(
2109 candidateSliceOp, "expected single surrounding scf.forall");
2110 }
2111 auto forallOp = dyn_cast<scf::ForallOp>(loops.front().getOperation());
2112 if (!forallOp) {
2113 return rewriter.notifyMatchFailure(
2114 candidateSliceOp, "expected single surrounding scf.forall");
2115 }
2116
2117 // 2. Fetch the corresponding output
2118 Value sliceDest = candidateSliceOp.getDest();
2119 auto iterArg = dyn_cast<BlockArgument>(sliceDest);
2120 if (!iterArg)
2121 return failure();
2122 if (iterArg.getOwner()->getParentOp() != forallOp)
2123 return failure();
2124
2125 unsigned resultNumber =
2126 forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
2127 .getResultNumber();
2128
2129 return getConsumerFromLoopUses(rewriter, forallOp, resultNumber);
2130}
2131
2132/// A utility to fetch an untiled consumer of
2133/// tensor.insert_slice/tensor.parallel_insert_slice.
2134static FailureOr<SmallVector<OpOperand *>> getUntiledConsumerOperandsFromSlices(
2135 RewriterBase &rewriter, ArrayRef<Operation *> sliceOps,
2137 assert(!loops.empty() && "unexpected empty loops");
2138 assert(!sliceOps.empty() && "unexpected empty list of candidate slices");
2139 SmallVector<OpOperand *> fusedOperands;
2140 for (auto sliceOp : sliceOps) {
2141 FailureOr<OpOperand *> fusedOperand =
2143 .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2144 [&](auto op) {
2145 return getUntiledConsumerFromSlice(rewriter, op, loops);
2146 })
2147 .Default([&](Operation *op) {
2148 return rewriter.notifyMatchFailure(op, "unhandled slice type");
2149 });
2150 if (failed(fusedOperand)) {
2151 return failure();
2152 }
2153 if (!fusedOperands.empty() &&
2154 fusedOperand.value()->getOwner() != fusedOperands.front()->getOwner()) {
2155 return rewriter.notifyMatchFailure(
2156 fusedOperand.value()->getOwner(),
2157 "all candidate slices must be to the same consumer");
2158 }
2159 fusedOperands.push_back(fusedOperand.value());
2160 }
2161 return fusedOperands;
2162}
2163
2164template <typename InsertSliceOpTy>
2165static tensor::InsertSliceOp cloneAsInsertSlice(RewriterBase &rewriter,
2166 InsertSliceOpTy sliceOp);
2167
2168template <>
2169tensor::InsertSliceOp
2171 tensor::InsertSliceOp insertSliceOp) {
2172 return cast<tensor::InsertSliceOp>(
2173 rewriter.clone(*insertSliceOp.getOperation()));
2174}
2175
2176template <>
2178 RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp) {
2179 return tensor::InsertSliceOp::create(
2180 rewriter, insertSliceOp->getLoc(), insertSliceOp.getSource(),
2181 insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(),
2182 insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
2183}
2184
2185static SmallVector<tensor::InsertSliceOp>
2187 ArrayRef<Operation *> candidateSlices) {
2188 assert(!candidateSlices.empty() &&
2189 "unexpected empty list of slices to clone");
2191 for (auto sliceOp : candidateSlices) {
2193 .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2194 [&](auto op) {
2195 auto clonedOp = cloneAsInsertSlice(rewriter, op);
2196 clonedSlices.push_back(clonedOp);
2197 })
2198 // Assert here assuming this has already been checked.
2199 .DefaultUnreachable(
2200 "unexpected slice type while cloning as insert slice");
2201 }
2202 return clonedSlices;
2203}
2204
2205static FailureOr<scf::SCFFuseConsumerOfSliceResult>
2207 ArrayRef<OpOperand *> consumerOpOperands,
2208 ArrayRef<Operation *> candidateSlices,
2210 assert(!loops.empty() && "expected loops to be not empty");
2211
2212 // 1. Check assumption for loop with `reorderOperations` disabled.
2213 if (failed(checkAssumptionForLoop(loops.front(), consumerOp, false))) {
2214 return rewriter.notifyMatchFailure(
2215 loops.front(), "the first user of loop should not dominate any define "
2216 "of consumer operand(s)");
2217 }
2218
2219 LoopLikeOpInterface outerMostLoop = loops.front();
2220 LoopLikeOpInterface innerMostLoop = loops.back();
2221
2222 OpBuilder::InsertionGuard g(rewriter);
2223 // 2. Check consumer is not using scf loop's output as init.
2224 auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
2225 if (!dstOp)
2226 return rewriter.notifyMatchFailure(consumerOp,
2227 "consumer op is not DPS operation");
2228 if (llvm::any_of(consumerOpOperands, [&](OpOperand *opOperand) {
2229 return dstOp.isDpsInit(opOperand);
2230 })) {
2231 return rewriter.notifyMatchFailure(
2232 consumerOp,
2233 "consumer op taking the result of scf.for as init is not supported");
2234 }
2235 SmallVector<Value> newInits = llvm::to_vector(dstOp.getDpsInits());
2236
2237 // 3. Move the whole loop structure right before firstUserOfLoop, the
2238 // dominance should be already ensured by `checkAssumptionForLoop`.
2239 FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(outerMostLoop);
2240 if (failed(firstUserOfLoop)) {
2241 return rewriter.notifyMatchFailure(
2242 outerMostLoop, "could not find the first user of outer most loop");
2243 }
2244 rewriter.moveOpBefore(outerMostLoop, *firstUserOfLoop);
2245
2246 // 4. Set insertion point before terminator op of the loop and create a new
2247 // tensor.insert_slice. In the scf.for case this is a clone of the
2248 // candidateSliceOp whereas in the scf.forall case this is created from the
2249 // operands of tensor.parallel_insert_slice.
2250 if (auto sliceOp =
2251 dyn_cast<tensor::ParallelInsertSliceOp>(candidateSlices.front())) {
2252 auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
2253 rewriter.setInsertionPoint(newForallOp.getTerminator());
2254 } else {
2255 rewriter.setInsertionPoint(candidateSlices.front());
2256 }
2257 // 5.a. Clone all the candidate slices as equivalent insert slice ops.
2258 SmallVector<tensor::InsertSliceOp> clonedInsertSlices =
2259 cloneAsInsertSlices(rewriter, candidateSlices);
2260
2261 // 5.b. Clone consumer op.
2262 auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
2263 SmallVector<unsigned> operandNumbers =
2264 llvm::map_to_vector(consumerOpOperands, [](OpOperand *opOperand) {
2265 return opOperand->getOperandNumber();
2266 });
2267 SmallVector<OpOperand *> clonedOpFusedOperandsList =
2268 llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
2269 return &clonedConsumerOp->getOpOperand(operandNum);
2270 });
2271
2272 // 5.c. Replace all uses of the loop result with the result of the cloned
2273 // tensor.insert_slice.
2274 rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
2275 for (auto [operandToReplace, clonedSliceOp] :
2276 llvm::zip_equal(clonedOpFusedOperandsList, clonedInsertSlices)) {
2277 operandToReplace->set(clonedSliceOp.getResult());
2278 }
2279 });
2280
2281 // 6. Perform tiling of the cloned consumer and replace the operand at
2282 // `operandNumber` with the source of the cloned tensor.insert_slice op.
2283 FailureOr<TilingResult> tileAndFuseResult =
2284 tensor::replaceInsertSlicesWithTiledConsumer(rewriter, clonedInsertSlices,
2285 clonedOpFusedOperandsList);
2286 if (failed(tileAndFuseResult)) {
2287 return failure();
2288 }
2289
2290 auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
2291 for (auto [operandNum, clonedSliceOp] :
2292 llvm::zip_equal(operandNumbers, clonedInsertSlices)) {
2293 rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNum),
2294 clonedSliceOp.getSource());
2295 }
2296
2297 // 7. Reconstruct [nested] loop with new inits.
2298 YieldTiledValuesFn newYieldValuesFn =
2299 [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
2300 ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
2302 SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
2303 OpBuilder::InsertionGuard g(innerRewriter);
2304 // 8. Set inner insertPoint right before tiled consumer op.
2305 innerRewriter.setInsertionPoint(tiledConsumerOp);
2306
2307 SmallVector<SmallVector<OpFoldResult>> allOffsets, allSizes;
2308 for (auto candidateSliceOp : clonedInsertSlices) {
2309 SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
2310 SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
2311 SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
2312
2313 // 9. Check all insert stride is 1.
2314 if (!llvm::all_of(strides, isOneInteger)) {
2315 return rewriter.notifyMatchFailure(
2316 candidateSliceOp, "containingOp's result yield with stride");
2317 }
2318
2319 allOffsets.emplace_back(std::move(offsets));
2320 allSizes.emplace_back(std::move(sizes));
2321 }
2322
2323 // 10. Try to get iter domain position from input position. Use
2324 // clonedConsumerOp instead of tiledConsumerOp, because the iteration
2325 // domain may require index computation based on the result size. The
2326 // sizes and offsets should be the same either way, but using
2327 // tiledConsumerOp could lead to some chained unnecessary extra index
2328 // computation.
2329 SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
2330 if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTiles(
2331 rewriter, operandNumbers, allOffsets, allSizes, iterDomainOffsets,
2332 iterDomainSizes))) {
2333 return rewriter.notifyMatchFailure(
2334 clonedConsumerOp,
2335 "can't get iter domain position from input position");
2336 }
2337
2338 // 11. Try to fetch the offset and size for all results of the cloned
2339 // consumer. This would then be used to form the corresponding
2340 // tensor.insert_slice/parallel_insert_slice later.
2341 unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
2343 totalNumResultsOfConsumer);
2345 totalNumResultsOfConsumer);
2346 for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
2347 if (failed(tiledConsumerOp.getResultTilePosition(
2348 rewriter, idx, iterDomainOffsets, iterDomainSizes,
2349 resultOffsets[idx], resultSizes[idx]))) {
2350 return rewriter.notifyMatchFailure(
2351 tiledConsumerOp,
2352 "can't get result domain position from iter domain position");
2353 }
2354 }
2355
2356 // 12. Create `extract_slice` for `iter_args` for DPS operation if
2357 // necessary.
2358 if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
2359 tiledConsumerOp.getOperation())) {
2360 rewriter.setInsertionPoint(tiledDestStyleOp);
2361 for (const auto &&[index, newRegionArg] :
2362 llvm::enumerate(newRegionIterArgs)) {
2363 auto destSlice = tensor::ExtractSliceOp::create(
2364 rewriter, loc, newRegionArg, resultOffsets[index],
2365 resultSizes[index],
2366 SmallVector<OpFoldResult>(resultOffsets[index].size(),
2367 rewriter.getIndexAttr(1)));
2368 // Make a copy of index to avoid a capturing structured binding, which
2369 // is a C++20 extension.
2370 auto dstNumber = index;
2371 rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
2372 tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
2373 });
2374 }
2375 }
2376
2377 // 13. Prepare tiled offset and sizes for later `insert_slice` creation by
2378 // caller.
2379 Block *block = rewriter.getInsertionPoint()->getBlock();
2380 rewriter.setInsertionPoint(block->getTerminator());
2381 for (const auto &&[index, result] :
2382 llvm::enumerate(tiledConsumerOp->getResults())) {
2383 tiledResult.push_back(result);
2384 tiledOffset.emplace_back(resultOffsets[index]);
2385 tiledSizes.emplace_back(resultSizes[index]);
2386 }
2387 return success();
2388 };
2389 // 14. Add new inits to [nested] loops.
2390 if (failed(addInitOperandsToLoopNest(rewriter, loops, newInits,
2391 newYieldValuesFn))) {
2392 return rewriter.notifyMatchFailure(tiledConsumerOp,
2393 "unable to add new inits to nest loop");
2394 }
2395
2396 // 15. Replace the result of scf loop and consumer op with new loop's
2397 // results.
2398
2399 for (auto &&[oldResult, newResult] :
2400 llvm::zip(consumerOp->getResults(),
2401 loops.front()->getResults().take_back(newInits.size()))) {
2402 rewriter.replaceAllUsesWith(oldResult, newResult);
2403 }
2404
2405 // 16. Need to erase the old scf loop and the cloned consumer op.
2406 rewriter.eraseOp(clonedConsumerOp);
2407
2408 SmallVector<OpOperand *> tiledAndFusedOpOperands =
2409 llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
2410 return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum);
2411 });
2412 auto consumerOpOperandsVec = llvm::to_vector(consumerOpOperands);
2413 return scf::SCFFuseConsumerOfSliceResult{
2414 std::move(consumerOpOperandsVec), std::move(tiledAndFusedOpOperands),
2415 std::move(tileAndFuseResult->tiledOps)};
2416}
2417
2418/// Implementation of fusing consumer of a single slice by computing the
2419/// slice of the consumer in-place for scf loop.
2420FailureOr<scf::SCFFuseConsumerOfSliceResult>
2421mlir::scf::tileAndFuseConsumerOfSlices(
2422 RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
2423 MutableArrayRef<LoopLikeOpInterface> loops) {
2424 if (candidateSlices.empty()) {
2425 return rewriter.notifyMatchFailure(
2426 rewriter.getUnknownLoc(),
2427 "no candidate slices provided for consumer fusion");
2428 }
2429 // Return if `loops` is empty, return an error for now. Caller is expected
2430 // to handle this case.
2431 if (loops.empty()) {
2432 return rewriter.notifyMatchFailure(
2433 candidateSlices.front(),
2434 "cannot call tile and fuse consumer with an empty loop nest");
2435 }
2436
2437 if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
2438 llvm::all_of(candidateSlices,
2439 llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
2440 return rewriter.notifyMatchFailure(
2441 candidateSlices.front(),
2442 "candidates slices need to be all `tensor.extract_slice`s or "
2443 "`tensor.parallel_insert_slice`s");
2444 }
2445
2446 // Get the consumer of scf.for for the result yielded by
2447 // tensor.insert_slice/parallel_insert_slice.
2448 FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperands =
2449 getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
2450 if (failed(maybeConsumerOpOperands)) {
2451 return rewriter.notifyMatchFailure(candidateSlices.front(),
2452 "could not fetch consumer to fuse");
2453 }
2454 Operation *consumerOp = maybeConsumerOpOperands->front()->getOwner();
2455
2456 return tileAndFuseConsumerOfSlicesImpl(rewriter, consumerOp,
2457 maybeConsumerOpOperands.value(),
2458 candidateSlices, loops);
2459}
2460
2461/// For a given `result` of a `forallOp` return the
2462/// `tensor.parallel_insert_slice` op (or combining op) that is used to
2463/// construct this result.
2464static std::optional<Operation *>
2466 if (result.getOwner() != forallOp)
2467 return std::nullopt;
2468 BlockArgument bbArg = forallOp.getTiedBlockArgument(result);
2469 SmallVector<Operation *> combiningOps = forallOp.getCombiningOps(bbArg);
2470 // If the number of combining ops is not 1, then this is unexpected. Return
2471 // nullopt.
2472 if (combiningOps.size() != 1)
2473 return std::nullopt;
2474 return combiningOps[0];
2475}
2476
2477/// For a given result of the loop nest that is a tiled loop nest, return the
2478/// insert slice-like op that is used for consumer fusion
2479static std::optional<Operation *>
2482 assert(!loops.empty() && "Expected loops to be not empty");
2483 LoopLikeOpInterface outerMostLoop = loops.front();
2484 if (auto forallOp = dyn_cast<scf::ForallOp>(outerMostLoop.getOperation())) {
2485 assert(loops.size() == 1 &&
2486 "expected only a single loop when tiling using scf.forall");
2487 return getProducingParallelInsertSlice(forallOp, result);
2488 }
2489 // Assume that the loop nest is a nested `scf.for` that is created through
2490 // tiling and retrieve the `tensor.insert_slice` operation used to construct
2491 // the result.
2492 while (loops.size() != 1) {
2493 LoopLikeOpInterface loop = loops.front();
2494 if (result.getOwner() != loop)
2495 return std::nullopt;
2496 auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
2497 if (!forOp)
2498 return std::nullopt;
2499 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
2500 auto innerForResult =
2501 dyn_cast<OpResult>(yieldOp.getOperand(result.getResultNumber()));
2502 if (!innerForResult)
2503 return std::nullopt;
2504 result = innerForResult;
2505 loops = loops.drop_front();
2506 }
2507 LoopLikeOpInterface loop = loops.front();
2508 if (result.getOwner() != loop)
2509 return std::nullopt;
2510 auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
2511 if (!forOp)
2512 return std::nullopt;
2513 auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
2514 auto insertSliceOp = yieldOp.getOperand(result.getResultNumber())
2515 .getDefiningOp<tensor::InsertSliceOp>();
2516 if (!insertSliceOp)
2517 return std::nullopt;
2518 return insertSliceOp;
2519}
2520
2521FailureOr<scf::SCFFuseConsumerOfSliceResult>
2522mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer,
2523 MutableArrayRef<LoopLikeOpInterface> loops) {
2524 if (!isa<TilingInterface>(consumer)) {
2525 return rewriter.notifyMatchFailure(
2526 consumer, "unhandled consumer that does not implement TilingInterface");
2527 }
2528
2529 // Return if `loops` is empty, return an error for now. Caller is expected
2530 // to handle this case.
2531 if (loops.empty()) {
2532 return rewriter.notifyMatchFailure(
2533 consumer, "cannot call tile and fuse consumer with an empty loop nest");
2534 }
2535
2536 LoopLikeOpInterface outermostLoop = loops.front();
2537
2538 // Collect the operands of the consumer that come from the outermost loop of
2539 // the loop nest.
2540 SmallVector<OpOperand *> consumerFusableOperands;
2541 for (OpOperand &opOperand : consumer->getOpOperands()) {
2542 if (opOperand.get().getDefiningOp() == outermostLoop) {
2543 consumerFusableOperands.push_back(&opOperand);
2544 }
2545 }
2546
2547 // Nothing to fuse. Just return an empty set.
2548 if (consumerFusableOperands.empty()) {
2549 return mlir::scf::SCFFuseConsumerOfSliceResult{consumerFusableOperands,
2550 SmallVector<OpOperand *>{},
2551 SmallVector<Operation *>{}};
2552 }
2553
2554 // Collect the relevant tensor.insert_slice/tensor.parallel_insert_slices
2555 // for fusion.
2556 SmallVector<Operation *> candidateSlices;
2557 candidateSlices.reserve(consumerFusableOperands.size());
2558 for (OpOperand *opOperand : consumerFusableOperands) {
2559 std::optional<Operation *> slice =
2560 getProducingInsertSliceLikeOp(cast<OpResult>(opOperand->get()), loops);
2561 if (!slice) {
2562 return rewriter.notifyMatchFailure(
2563 consumer,
2564 "couldnt find producing insert-slice like operation for operand");
2565 }
2566 candidateSlices.push_back(slice.value());
2567 }
2569 rewriter, consumer, consumerFusableOperands, candidateSlices, loops);
2570}
2571
2572//===----------------------------------------------------------------------===//
2573// lowerToLoopsUsingSCFForOp implementation.
2574//===----------------------------------------------------------------------===//
2575
2576FailureOr<SmallVector<scf::ForOp>>
2577mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
2578 TilingInterface op) {
2579 // TODO: Handle cases where the op has results if needed.
2580 if (op->getNumResults() > 0) {
2581 return rewriter.notifyMatchFailure(
2582 op, "unable to lower to loops operations with return values");
2583 }
2584
2585 SmallVector<Range> domain = op.getIterationDomain(rewriter);
2586 SmallVector<Value> ivs;
2587 SmallVector<scf::ForOp> loops;
2588 Location loc = op.getLoc();
2589 for (auto loopRange : domain) {
2590 Value offsetVal =
2591 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
2592 Value sizeVal =
2593 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
2594 Value strideVal =
2595 getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
2596 auto loop = scf::ForOp::create(rewriter, op.getLoc(), offsetVal, sizeVal,
2597 strideVal, ValueRange{});
2598 loops.push_back(loop);
2599 ivs.push_back(loop.getInductionVar());
2600 rewriter.setInsertionPoint(loop.getBody()->getTerminator());
2601 }
2602 if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
2603 return failure();
2604 }
2605 return loops;
2606}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static llvm::ManagedStatic< PassManagerOptions > options
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult givenTileSize, OpFoldResult numThreads, OpFoldResult iterationSize)
Returns true if the maximum tile offset tileSize * numThreads-1 is less than iterationSize.
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static FailureOr< MergeResult > mergeTilingResults(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, const SetVector< unsigned > &reductionDims, ValueRange partialResults)
static std::optional< Operation * > getProducingInsertSliceLikeOp(OpResult result, ArrayRef< LoopLikeOpInterface > loops)
For a given result of the loop nest that is a tiled loop nest, return the insert slice-like op that i...
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< LoopLikeOpInterface > loops)
Return the untiled producer whose slice is used in a tiled consumer.
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop< scf::ForallOp >(scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Implementation of yieldTiledValuesAndReplaceLoop for scf.forall
static FailureOr< OpOperand * > getConsumerFromLoopUses(RewriterBase &rewriter, Operation *loopOp, unsigned resultNumber)
Fetches the OpOperand of the first valid user (and use) of the value val which implements TilingInter...
static LogicalResult checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp)
A utility function that checks whether the only use of the result of a tensor.insert_slice op is in a...
std::function< LogicalResult( RewriterBase &rewriter, Location Loc, ValueRange ivs, ArrayRef< OpFoldResult > tileOffsets, ArrayRef< OpFoldResult > tileSizes, ValueRange outerDestinationTensors, SmallVector< Value > &tiledResults, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> GenerateTiledBodyFn
Typedef for function that implements the body of a tiled loop.
static LogicalResult checkTileSizes(TilingInterface op, scf::SCFTilingOptions::LoopType loopType, ReductionTilingStrategy reductionStrategy, ArrayRef< OpFoldResult > givenTileSizes, ArrayRef< OpFoldResult > numThreads)
Checks if any of the tiled loops are not parallel.
static FailureOr< scf::SCFFuseConsumerOfSliceResult > tileAndFuseConsumerOfSlicesImpl(RewriterBase &rewriter, Operation *consumerOp, ArrayRef< OpOperand * > consumerOpOperands, ArrayRef< Operation * > candidateSlices, MutableArrayRef< LoopLikeOpInterface > loops)
static FailureOr< SmallVector< LoopLikeOpInterface > > generateLoopNestUsingCustomOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes, ValueRange outerDestinationTensors, const scf::SCFTilingOptions::GenerateLoopHeaderFn &generateLoopHeaderFn, const scf::SCFTilingOptions::GenerateLoopTerminatorFn &generateLoopTerminatorFn, GenerateTiledBodyFn tiledBodyFn)
Generate the tile-loop nest using custom loop operation.
static FailureOr< SmallVector< LoopLikeOpInterface > > generateLoopNest(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes, ArrayRef< OpFoldResult > numThreads, ValueRange destinationTensors, GenerateTiledBodyFn tiledBodyFn)
Generate the tile-loop nest using the loop construct specifed in options.
static FailureOr< SmallVector< LoopLikeOpInterface > > generateLoopNestUsingForOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes, ValueRange outerDestinationTensors, GenerateTiledBodyFn tiledBodyFn)
Generate the tile-loop nest using scf.for operation.
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static FailureOr< OpOperand * > getUntiledConsumerFromSlice(RewriterBase &rewriter, tensor::InsertSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Fetch the untiled consumer of the outermost scf.for's result which is yielded by a tensor....
static SmallVector< tensor::InsertSliceOp > cloneAsInsertSlices(RewriterBase &rewriter, ArrayRef< Operation * > candidateSlices)
static FailureOr< SmallVector< OpOperand * > > getUntiledConsumerOperandsFromSlices(RewriterBase &rewriter, ArrayRef< Operation * > sliceOps, MutableArrayRef< LoopLikeOpInterface > loops)
A utility to fetch an untiled consumer of tensor.insert_slice/tensor.parallel_insert_slice.
tensor::InsertSliceOp cloneAsInsertSlice< tensor::ParallelInsertSliceOp >(RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp)
static FailureOr< Operation * > getFirstUserOfLoop(Operation *loopOp)
An utility to get the first user of the given loopOp.
static Operation * cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *op, ValueRange newDestArgs)
Clones the operation and updates the destination if the operation implements the DestinationStyleOpIn...
static FailureOr< llvm::SetVector< Operation * > > checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, bool reorderOperations)
This utility currently checks whether the first userOp of loop is NOT before the last defineOp of con...
std::function< LogicalResult( RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, SmallVector< Value > &tiledValues, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> YieldTiledValuesFn
Typedef for function that allows returning additional yielded values during yieldTiledValuesAndReplac...
tensor::InsertSliceOp cloneAsInsertSlice< tensor::InsertSliceOp >(RewriterBase &rewriter, tensor::InsertSliceOp insertSliceOp)
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > givenTileSizes)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes)
Function to return the bounds of the loops to be generated.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizesWithForAllOp(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > givenTileSizes, ArrayRef< OpFoldResult > numThreads)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, OpFoldResult offset, OpFoldResult givenTileSize)
Returns the bounded tile size given the current offset, loopRange and tileSize, i....
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop< scf::ForOp >(scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Implementation of yieldTiledValuesAndReplaceLoop for scf.for.
static bool tileDividesIterationDomain(Range loopRange)
Check if stride evenly divides the trip count size - offset.
static SetVector< unsigned > getSanitizedReductionDims(ArrayRef< OpFoldResult > givenTileSizes, const scf::SCFTilingOptions &options)
Get the reduction dims that are tiled.
static LogicalResult addInitOperandsToLoopNest(RewriterBase &rewriter, MutableArrayRef< LoopLikeOpInterface > loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn)
Method to add new init values to a loop nest.
static FailureOr< SmallVector< LoopLikeOpInterface > > generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes, ArrayRef< OpFoldResult > numThreads, ArrayRef< Attribute > mappingVector, ValueRange outerDestinationTensors, GenerateTiledBodyFn tiledBodyFn)
Generate the tile-loop nest using scf.forall operation.
static std::optional< Operation * > getProducingParallelInsertSlice(scf::ForallOp forallOp, OpResult result)
For a given result of a forallOp return the tensor.parallel_insert_slice op (or combining op) that is...
static tensor::InsertSliceOp cloneAsInsertSlice(RewriterBase &rewriter, InsertSliceOpTy sliceOp)
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims)
static FailureOr< SmallVector< Value > > createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims)
static SmallVector< OpFoldResult > getSplitReductionIvs(RewriterBase &rewriter, Location loc, ReductionTilingStrategy reductionStrategy, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims)
For the case of ReductionTilingStrategy::PartialReductionOuterParallel the PartialReductionOpInterfac...
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, ArrayRef< Range > iterationDomain, const scf::SCFTilingOptions &options)
Method to instantiate the tile sizes and/or number of threads specified by the user.
static LogicalResult verifyOptions(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options)
Verify the tile size options are set in a consistent manner.
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Append the specified additional newInitOperands operands to the loops existing init operands (or simi...
Base type for affine expression.
Definition AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
unsigned getNumArguments()
Definition Block.h:138
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
Location getUnknownLoc()
Definition Builders.cpp:25
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
A class for computing basic dominance information.
Definition Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition Builders.h:447
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition Builders.h:318
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:322
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition Operation.h:852
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
result_range getResults()
Definition Operation.h:415
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
iterator_range< use_iterator > use_range
Definition Value.h:182
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:578
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
FailureOr< TilingResult > replaceInsertSlicesWithTiledConsumer(OpBuilder &builder, ArrayRef< tensor::InsertSliceOp > sliceOps, ArrayRef< OpOperand * > consumerOperands)
Method to swap tensor.insert_slices with their consumers when the consumer implements the TilingInter...
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition TensorOps.cpp:77
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Include the generated interface declarations.
bool isPerfectlyNestedForLoops(MutableArrayRef< LoopLikeOpInterface > loops)
Check if the provided loops are perfectly nested for-loops.
Definition Utils.cpp:1525
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
ReductionTilingStrategy
Tiling can be thought of as splitting a dimension into 2 and materializing the outer dimension as a l...
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
Container for result values of tiling.