MLIR  22.0.0git
TileUsingInterface.cpp
Go to the documentation of this file.
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
22 #include "mlir/IR/Dominance.h"
23 #include "mlir/IR/PatternMatch.h"
28 #include "llvm/ADT/ScopeExit.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/Support/Debug.h"
31 #include <optional>
32 
33 #define DEBUG_TYPE "tile-using-interface"
34 
35 using namespace mlir;
36 
39  assert(!tileSizeComputationFunction && "tile sizes already set");
40  auto tileSizes = llvm::to_vector(ts);
41  tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
42  return tileSizes;
43  };
44  return *this;
45 }
46 
49  assert(!numThreadsComputationFunction && "num tiles already set");
50  auto numThreads = llvm::to_vector(nt);
51  numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) {
52  return numThreads;
53  };
54  return *this;
55 }
56 
57 /// Helper method to adjust the interchange vector to match the iteration
58 /// domain.
61  size_t iterationDomainSize) {
62  SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
63  if (filledVector.size() < iterationDomainSize) {
64  auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
65  filledVector.append(range.begin(), range.end());
66  }
67  if (filledVector.size() > iterationDomainSize)
68  filledVector.resize(iterationDomainSize);
69  return filledVector;
70 }
71 
72 //===----------------------------------------------------------------------===//
73 // tileUsingSCF implementation.
74 //===----------------------------------------------------------------------===//
75 
76 /// Verify the tile size options are set in a consistent manner.
77 static LogicalResult verifyOptions(RewriterBase &rewriter, Location loc,
79  // Specifying number of threads is only supported on `scf.forall` op.
80  if (options.numThreadsComputationFunction &&
82  return rewriter.notifyMatchFailure(
83  loc, "number of threads can only by specified when loop type is "
84  "set to use `scf.forall`");
85  }
86 
87  // If specified, check that the interchange vector is a permutation.
88  if (!options.interchangeVector.empty()) {
89  if (!isPermutationVector(options.interchangeVector)) {
90  return rewriter.notifyMatchFailure(
91  loc, "invalid interchange vector, not a permutation of the entire "
92  "iteration space");
93  }
94  }
95  return success();
96 }
97 
98 /// Method to instantiate the tile sizes and/or number of threads specified
99 /// by the user.
100 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
101 getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
102  ArrayRef<Range> iterationDomain,
104  OpFoldResult zero = rewriter.getIndexAttr(0);
105  SmallVector<OpFoldResult> tileSizes, numThreads;
106  size_t numLoops = iterationDomain.size();
107 
108  // Check whether the number of tiles to use is specified.
109  if (options.numThreadsComputationFunction) {
110  numThreads = options.numThreadsComputationFunction(rewriter, op);
111  numThreads.resize(numLoops, zero);
112 
113  // If the number of tiles is also specified, use that.
114  if (options.tileSizeComputationFunction) {
115  tileSizes = options.tileSizeComputationFunction(rewriter, op);
116  tileSizes.resize(numLoops, zero);
117  return {tileSizes, numThreads};
118  }
119 
120  // Compute the tile sizes from the iteration domain and number
121  // of tiles as follows
122  // - niters = ceilDiv(ub - lb, step)
123  // - tileSize = ceilDiv(niters, numThreads)
124  AffineExpr s0, s1, s2;
125  bindSymbols(rewriter.getContext(), s0, s1, s2);
126  // TODO: The step here is assumed to be 1.
127  AffineExpr numItersExpr = (s1 - s0);
128  AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s2);
129  tileSizes.resize(numLoops, zero);
130  for (auto [index, range, nt] :
131  llvm::enumerate(iterationDomain, numThreads)) {
132  if (isZeroInteger(nt))
133  continue;
134 
135  tileSizes[index] = affine::makeComposedFoldedAffineApply(
136  rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
137  }
138  tileSizes.resize(numLoops, zero);
139  return {tileSizes, numThreads};
140  }
141 
142  // Enforce the convention that "tiling by zero"
143  // skips tiling a particular dimension. This convention is significantly
144  // simpler to handle instead of adjusting affine maps to account for missing
145  // dimensions.
146  assert(options.tileSizeComputationFunction &&
147  "expected tile sizes to be specified");
148  tileSizes = options.tileSizeComputationFunction(rewriter, op);
149  tileSizes.resize(numLoops, zero);
150 
151  return {tileSizes, numThreads};
152 }
153 
154 /// Checks if any of the tiled loops are not parallel.
155 static LogicalResult checkTileSizes(TilingInterface op,
157  ReductionTilingStrategy reductionStrategy,
158  ArrayRef<OpFoldResult> givenTileSizes,
159  ArrayRef<OpFoldResult> numThreads) {
160  auto iterators = op.getLoopIteratorTypes();
161  assert(iterators.size() == givenTileSizes.size() &&
162  "expected as many tile size values as number of loops");
163  assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
164  "when specified, expected number of threads to use for each loop");
165 
166  bool isParallelTiling = false;
167  for (auto [index, iterator, givenTileSize] :
168  llvm::enumerate(iterators, givenTileSizes)) {
169  if (!isConstantIntValue(givenTileSize, 0)) {
170  isParallelTiling |= iterator == utils::IteratorType::parallel;
171  }
172 
174  reductionStrategy == ReductionTilingStrategy::FullReduction) {
175  // If num threads is specified, check that it is greater than one only for
176  // parallel dimensions.
177  if (!numThreads.empty()) {
178  if (std::optional<int64_t> constNumThreads =
179  getConstantIntValue(numThreads[index])) {
180  if (constNumThreads.value() > 1 &&
181  iterator != utils::IteratorType::parallel) {
182  op.emitWarning() << "tiling is not thread safe at axis #" << index;
183  }
184  }
185  continue;
186  }
187 
188  if (std::optional<int64_t> constTileSize =
189  getConstantIntValue(givenTileSize)) {
190  if (constTileSize.value() > 0 &&
191  iterator != utils::IteratorType::parallel) {
192  op.emitWarning() << "tiling is not thread safe at axis #" << index;
193  }
194  }
195  }
196  }
197 
198  if (reductionStrategy != ReductionTilingStrategy::FullReduction) {
199  if (isParallelTiling) {
200  return op->emitOpError("tiling parallel dimensions is not supported with "
201  "partial reduction tiling strategies");
202  }
203  }
204  return success();
205 }
206 
207 /// Get the reduction dims that are tiled. This accounts for reduction dims
208 /// that are specified as tiled, but the tile size is 0.
209 static SetVector<unsigned>
212  SetVector<unsigned> reductionDims;
213  for (auto dim : options.reductionDims) {
214  if (isConstantIntValue(givenTileSizes[dim], 0))
215  continue;
216  reductionDims.insert(dim);
217  }
218  return reductionDims;
219 }
220 
221 /// Check if `stride` evenly divides the trip count `size - offset`.
222 static bool tileDividesIterationDomain(Range loopRange) {
223  std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
224  if (!offsetAsInt)
225  return false;
226  std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
227  if (!sizeAsInt)
228  return false;
229  std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
230  if (!strideAsInt)
231  return false;
232  return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
233 }
234 
235 /// Returns the bounded tile size given the current `offset`, `loopRange` and
236 /// `tileSize`, i.e., `min(tileSize, range.end() - offset)`.
238  Range loopRange, OpFoldResult offset,
239  OpFoldResult givenTileSize) {
240  std::optional<int64_t> ts = getConstantIntValue(givenTileSize);
241  if (ts && ts.value() == 1)
242  return givenTileSize;
243 
245  Range{loopRange.offset, loopRange.size, givenTileSize}))
246  return givenTileSize;
247 
248  // The tile size to use (to avoid out of bounds access) is minimum of
249  // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
250  // loop.
251  AffineExpr s0, s1, d0;
252  bindDims(b.getContext(), d0);
253  bindSymbols(b.getContext(), s0, s1);
254  AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext());
255  Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
257  b, loc, minMap, SmallVector<OpFoldResult>{offset, size, givenTileSize});
258 }
259 
260 /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
261 /// than `iterationSize`.
263  OpFoldResult numThreads,
264  OpFoldResult iterationSize) {
265  std::optional<int64_t> tileSizeConst = getConstantIntValue(givenTileSize);
266  std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
267  std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
268  if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
269  return false;
270  return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
271 }
272 
273 /// Compute the `OpFoldResult`s that represents the multi-dimensional
274 /// `offset`s and `size`s of the tile of the iteration space that the
275 /// innermost loop body of the generated tiled loops corresponds to.
276 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
278  ArrayRef<Range> iterationDomain,
279  ArrayRef<OpFoldResult> givenTileSizes) {
280  SmallVector<OpFoldResult> offsets, sizes;
281  int materializedLoopNum = 0;
282  for (auto [givenTileSize, loopRange] :
283  llvm::zip_equal(givenTileSizes, iterationDomain)) {
284 
285  // Non-tiled cases, set the offset and size to the
286  // `loopRange.offset/size`.
287  if (isZeroInteger(givenTileSize)) {
288  offsets.push_back(loopRange.offset);
289  sizes.push_back(loopRange.size);
290  continue;
291  }
292 
293  Value iv = ivs[materializedLoopNum++];
294  OpFoldResult offset = getAsOpFoldResult(iv);
295  offsets.push_back(offset);
296  OpFoldResult size =
297  getBoundedTileSize(rewriter, loc, loopRange, offset, givenTileSize);
298  sizes.push_back(size);
299  }
300  return {offsets, sizes};
301 }
302 
303 /// Function to return the bounds of the loops to be generated.
304 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
307  ArrayRef<OpFoldResult> givenTileSizes) {
308  SmallVector<OpFoldResult> lbs, ubs, steps;
309  for (auto [loopRange, givenTileSize] :
310  llvm::zip_equal(loopRanges, givenTileSizes)) {
311  // No loop if the tile size is 0.
312  if (isZeroInteger(givenTileSize))
313  continue;
314  lbs.push_back(loopRange.offset);
315  ubs.push_back(loopRange.size);
316  steps.push_back(givenTileSize);
317  }
318  return {lbs, ubs, steps};
319 }
320 
321 /// Typedef for function that allows returning additional yielded values during
322 /// `yieldTiledValuesAndReplace`.
323 /// - `ivs` induction variable for the loop.
324 /// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
325 /// - `tiledValues` the tiled values to return. Must be of same size as
326 /// `newbbArgs`, each element of this array is inserted into the corresponding
327 /// element in `newbbArgs`.
328 /// - `resultOffsets` is of the same size as `tiledValues` and represents
329 /// the offsets to use when inserting corresponding element from `tiledValues`
330 /// into the element from `newBbArgs`.
331 /// - `resultSizes` is of the same size as `tiledValues` and represents
332 /// the size of the corresponding element from `tiledValues` inserted into
333 /// the element from `newBbArgs`.
334 /// In case the method needs to return `failure()` the method is expected
335 /// to clean up any inserted operations.
336 using YieldTiledValuesFn = std::function<LogicalResult(
337  RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
338  SmallVector<Value> &tiledValues,
339  SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
340  SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
341 
342 /// Typedef for function that implements the body of a tiled loop.
343 /// - `ivs` induction variable for the loop.
344 /// - `tileOffsets` represents offsets for the tiled iteration space.
345 /// - `tileSizes` represents the sizes for the tiled iteraiton space.
346 /// - `outerDestinationTensors` tensor that holds the result. Is same size
347 /// as the destination operands of the original operations.
348 /// - `tiledResults` results of the tiled computation, corresponds to
349 /// tiles of the original operation computed by the loop body.
350 /// Should be same size as the `destinationTensors`
351 /// - `resultOffsets` is of the same size as `tiledResults` and represents
352 /// the offset to use when writing the corresponding element from
353 /// `tiledResults` into `destinationTensors`.
354 /// - `resultOffsets` is of the same size as `tiledResults` and represents
355 /// the size to use when writing the corresponding element from
356 /// `tiledResults` into `destinationTensors`.
357 /// In case the method needs to return `failure()` the method is expected
358 /// to clean up any inserted operations.
359 using GenerateTiledBodyFn = std::function<LogicalResult(
360  RewriterBase &rewriter, Location Loc, ValueRange ivs,
361  ArrayRef<OpFoldResult> tileOffsets, ArrayRef<OpFoldResult> tileSizes,
362  ValueRange outerDestinationTensors, SmallVector<Value> &tiledResults,
363  SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
364  SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
365 
366 /// Clones the operation and updates the destination if the operation
367 /// implements the `DestinationStyleOpInterface`.
369  Operation *op,
370  ValueRange newDestArgs) {
371  Operation *clonedOp = rewriter.clone(*op);
372  if (newDestArgs.empty())
373  return clonedOp;
374  if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
375  destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
376  return clonedOp;
377 }
378 
379 /// Generate the tile-loop nest using `scf.for` operation.
380 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
381 /// - `givenTileSizes` is the tile sizes to use. Zero represent untiled loops.
382 /// - `outerDestinationTensors` are the init values to use for the outer most
383 /// loop.
384 /// - `tiledBodyFn` is called to generated the loop body of the inner
385 /// most
386 /// loop.
387 /// Returns the generated `scf.for` loops on success.
388 static FailureOr<SmallVector<LoopLikeOpInterface>> generateLoopNestUsingForOp(
389  RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
390  ArrayRef<OpFoldResult> givenTileSizes, ValueRange outerDestinationTensors,
391  GenerateTiledBodyFn tiledBodyFn) {
392  assert(!loopRanges.empty() && "unexpected empty loop ranges");
393  assert(loopRanges.size() == givenTileSizes.size() &&
394  "expected as many tile sizes as loop ranges");
395  OpBuilder::InsertionGuard guard(rewriter);
396 
397  SmallVector<OpFoldResult> lbs, ubs, steps;
398  std::tie(lbs, ubs, steps) =
399  getLoopBounds(rewriter, loc, loopRanges, givenTileSizes);
400  SmallVector<Value> lbVals =
401  getValueOrCreateConstantIndexOp(rewriter, loc, lbs);
402  SmallVector<Value> ubVals =
403  getValueOrCreateConstantIndexOp(rewriter, loc, ubs);
404  SmallVector<Value> stepVals =
405  getValueOrCreateConstantIndexOp(rewriter, loc, steps);
406 
407  SmallVector<Value> ivs;
409  ValueRange innerDestinationTensors(outerDestinationTensors);
410  for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
411  auto loop =
412  scf::ForOp::create(rewriter, loc, lb, ub, step, innerDestinationTensors,
413  [](OpBuilder &bodyBuilder, Location bodyLoc,
414  Value iv, ValueRange /*iterArgs*/) {});
415  loops.push_back(loop);
416  ivs.push_back(loop.getInductionVar());
417  rewriter.setInsertionPointToEnd(loop.getBody());
418  innerDestinationTensors = loop.getRegionIterArgs();
419  }
420  if (loops.empty())
421  return success();
422 
423  // Compute the `offsets` and `sizes` to use for tiling.
424  SmallVector<OpFoldResult> offsets, sizes;
425  std::tie(offsets, sizes) =
426  getTileOffsetAndSizes(rewriter, loc, ivs, loopRanges, givenTileSizes);
427 
428  SmallVector<Value> tiledResults;
429  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
430  if (failed(tiledBodyFn(rewriter, loc, ivs, offsets, sizes,
431  innerDestinationTensors, tiledResults, resultOffsets,
432  resultSizes))) {
433  return rewriter.notifyMatchFailure(
434  loc, "failed to generate inner tile loop body");
435  }
436  if (loops.empty())
437  return loops;
438 
439  assert(tiledResults.size() == innerDestinationTensors.size() &&
440  "Number of results of body should be equal to number of iter args");
441 
442  // 6. Yield all the results of the tiled operation.
443  SmallVector<Value> yieldedValues;
444  for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
445  llvm::zip_equal(tiledResults, innerDestinationTensors, resultOffsets,
446  resultSizes)) {
447  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
448  rewriter.getIndexAttr(1));
449  auto insertSlice = tensor::InsertSliceOp::create(
450  rewriter, loc, tiledValue, destinationTensor, resultOffset, resultSize,
451  resultStride);
452  yieldedValues.push_back(insertSlice);
453  }
454  scf::YieldOp::create(rewriter, loc, yieldedValues);
455 
456  // Add the scf.yield operations for all the outer loops.
457  for (auto [outerLoop, innerLoop] :
458  llvm::zip_equal(MutableArrayRef(loops).drop_back(),
459  MutableArrayRef(loops).drop_front())) {
460  rewriter.setInsertionPointToEnd(
461  cast<scf::ForOp>(outerLoop.getOperation()).getBody());
462  scf::YieldOp::create(rewriter, outerLoop.getLoc(), innerLoop->getResults());
463  }
464  return loops;
465 }
466 
467 /// Compute the `OpFoldResult`s that represents the multi-dimensional
468 /// `offset`s and `size`s of the tile of the iteration space that the
469 /// innermost loop body of the generated tiled loops corresponds to
470 /// when tiling using `forall` op. This is handle separately due to
471 /// the special case handling needed for when the tiling is done by
472 /// specifying number of threads.
473 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
475  ValueRange ivs,
476  ArrayRef<Range> iterationDomain,
477  ArrayRef<OpFoldResult> givenTileSizes,
478  ArrayRef<OpFoldResult> numThreads) {
479  if (numThreads.empty()) {
480  return getTileOffsetAndSizes(rewriter, loc, ivs, iterationDomain,
481  givenTileSizes);
482  }
483 
484  SmallVector<OpFoldResult> offsets, sizes;
485  int materializedLoopNum = 0;
486 
487  AffineExpr d0, d1, s0, s1;
488  AffineExpr offsetExpr, residualTileSizeExpr;
489  bindDims(rewriter.getContext(), d0, d1);
490  bindSymbols(rewriter.getContext(), s0, s1);
491  offsetExpr = d0 + d1 * s0;
492  residualTileSizeExpr = s1 - (d0 + d1 * s0);
493 
494  for (auto [index, nt, givenTileSize, loopRange] :
495  llvm::enumerate(numThreads, givenTileSizes, iterationDomain)) {
496 
497  // Non-tiled cases, set the offset and size to the
498  // `loopRange.offset/size`.
499  if (isZeroInteger(nt)) {
500  offsets.push_back(loopRange.offset);
501  sizes.push_back(loopRange.size);
502  continue;
503  }
504 
505  Value iv = ivs[materializedLoopNum++];
507  rewriter, loc, offsetExpr,
508  ArrayRef<OpFoldResult>{loopRange.offset, iv, givenTileSize});
510  rewriter, loc, residualTileSizeExpr,
511  {loopRange.offset, nt, givenTileSize, loopRange.size});
512 
513  OpFoldResult size = givenTileSize;
514  if (!isZeroInteger(residualTileSize)) {
515  OpFoldResult sizeMinusOffsetPerThread =
516  affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
517  {offset, loopRange.size});
519  rewriter, loc,
521  {sizeMinusOffsetPerThread, givenTileSize});
522  }
523 
524  // Consider the case where the original loop was `[0, 100)`.
525  // If number of threads are `7`, the tile size would be computed as
526  // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6)
527  // - `offset = 0 + 6 * 15 = 105`
528  // - `tileSize = min(15, 100 - 105) = -5`
529  // To avoid negative tile sizes, we need to do a further
530  // `nonNegativeTileSize = affine.max(0, tileSize)`.
531  // This `max` can be avoided if
532  // `offset + tileSize * (numThreads - 1) < (ub - lb)`
533  if (!canOmitTileOffsetInBoundsCheck(givenTileSize, nt, loopRange.size)) {
534  AffineMap maxMap =
537  rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});
538  }
539 
540  offsets.push_back(offset);
541  sizes.push_back(size);
542  }
543  return {offsets, sizes};
544 }
545 
546 /// Generate the tile-loop nest using `scf.forall` operation.
547 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
548 /// - `giventileSizes` is the tile sizes to use. Zero represent untiled loops.
549 /// - `outerDestinationTensors` are the init values to use for the loop.
550 /// - `mappingVector` is the mapping attributes to use for loop construction.
551 /// Can be empty.
552 /// - `tiledBodyFn` is called to generated the loop body of the inner
553 /// most
554 /// loop.
555 /// Returns the generated `scf.forall` loop on success.
556 static FailureOr<SmallVector<LoopLikeOpInterface>>
558  ArrayRef<Range> loopRanges,
559  ArrayRef<OpFoldResult> givenTileSizes,
560  ArrayRef<OpFoldResult> numThreads,
561  ArrayRef<Attribute> mappingVector,
562  ValueRange outerDestinationTensors,
563  GenerateTiledBodyFn tiledBodyFn) {
564  assert(!loopRanges.empty() && "unexpected empty loop ranges");
565  assert(loopRanges.size() == givenTileSizes.size() &&
566  "expected as many tile sizes as loop ranges");
567  OpBuilder::InsertionGuard guard(rewriter);
568 
569  std::optional<ArrayAttr> mappingAttr;
570  if (!mappingVector.empty())
571  mappingAttr = rewriter.getArrayAttr(mappingVector);
572 
573  scf::ForallOp forallOp;
574  bool useNumThreads = !numThreads.empty();
575 
577  if (useNumThreads) {
578  // Prune the zero numthreads.
579  SmallVector<OpFoldResult> nonZeroNumThreads;
580  for (auto nt : numThreads) {
581  if (isZeroInteger(nt))
582  continue;
583  nonZeroNumThreads.push_back(nt);
584  }
585  forallOp = scf::ForallOp::create(rewriter, loc, nonZeroNumThreads,
586  outerDestinationTensors, mappingAttr);
587  } else {
588  SmallVector<OpFoldResult> lbs, ubs, steps;
589  std::tie(lbs, ubs, steps) =
590  getLoopBounds(rewriter, loc, loopRanges, givenTileSizes);
591  forallOp = scf::ForallOp::create(rewriter, loc, lbs, ubs, steps,
592  outerDestinationTensors, mappingAttr);
593  }
594  loops.push_back(forallOp);
595 
596  rewriter.setInsertionPoint(forallOp.getTerminator());
597  ValueRange innerDestinationTensors = forallOp.getRegionOutArgs();
598  SmallVector<Value> ivs = forallOp.getInductionVars();
599 
600  // Compute the `offsets` and `sizes` to use for tiling.
601  SmallVector<OpFoldResult> offsets, sizes;
602  std::tie(offsets, sizes) = getTileOffsetAndSizesWithForAllOp(
603  rewriter, loc, ivs, loopRanges, givenTileSizes, numThreads);
604 
605  SmallVector<Value> tiledResults;
606  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
607  if (failed(tiledBodyFn(rewriter, loc, ivs, offsets, sizes,
608  innerDestinationTensors, tiledResults, resultOffsets,
609  resultSizes)))
610  return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
611 
612  rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
613  for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
614  llvm::zip_equal(tiledResults, innerDestinationTensors, resultOffsets,
615  resultSizes)) {
616  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
617  rewriter.getIndexAttr(1));
618 
619  tensor::ParallelInsertSliceOp::create(rewriter, loc, tiledValue,
620  destinationTensor, resultOffset,
621  resultSize, resultStride);
622  }
623  return loops;
624 }
625 
626 /// Generate the tile-loop nest using custom loop operation.
627 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
628 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
629 /// - `destinationTensors` are the init values to use for the outer most loop.
630 /// - `mappingVector` is the mapping attributes to use for loop construction.
631 /// Can be empty.
632 /// - `tiledBodyFn` is called to generated the loop body of the inner
633 /// most
634 /// loop.
635 /// Returns the generated `scf.forall` loop on success.
636 static FailureOr<SmallVector<LoopLikeOpInterface>>
638  RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
639  ArrayRef<OpFoldResult> givenTileSizes, ValueRange outerDestinationTensors,
640  const scf::SCFTilingOptions::GenerateLoopHeaderFn &generateLoopHeaderFn,
642  &generateLoopTerminatorFn,
643  GenerateTiledBodyFn tiledBodyFn) {
644  assert(!loopRanges.empty() && "unexpected empty loop ranges");
645  assert(loopRanges.size() == givenTileSizes.size() &&
646  "expected as many tile sizes as loop ranges");
647  assert(generateLoopHeaderFn && generateLoopTerminatorFn &&
648  "expected loop header/terminator generation function");
649  OpBuilder::InsertionGuard guard(rewriter);
650 
651  FailureOr<scf::SCFTilingOptions::CustomLoopHeaderInfo> loopHeaderInfo =
652  generateLoopHeaderFn(rewriter, loc, loopRanges, givenTileSizes,
653  outerDestinationTensors);
654  if (failed(loopHeaderInfo)) {
655  return failure();
656  }
657 
658  SmallVector<Value> ivs;
659  SmallVector<Value> tiledResults;
660  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
661  if (failed(tiledBodyFn(rewriter, loc, ivs, loopHeaderInfo->tileOffset,
662  loopHeaderInfo->tileSizes,
663  loopHeaderInfo->destinationTensors, tiledResults,
664  resultOffsets, resultSizes))) {
665  return failure();
666  }
667 
668  if (failed(generateLoopTerminatorFn(rewriter, loc, tiledResults,
669  resultOffsets, resultSizes,
670  loopHeaderInfo->destinationTensors))) {
671  return failure();
672  }
673 
674  return loopHeaderInfo->loops;
675 }
676 
677 /// Generate the tile-loop nest using the loop construct specifed in `options`.
678 /// - `options`: Tiling options specified.
679 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
680 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
681 /// - `outerDestinationTensors` are the init values to use for the outer most
682 /// loop.
683 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
684 /// most
685 /// loop.
686 /// Returns the generated loops on success.
687 static FailureOr<SmallVector<LoopLikeOpInterface>> generateLoopNest(
688  RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options,
689  ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> givenTileSizes,
690  ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors,
691  GenerateTiledBodyFn tiledBodyFn) {
692  // If the tile sizes are all zero, no loops are generated. Just call the
693  // callback function to handle untiled case.
694  if (llvm::all_of(givenTileSizes, isZeroInteger)) {
695  SmallVector<Value> tiledResults;
696  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
697  auto tileOffsets =
698  llvm::map_to_vector(loopRanges, [](Range r) { return r.offset; });
699  auto tileSizes =
700  llvm::map_to_vector(loopRanges, [](Range r) { return r.size; });
701  if (failed(tiledBodyFn(rewriter, loc, ValueRange{}, tileOffsets, tileSizes,
702  destinationTensors, tiledResults, resultOffsets,
703  resultSizes))) {
704  return failure();
705  }
707  }
709  return generateLoopNestUsingForOp(rewriter, loc, loopRanges, givenTileSizes,
710  destinationTensors, tiledBodyFn);
711  }
714  rewriter, loc, loopRanges, givenTileSizes, numThreads,
715  options.mappingVector, destinationTensors, tiledBodyFn);
716  }
719  rewriter, loc, loopRanges, givenTileSizes, destinationTensors,
720  options.generateLoopHeaderFn, options.generateLoopTerminatorFn,
721  tiledBodyFn);
722  }
723  return rewriter.notifyMatchFailure(loc, "unhandled loop type");
724 }
725 
726 static FailureOr<SmallVector<Value>> createInitialTensorsForTiling(
727  RewriterBase &rewriter, TilingInterface op,
728  ReductionTilingStrategy reductionStrategy, ArrayRef<Range> iterationDomain,
729  ArrayRef<OpFoldResult> numThreads, ArrayRef<OpFoldResult> givenTileSizes,
730  const SetVector<unsigned> &reductionDims) {
731  SmallVector<Value> initTensors;
732  Location loc = op->getLoc();
733  if (reductionStrategy == ReductionTilingStrategy::FullReduction) {
734  if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, initTensors)))
735  return failure();
736  return initTensors;
737  }
738 
739  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
740  if (!redOp) {
741  return op->emitOpError(
742  "PartialReductionOuterReduction tiling strategy is only supported for "
743  "operations implementing PartialReductionOpInterface");
744  }
745  SmallVector<OpFoldResult> sizes(iterationDomain.size());
746  AffineExpr s0, s1, s2;
747  bindSymbols(rewriter.getContext(), s0, s1, s2);
748  AffineExpr sizeExpr = ((s0 - s1).ceilDiv(s2));
749  AffineExpr divExpr = s0.ceilDiv(s1);
750  for (auto [index, domain, tileSize] :
751  llvm::enumerate(iterationDomain, givenTileSizes)) {
752  if (!numThreads.empty()) {
753  // Untiled case.
754  if (isConstantIntValue(numThreads[index], 0)) {
756  rewriter, op.getLoc(), sizeExpr,
757  {domain.size, domain.offset, domain.stride});
758  continue;
759  }
760  sizes[index] = numThreads[index];
761  continue;
762  }
763 
764  // Non reduction dimensions/non-tiled dimensions.
765  if (!reductionDims.contains(index) || isConstantIntValue(tileSize, 0)) {
767  rewriter, op.getLoc(), sizeExpr,
768  {domain.size, domain.offset, domain.stride});
769  continue;
770  }
771 
772  if (reductionStrategy ==
774  sizes[index] = tileSize;
775  continue;
776  }
777 
778  assert(reductionStrategy ==
781  rewriter, op.getLoc(), sizeExpr,
782  {domain.size, domain.offset, domain.stride});
784  rewriter, op.getLoc(), divExpr, {normalizedRange, tileSize});
785  }
786  return redOp.generateInitialTensorForPartialReduction(rewriter, loc, sizes,
787  reductionDims);
788 }
789 
790 /// For the case of `ReductionTilingStrategy::PartialReductionOuterParallel`
791 /// the `PartialReductionOpInterface` methods need the index of the parallel
792 /// split reduction being executed.
795  ReductionTilingStrategy reductionStrategy, ValueRange ivs,
796  ArrayRef<OpFoldResult> numThreads,
797  ArrayRef<OpFoldResult> givenTileSizes,
798  const SetVector<unsigned> &reductionDims) {
799  SmallVector<OpFoldResult> splitReductionIvs;
800  splitReductionIvs.resize(reductionDims.size(), rewriter.getIndexAttr(0));
801  AffineExpr s0, s1;
802  bindSymbols(rewriter.getContext(), s0, s1);
803  AffineExpr divExpr = s0.floorDiv(s1);
804  int ivIndex = 0;
805  if (reductionStrategy ==
807  for (auto [index, reductionDim] : llvm::enumerate(reductionDims)) {
808  if (!numThreads.empty()) {
809  splitReductionIvs[index] = ivs[ivIndex++];
810  continue;
811  }
812  splitReductionIvs[index] = affine::makeComposedFoldedAffineApply(
813  rewriter, loc, divExpr,
814  ArrayRef<OpFoldResult>{ivs[ivIndex++], givenTileSizes[reductionDim]});
815  }
816  }
817  return splitReductionIvs;
818 }
819 
820 static FailureOr<TilingResult>
821 getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
822  ReductionTilingStrategy reductionStrategy,
823  ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
825  ArrayRef<OpFoldResult> numThreads,
826  ArrayRef<OpFoldResult> givenTileSizes,
827  const SetVector<unsigned> &reductionDims) {
828  if (reductionStrategy == ReductionTilingStrategy::FullReduction) {
829  return op.getTiledImplementation(rewriter, offsets, sizes);
830  }
831 
832  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
833  if (!redOp) {
834  return rewriter.notifyMatchFailure(
835  op, "PartialReductionOuterReduction tiling strategy is only "
836  "supported for operations "
837  "implementing PartialReductionOpInterface");
838  }
839 
840  SmallVector<OpFoldResult> splitReductionIvs =
841  getSplitReductionIvs(rewriter, op.getLoc(), reductionStrategy, ivs,
842  numThreads, givenTileSizes, reductionDims);
843  return redOp.tileToPartialReduction(rewriter, op.getLoc(), reductionStrategy,
844  regionIterArg, offsets, sizes,
845  reductionDims, splitReductionIvs);
846 }
847 
848 static LogicalResult getResultTilePosition(
849  RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy,
850  int64_t index, Value tiledResult, TilingInterface op,
852  ValueRange ivs, ArrayRef<OpFoldResult> numThreads,
853  ArrayRef<OpFoldResult> givenTileSizes,
854  const SetVector<unsigned> &reductionDims,
855  SmallVector<OpFoldResult> &resultOffset,
856  SmallVector<OpFoldResult> &resultSize) {
857 
858  if (reductionStrategy == ReductionTilingStrategy::FullReduction) {
859  return op.getResultTilePosition(rewriter, index, offsets, sizes,
860  resultOffset, resultSize);
861  }
862  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
863  if (!redOp) {
864  return rewriter.notifyMatchFailure(
865  op, "PartialReductionOuterReduction tiling strategy is only supported"
866  "for operations implementing PartialReductionOpInterface");
867  }
868  SmallVector<OpFoldResult> splitReductionIvs =
869  getSplitReductionIvs(rewriter, op.getLoc(), reductionStrategy, ivs,
870  numThreads, givenTileSizes, reductionDims);
871  return redOp.getPartialResultTilePosition(
872  rewriter, index, reductionStrategy, offsets, sizes, reductionDims,
873  splitReductionIvs, resultOffset, resultSize);
874 }
875 
876 static FailureOr<MergeResult>
877 mergeTilingResults(RewriterBase &rewriter, TilingInterface op,
878  ReductionTilingStrategy reductionStrategy,
879  const SetVector<unsigned> &reductionDims,
880  ValueRange partialResults) {
881  assert(reductionStrategy != ReductionTilingStrategy::FullReduction &&
882  "expected merge to be called for only partial reduction cases");
883 
884  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
885  if (!redOp) {
886  return rewriter.notifyMatchFailure(
887  op, "PartialReductionOuterReduction tiling strategy is only "
888  "supported for operations "
889  "implementing PartialReductionOpInterface");
890  }
891  return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
892  reductionDims);
893 }
894 
895 /// Append the specified additional `newInitOperands` operands to the
896 /// loops existing `init` operands (or similar), and replace `loopOp` with
897 /// the new loop that has the additional init operands. The loop body of
898 /// this loop is moved over to the new loop. `yieldTiledValuesFn`
899 /// is called to get the new tiled values returned, and the offset
900 /// and sizes at which the tiled value is inserted into the
901 /// new region iter_args that correspond to the newly added init operands.
902 template <typename LoopType>
903 FailureOr<LoopLikeOpInterface>
904 yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter,
905  ValueRange newInitOperands,
906  YieldTiledValuesFn yieldTiledValuesFn) {
907  return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
908 }
909 
910 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`.
911 template <>
912 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
913  scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
914  YieldTiledValuesFn yieldTiledValuesFn) {
915  OpBuilder::InsertionGuard g(rewriter);
916  Location loc = loopOp.getLoc();
917  rewriter.setInsertionPoint(loopOp);
918 
919  auto inits = llvm::to_vector(loopOp.getInitArgs());
920  inits.append(newInitOperands.begin(), newInitOperands.end());
921  auto newLoop = scf::ForOp::create(
922  rewriter, loc, loopOp.getLowerBound(), loopOp.getUpperBound(),
923  loopOp.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {},
924  loopOp.getUnsignedCmp());
925 
926  // Move the loop body to the new op.
927  Block *loopBody = loopOp.getBody();
928  Block *newLoopBody = newLoop.getBody();
929  rewriter.mergeBlocks(
930  loopBody, newLoopBody,
931  newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
932 
933  auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator());
934  rewriter.setInsertionPoint(yieldOp);
935 
936  SmallVector<Value> tiledValues;
937  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
938  ValueRange newRegionIterArgs =
939  newLoop.getRegionIterArgs().take_back(newInitOperands.size());
940  if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
941  newRegionIterArgs, tiledValues, resultOffsets,
942  resultSizes))) {
943  rewriter.eraseOp(newLoop);
944  return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");
945  }
946 
947  SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
948  for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
949  llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
950  resultSizes)) {
951  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
952  rewriter.getIndexAttr(1));
953  Value insert = tensor::InsertSliceOp::create(
954  rewriter, yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset,
955  resultSize, resultStride);
956  newYieldValues.push_back(insert);
957  }
958 
959  rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
960  rewriter.replaceOp(loopOp,
961  newLoop->getResults().take_front(loopOp.getNumResults()));
962  return cast<LoopLikeOpInterface>(newLoop.getOperation());
963 }
964 
965 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall`
966 template <>
967 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
968  scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
969  YieldTiledValuesFn yieldTiledValuesFn) {
970  OpBuilder::InsertionGuard g(rewriter);
971  Location loc = loopOp.getLoc();
972  rewriter.setInsertionPoint(loopOp);
973  auto inits = llvm::to_vector(loopOp.getOutputs());
974  inits.append(newInitOperands.begin(), newInitOperands.end());
975  auto newLoop = scf::ForallOp::create(
976  rewriter, loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
977  loopOp.getMixedStep(), inits, loopOp.getMapping(),
978  [](OpBuilder &, Location, ValueRange) {});
979 
980  // Move the region of the current block to the newly created op.
981  Block *loopBody = loopOp.getBody();
982  Block *newLoopBody = newLoop.getBody();
983  rewriter.mergeBlocks(
984  loopBody, newLoopBody,
985  newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
986 
987  auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
988  rewriter.setInsertionPoint(terminator);
989  SmallVector<Value> tiledValues;
990  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
991  ValueRange regionIterArgs =
992  newLoop.getRegionIterArgs().take_back(newInitOperands.size());
993  if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
994  regionIterArgs, tiledValues, resultOffsets,
995  resultSizes))) {
996  rewriter.eraseOp(newLoop);
997  return rewriter.notifyMatchFailure(loopOp,
998  "failed to get yielded tiled values");
999  }
1000 
1001  // Update the terminator.
1002  rewriter.setInsertionPointToEnd(terminator.getBody());
1003 
1004  for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
1005  tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
1006  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
1007  rewriter.getIndexAttr(1));
1008  tensor::ParallelInsertSliceOp::create(rewriter, terminator.getLoc(),
1009  tiledValue, iterArg, resultOffset,
1010  resultSize, resultStride);
1011  }
1012 
1013  rewriter.replaceOp(loopOp,
1014  newLoop->getResults().take_front(loopOp.getNumResults()));
1015  return cast<LoopLikeOpInterface>(newLoop.getOperation());
1016 }
1017 
1018 /// Implementation of `yieldTiledValuesAndReplaceLoop` for
1019 /// `LoopLikeOpInterface`, that just dispatches to the implementation for each
1020 /// supported loop type.
1021 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
1022  LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
1023  ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
1025  loopLikeOp.getOperation())
1026  .Case<scf::ForOp, scf::ForallOp>(
1027  [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
1029  loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
1030  })
1031  .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
1032  return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
1033  });
1034 }
1035 
1036 /// Method to add new init values to a loop nest. Updates `loops` in-place
1037 /// with new loops that use the `newInitValues`. The outer-loops are updated
1038 /// to yield the new result values of the inner loop. For the innermost loop,
1039 /// the call back `getNewYields` is invoked to get the additional values to
1040 /// yield form the innermost loop.
1041 static LogicalResult addInitOperandsToLoopNest(
1043  ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
1044  if (loops.empty())
1045  return success();
1046  OpBuilder::InsertionGuard g(rewriter);
1047  rewriter.setInsertionPoint(loops.front());
1048 
1049  SmallVector<Value> ivs;
1050  for (auto &loop : loops.drop_back()) {
1051  rewriter.setInsertionPoint(loop);
1052 
1053  // if loops.size() > 1 we assume that scf.for is used for the loops.
1054  auto forLoop = cast<scf::ForOp>(loop.getOperation());
1055 
1056  // Create a new loop with the new init values for this loop.
1057  SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs());
1058  newInits.append(newInitValues.begin(), newInitValues.end());
1059  auto newLoop = scf::ForOp::create(
1060  rewriter, forLoop.getLoc(), forLoop.getLowerBound(),
1061  forLoop.getUpperBound(), forLoop.getStep(), newInits,
1062  [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {},
1063  forLoop.getUnsignedCmp());
1064 
1065  // Merge the body of the new loop with the body of the old loops.
1066  SmallVector<Value> sourceBlockArgs;
1067  sourceBlockArgs.push_back(newLoop.getInductionVar());
1068  auto newRegionIterArgs = newLoop.getRegionIterArgs();
1069  sourceBlockArgs.append(
1070  newRegionIterArgs.begin(),
1071  std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
1072  rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
1073  rewriter.replaceOp(
1074  forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
1075  loop = newLoop;
1076  ivs.push_back(newLoop.getInductionVar());
1077  newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
1078  }
1079 
1080  // Update the loop body of the innermost loop to get new yield values.
1081  LoopLikeOpInterface innerMostLoop = loops.back();
1082  FailureOr<LoopLikeOpInterface> newInnerMostLoop =
1083  yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues,
1084  getNewTiledYieldsFn);
1085 
1086  if (failed(newInnerMostLoop))
1087  return innerMostLoop.emitOpError("failed to return additional yields");
1088  loops.back() = newInnerMostLoop.value();
1089 
1090  // Make all other loops except the innermost loops yield the values returned
1091  // by the inner loop.
1092  for (auto [outerLoop, innerLoop] :
1093  llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
1094  // Again assume that all the outer loops are scf.for operations.
1095  auto outerForLoop = cast<scf::ForOp>(outerLoop);
1096  auto outerLoopYield =
1097  cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
1098  SmallVector<Value> newYields =
1099  llvm::to_vector(outerLoopYield.getOperands());
1100  ValueRange additionalYields =
1101  innerLoop->getResults().take_back(newInitValues.size());
1102  newYields.append(additionalYields.begin(), additionalYields.end());
1103  rewriter.setInsertionPoint(outerLoopYield);
1104  rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
1105  }
1106  return success();
1107 }
1108 
1109 /// Implementation of tiling transformation of `op` that implements the
1110 /// `TilingInterface` using `scf.for` to iterate over the tiles.
1111 FailureOr<scf::SCFTilingResult>
1112 mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
1113  const scf::SCFTilingOptions &options) {
1114  if (failed(verifyOptions(rewriter, op.getLoc(), options))) {
1115  return failure();
1116  }
1117 
1118  OpBuilder::InsertionGuard guard(rewriter);
1119  rewriter.setInsertionPointAfter(op);
1120 
1121  // 1. Get the range of the loops that are represented by the operation.
1122  SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
1123 
1124  // 2. Materialize the tile sizes and/or number of threads;
1125  SmallVector<OpFoldResult> givenTileSizes, numThreads;
1126  std::tie(givenTileSizes, numThreads) =
1127  getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options);
1128 
1129  // Check if it is safe to tile. This is hold over from previous iterations
1130  // of tile to for-all. Consider dropping it.
1131  if (failed(checkTileSizes(op, options.loopType, options.reductionStrategy,
1132  givenTileSizes, numThreads))) {
1133  return failure();
1134  }
1135 
1136  // Get the reduction dims
1137  SetVector<unsigned> reductionDims =
1138  getSanitizedReductionDims(givenTileSizes, options);
1139 
1140  // 3. If there is an interchange specified, permute the iteration domain and
1141  // the tile sizes.
1142  SmallVector<int64_t> interchangeVector;
1143  if (!options.interchangeVector.empty()) {
1144  interchangeVector = fillInterchangeVector(options.interchangeVector,
1145  iterationDomain.size());
1146  assert(isPermutationVector(interchangeVector) &&
1147  "expected interchange vector to be a permutation");
1148 
1149  applyPermutationToVector(iterationDomain, interchangeVector);
1150  applyPermutationToVector(givenTileSizes, interchangeVector);
1151  if (!numThreads.empty())
1152  applyPermutationToVector(numThreads, interchangeVector);
1153  }
1154 
1155  FailureOr<TilingResult> tilingResult;
1156  // 4. Define the lambda function used later to generate the body of the
1157  // innermost tiled loop.
1158  GenerateTiledBodyFn innerYieldTiledValuesFn =
1159  [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
1160  ArrayRef<OpFoldResult> tileOffsets, ArrayRef<OpFoldResult> tileSizes,
1161  ValueRange regionIterArgs, SmallVector<Value> &tiledResults,
1162  SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
1164  -> LogicalResult {
1165  // 4b. If interchange was provided, apply inverse of the interchange
1166  // to get back the offsets/sizes in the order to be specified.
1167  SmallVector<OpFoldResult> tileOffsetsVec = llvm::to_vector(tileOffsets);
1168  SmallVector<OpFoldResult> tileSizesVec = llvm::to_vector(tileSizes);
1169  if (!interchangeVector.empty()) {
1170  auto inversePermutation = invertPermutationVector(interchangeVector);
1173  }
1174 
1175  // 5. Generate the tiled implementation within the inner most loop.
1176 
1177  // 5a. Clone the operation within the loop body.
1178  auto clonedOp = cast<TilingInterface>(
1179  cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
1180 
1181  // 5b. Early return cloned op if tiling is not happening. We can not
1182  // return the original op because it could lead to `rewriter.replaceOp(op,
1183  // op->getResults())` and users would get crash.
1184  if (llvm::all_of(givenTileSizes, isZeroInteger)) {
1185  tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
1186  tilingResult =
1187  TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
1188  /*generatedSlices=*/{}};
1189  return success();
1190  }
1191 
1192  // 5c. Tile the cloned operation.
1193  tilingResult =
1194  getTiledImplementation(rewriter, clonedOp, options.reductionStrategy,
1195  regionIterArgs, tileOffsetsVec, tileSizesVec,
1196  ivs, numThreads, givenTileSizes, reductionDims);
1197  if (failed(tilingResult)) {
1198  rewriter.eraseOp(clonedOp);
1199  return op.emitOpError("faild to tile operation");
1200  }
1201 
1202  // 5d. Delete the cloned operation.
1203  rewriter.eraseOp(clonedOp);
1204 
1205  // 5e. Compute the offsets at which the result values are to be inserted
1206  // back into its destinations.
1207  for (auto [index, tiledValue] :
1208  llvm::enumerate(tilingResult->tiledValues)) {
1209  tiledResults.push_back(tiledValue);
1210  SmallVector<OpFoldResult> resultOffset, resultSize;
1212  rewriter, options.reductionStrategy, index, tiledValue, op,
1213  tileOffsetsVec, tileSizesVec, ivs, numThreads, givenTileSizes,
1214  reductionDims, resultOffset, resultSize))) {
1215  for (auto op : tilingResult->tiledOps) {
1216  rewriter.eraseOp(op);
1217  }
1218  return rewriter.notifyMatchFailure(
1219  op, "failed to get slice of result produced");
1220  }
1221  resultOffsets.emplace_back(std::move(resultOffset));
1222  resultSizes.emplace_back(std::move(resultSize));
1223  }
1224 
1225  return success();
1226  };
1227 
1228  // 6. Find the destination tensors to use for the operation.
1229  FailureOr<SmallVector<Value>> maybeInits = createInitialTensorsForTiling(
1230  rewriter, op, options.reductionStrategy, iterationDomain, numThreads,
1231  givenTileSizes, reductionDims);
1232  if (failed(maybeInits)) {
1233  return rewriter.notifyMatchFailure(
1234  op, "unable to create initial tensors for tiling");
1235  }
1236  SmallVector<Value> &initTensors = maybeInits.value();
1237 
1238  // 7. Generate the tiled loops nest using the callback defined above.
1240  {
1241  FailureOr<SmallVector<LoopLikeOpInterface>> loopsOr = generateLoopNest(
1242  rewriter, op.getLoc(), options, iterationDomain, givenTileSizes,
1243  numThreads, initTensors, innerYieldTiledValuesFn);
1244  if (failed(loopsOr))
1245  return op.emitOpError("failed to generate tiling loops");
1246  assert(succeeded(tilingResult) &&
1247  "expected tiling result to be computed after loop generation");
1248  std::swap(loops, loopsOr.value());
1249  }
1250 
1251  if (loops.empty()) {
1252  // If loops are empty, the tiled op is used as the replacement for the
1253  // untiled op.
1254  return scf::SCFTilingResult{tilingResult->tiledOps,
1255  initTensors,
1256  loops,
1257  tilingResult->tiledValues,
1258  tilingResult->generatedSlices,
1259  {}};
1260  }
1261 
1262  auto loopResults = llvm::map_to_vector(loops.front()->getResults(),
1263  [](OpResult r) -> Value { return r; });
1264 
1265  // For the full reduction case, there is nothing more to do.
1266  if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) {
1267  return scf::SCFTilingResult{
1268  tilingResult->tiledOps, initTensors, loops, loopResults,
1269  tilingResult->generatedSlices, {}};
1270  }
1271 
1272  // The results of the loop needs to be merged.
1273  FailureOr<MergeResult> mergeResult = mergeTilingResults(
1274  rewriter, op, options.reductionStrategy, reductionDims, loopResults);
1275  if (failed(mergeResult)) {
1276  return rewriter.notifyMatchFailure(
1277  op, "Failed to merge partial results from tiling");
1278  }
1279  return scf::SCFTilingResult{tilingResult->tiledOps,
1280  initTensors,
1281  loops,
1282  mergeResult->replacements,
1283  tilingResult->generatedSlices,
1284  mergeResult->mergeOps};
1285 }
1286 
1287 FailureOr<scf::SCFTilingResult>
1289  PartialReductionOpInterface op,
1290  ArrayRef<OpFoldResult> tileSize) {
1293  options.setReductionTilingStrategy(
1295  options.setTileSizes(tileSize);
1296  SmallVector<unsigned> reductionDims;
1297  for (auto [index, iteratorType] : llvm::enumerate(op.getLoopIteratorTypes()))
1298  if (iteratorType == utils::IteratorType::reduction)
1299  reductionDims.push_back(index);
1300  options.setReductionDims(reductionDims);
1301  return tileUsingSCF(b, op, options);
1302 }
1303 
1304 //===----------------------------------------------------------------------===//
1305 // tileConsumerAndFuseProducersUsingSCF implementation.
1306 //===----------------------------------------------------------------------===//
1307 
1308 /// Return the untiled producer whose slice is used in a tiled consumer. The
1309 /// method traverses the tile loop nest (`loops`) if needed, and returns the
1310 /// `iter_args` of the outer most that is encountered. Traversing the
1311 /// iter_args indicates that this is a destination operand of the consumer. If
1312 /// there was no loop traversal needed, the second value of the returned tuple
1313 /// is empty.
1314 static std::tuple<OpResult, std::optional<OpOperand *>>
1317  std::optional<OpOperand *> destinationIterArg;
1318  assert(!loops.empty() && "expected non empty loops container");
1319  auto loopIt = loops.rbegin();
1320  while (loopIt != loops.rend() && isa<BlockArgument>(source->get())) {
1321  auto iterArg = cast<BlockArgument>(source->get());
1322  auto loop = *loopIt;
1323  if (iterArg.getOwner()->getParentOp() != loop)
1324  break;
1325  source = loop.getTiedLoopInit(iterArg);
1326  loopIt++;
1327  }
1328  if (loopIt == loops.rend())
1329  destinationIterArg = source;
1330  return {dyn_cast<OpResult>(source->get()), destinationIterArg};
1331 }
1332 
1333 /// Implementation of fusing producer of a single slice by computing the
1334 /// slice of the producer in-place.
1335 std::optional<scf::SCFFuseProducerOfSliceResult>
1337  RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1339  // 1. Get the producer of the source (potentially walking through
1340  // `iter_args` of nested `scf.for`)
1341  auto [fusableProducer, destinationInitArg] =
1342  getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
1343  loops);
1344  if (!fusableProducer)
1345  return std::nullopt;
1346  unsigned resultNumber = fusableProducer.getResultNumber();
1347 
1348  OpBuilder::InsertionGuard g(rewriter);
1349  rewriter.setInsertionPoint(candidateSliceOp);
1350 
1351  // 2. Clone the fused producer
1352  // 2a. Compute the destination operands to use for the cloned operation.
1353  SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
1354  Operation *fusableProducerOp = fusableProducer.getOwner();
1355  if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
1357  rewriter, fusableProducerOp->getLoc(), fusableProducerOp,
1358  origDestinationTensors)))
1359  return std::nullopt;
1360 
1361  clonedOpDestinationTensors = origDestinationTensors;
1362  if (destinationInitArg &&
1363  isa<DestinationStyleOpInterface>(fusableProducerOp)) {
1364  // 2b. If the producer is also destination style, then to maintain the
1365  // destination passing style, update the destination of the producer to be
1366  // the source of the slice.
1367  clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1368  }
1369  // 2c. Clone the fused producer.
1370  Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
1371  rewriter, fusableProducerOp, clonedOpDestinationTensors);
1372  // 2d. Update the source of the candidateSlice to be the cloned producer.
1373  // Easier to just clone the slice with different source since
1374  // replacements and DCE of cloned ops becomes easier
1375  SmallVector<Value> candidateSliceOpOperands =
1376  llvm::to_vector(candidateSliceOp->getOperands());
1377  candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
1378  tensor::ExtractSliceOp clonedCandidateSliceOp =
1379  mlir::clone(rewriter, candidateSliceOp,
1380  candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1381 
1382  // 3. Generate the tiled implementation of the producer of the source
1383  FailureOr<TilingResult> tileAndFuseResult =
1385  rewriter, clonedCandidateSliceOp,
1386  clonedProducerOp->getResult(resultNumber));
1387  if (failed(tileAndFuseResult))
1388  return std::nullopt;
1389  // Note: Do not delete the candidateSliceOp, since its passed in from the
1390  // caller.
1391  rewriter.replaceAllUsesWith(candidateSliceOp,
1392  tileAndFuseResult->tiledValues[0]);
1393  rewriter.eraseOp(clonedCandidateSliceOp);
1394  rewriter.eraseOp(clonedProducerOp);
1395 
1396  // 3. If the slice is for a destination operand, for example,
1397  //
1398  // ```mlir
1399  // %0 = linalg.init
1400  // %1 = linalg.fill .. outs(%0 : )
1401  // %2 = scf.for .. iter_args(%arg0 = %1) {
1402  // %3 = scf.for .. iter_args(%arg1 = %arg0) {
1403  // %4 = tensor.extract_slice %arg1 [..]
1404  // .. = linalg.matmul .. outs(%4 : )
1405  // }
1406  // }
1407  // ```
1408  //
1409  // the IR is currently
1410  //
1411  // ```
1412  // %0 = linalg.init
1413  // %1 = linalg.fill
1414  // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
1415  // %3 = scf.for .. iter_args(%arg1 = %arg0) {
1416  // %4 = tensor.extract_slice %arg1[..]
1417  // %5 = linalg.fill .. outs(%4 : )
1418  // .. = linalg.matmul .. outs(%5 : )
1419  // }
1420  // }
1421  // ```
1422  //
1423  // The untiled `linalg.fill` is still used as the `init_value` since it
1424  // was originally a destination operand of the untiled `linalg.matmul`.
1425  // When fusing an operand that is a destination operand, the iter_arg of
1426  // the outer most loop should be changed to use the destination of the
1427  // fused operation. With this the IR will be.
1428  //
1429  // ```
1430  // %0 = linalg.init
1431  // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
1432  // %2 = scf.for .. iter_args(%arg1 = %arg0) {
1433  // %3 = tensor.extract_slice %arg1[..]
1434  // %4 = linalg.fill .. outs(%3 : )
1435  // .. = linalg.matmul .. outs(%4 : )
1436  // }
1437  // }
1438  // ```
1439  if (destinationInitArg &&
1440  isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
1441  loops.front()
1442  ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1443  .set(origDestinationTensors[resultNumber]);
1444  }
1446  fusableProducer, tileAndFuseResult->tiledValues[0],
1447  tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices};
1448 }
1449 
1450 /// Reconstruct the fused producer from within the tiled-and-fused code.
1451 FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
1452  RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1453  scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
1455  ArrayRef<unsigned> yieldResultNumber) {
1456  if (loops.empty())
1457  return success();
1458 
1459  Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
1460  *tiledOwner = fusedProducerInfo.tiledOps[0];
1461 
1462  Location loc = originalOwner->getLoc();
1463  // a. collect all init Value to be appended
1464  SmallVector<unsigned> initNumberList =
1465  yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
1466  0, originalOwner->getNumResults()))
1467  : llvm::to_vector(yieldResultNumber);
1468  SmallVector<Value> initValueList;
1469  for (const auto &resultNumber : initNumberList) {
1470  FailureOr<Value> initValue = tensor::getOrCreateDestination(
1471  rewriter, loc, originalOwner->getResult(resultNumber));
1472  if (succeeded(initValue)) {
1473  initValueList.push_back(initValue.value());
1474  } else {
1475  return failure();
1476  }
1477  }
1478 
1479  SmallVector<Operation *> generatedSlices;
1480  YieldTiledValuesFn newYieldValuesFn =
1481  [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
1482  ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
1484  SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
1485  OpBuilder::InsertionGuard g(innerRewriter);
1486 
1487  // get sliceOp tile information
1488  SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
1489  sliceSizes = sliceOp.getMixedSizes();
1490 
1491  // expect all strides of sliceOp being 1
1492  if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
1493  return failure();
1494 
1495  unsigned sliceResultNumber =
1496  fusedProducerInfo.origProducer.getResultNumber();
1497 
1498  auto tilableOp = cast<TilingInterface>(originalOwner);
1499  // b. get iterDomain Offset and Sizes based on sliceOp tile
1500  SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
1501  // skip tensor.pack/unpack/pad, which expects single opResult
1502  if (tilableOp->getNumResults() > 1 &&
1503  failed(tilableOp.getIterationDomainTileFromResultTile(
1504  rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1505  iterDomainOffset, iterDomainSizes))) {
1506  // In theory, it is unnecessary to raise an error here. Actually
1507  // although it fails to reconstruct the result tensor, it should not
1508  // broke current fusion anyway. The reason why we must return failure
1509  // currently is that the callback function `newYieldValuesFn` will be
1510  // called after new init operand(s) has already been appended. It will
1511  // take more refactoring to make sure the init operands are added
1512  // consistently in the future. For more details, please refer to:
1513  // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
1514  return failure();
1515  }
1516 
1517  // c. calculate offsets and sizes info of all OpResults respectively based
1518  // on iteration Domain Tile
1519  SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
1520  for (const auto &resultNumber : initNumberList) {
1521  if (resultNumber == sliceResultNumber) {
1522  offsetList.push_back(sliceOffset);
1523  sizesList.push_back(sliceSizes);
1524  } else {
1525  assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1526  // infer result tile according to the iteration domain tile
1527  SmallVector<OpFoldResult> offset, sizes;
1528  if (failed(tilableOp.getResultTilePosition(
1529  rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1530  offset, sizes))) {
1531  return failure();
1532  }
1533  offsetList.push_back(offset);
1534  sizesList.push_back(sizes);
1535  }
1536  }
1537 
1538  // d. create `extract_slice` for `iter_args` for DPS operation if
1539  // necessary
1540  if (auto tiledDestStyleOp =
1541  dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1542  rewriter.setInsertionPoint(tiledDestStyleOp);
1543  for (const auto &&[index, newRegionArg] :
1544  llvm::enumerate(newRegionIterArgs)) {
1545  auto destSlice = tensor::ExtractSliceOp::create(
1546  rewriter, loc, newRegionArg, offsetList[index], sizesList[index],
1547  SmallVector<OpFoldResult>(offsetList[index].size(),
1548  rewriter.getIndexAttr(1)));
1549  generatedSlices.push_back(destSlice);
1550  unsigned resultNumber = initNumberList[index];
1551  rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
1552  tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1553  });
1554  }
1555  }
1556 
1557  // e. prepare tiled offset and sizes for later `insert_slice` creation by
1558  // caller
1559  Block *block = rewriter.getInsertionPoint()->getBlock();
1560  rewriter.setInsertionPoint(block->getTerminator());
1561  for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
1562  tiledResult.push_back(tiledOwner->getResult(resultNumber));
1563  tiledOffset.emplace_back(offsetList[index]);
1564  tiledSizes.emplace_back(sizesList[index]);
1565  }
1566  return success();
1567  };
1568 
1569  if (failed(addInitOperandsToLoopNest(rewriter, loops, initValueList,
1570  newYieldValuesFn))) {
1571  return failure();
1572  }
1573  return generatedSlices;
1574 }
1575 
1576 namespace {
1577 
1578 //===----------------------------------------------------------------------===//
1579 // SliceTrackingListener
1580 //===----------------------------------------------------------------------===//
1581 
1582 /// This class is a listener for tracking the insertion and removal of
1583 /// `tensor.extract_slice` ops in a worklist. This can be used in a greedy
1584 /// fusion algorithm to apply cleanup patterns in between fusion steps.
1585 class SliceTrackingListener : public RewriterBase::Listener {
1586 public:
1587  explicit SliceTrackingListener(
1588  std::optional<FrozenRewritePatternSet> patterns);
1589  SliceTrackingListener() = default;
1590 
1591  /// Adds the given list of operations to the worklist, and if present,
1592  /// applies the list of `patterns` to the newly added operations. This only
1593  /// processes the given operations and any newly inserted ones by the
1594  /// pattern set.
1595  LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
1596 
1597  /// Add to the new operation worklist if it is an extract_slice.
1598  void notifyOperationInserted(Operation *op,
1599  OpBuilder::InsertPoint previous) override;
1600 
1601  /// Shared helper for operation removal from the worklist.
1602  void removeOp(Operation *op);
1603 
1604  /// Remove the operation from the worklist.
1605  void notifyOperationErased(Operation *op) override;
1606 
1607  /// Remove the operation from the worklist.
1608  void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
1609 
1610  /// The worklist for this transformation keeps track of the slices to visit
1611  /// next for fusion.
1612  std::deque<tensor::ExtractSliceOp> worklist;
1613 
1614 private:
1615  /// Optional pattern set to apply when adding new operations to the
1616  /// worklist.
1617  std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
1618 };
1619 
1620 SliceTrackingListener::SliceTrackingListener(
1621  std::optional<FrozenRewritePatternSet> p) {
1622  patterns = std::move(p);
1623 }
1624 
1625 LogicalResult
1626 SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
1627  for (Operation *op : ops) {
1628  if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1629  worklist.push_back(slice);
1630  }
1631 
1632  if (!patterns)
1633  return success();
1634 
1635  return applyOpPatternsGreedily(
1636  ops, patterns.value(),
1637  GreedyRewriteConfig().setListener(this).setStrictness(
1639 }
1640 
1641 void SliceTrackingListener::notifyOperationInserted(
1642  Operation *op, OpBuilder::InsertPoint previous) {
1643  auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1644  if (!slice)
1645  return;
1646  worklist.push_back(slice);
1647 }
1648 
1649 // Scan the worklist for the given op and remove it if present. The
1650 // expectation is for the worklist to be small and for removal to be
1651 // relatively rare.
1652 void SliceTrackingListener::removeOp(Operation *op) {
1653  if (!isa<tensor::ExtractSliceOp>(op))
1654  return;
1655  auto iter = worklist.begin();
1656  while (iter != worklist.end()) {
1657  if (*iter == op)
1658  break;
1659  iter++;
1660  }
1661  if (iter == worklist.end())
1662  return;
1663 
1664  worklist.erase(iter);
1665 }
1666 
1667 void SliceTrackingListener::notifyOperationErased(Operation *op) {
1668  removeOp(op);
1669 }
1670 
1671 void SliceTrackingListener::notifyOperationReplaced(Operation *op,
1672  ValueRange replacement) {
1673  removeOp(op);
1674 }
1675 
1676 //===----------------------------------------------------------------------===//
1677 // ReplacementListener
1678 //===----------------------------------------------------------------------===//
1679 
1680 /// Listener that tracks updates replacements for values which can be mutated.
1681 /// This listener runs on top of the existing listener for the rewriter,
1682 /// to make sure external users can still run listeners.
1683 class ReplacementListener : public RewriterBase::ForwardingListener {
1684 public:
1685  ReplacementListener(DenseMap<Value, Value> &replacements,
1686  OpBuilder::Listener *listener)
1687  : ForwardingListener(listener), replacements(replacements) {}
1688 
1689  void updateReplacementValues(ValueRange origValues,
1690  ValueRange replaceValues) {
1691  // This can probably be written better, but just iterates over the map
1692  // and the new replacements for now.
1693  for (auto &[key, val] : replacements) {
1694  for (auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) {
1695  if (val == orig) {
1696  val = replace;
1697  }
1698  }
1699  }
1700  }
1701 
1702  void notifyOperationReplaced(Operation *op, Operation *newOp) override {
1703  ForwardingListener::notifyOperationReplaced(op, newOp);
1704  updateReplacementValues(op->getResults(), newOp->getResults());
1705  }
1706 
1707  void notifyOperationReplaced(Operation *op, ValueRange values) override {
1708  ForwardingListener::notifyOperationReplaced(op, values);
1709  updateReplacementValues(op->getResults(), values);
1710  }
1711 
1712 private:
1713  DenseMap<Value, Value> &replacements;
1714 };
1715 
1716 } // namespace
1717 
1718 /// Implementation of tile consumer and fuse producer greedily.
1719 FailureOr<scf::SCFTileAndFuseResult>
1721  RewriterBase &rewriter, TilingInterface consumer,
1723  // This transformation is only valid for ops that return values (i.e. not
1724  // valid to use with operations that have memref operands).
1725  if (!consumer->getNumResults()) {
1726  return rewriter.notifyMatchFailure(
1727  consumer, "invalid pattern for op with no results");
1728  }
1729 
1730  // 1. First tile the consumer.
1731  SetVector<Operation *> fusedProducers, tiledAndFusedOps;
1732 
1733  FailureOr<scf::SCFTilingResult> tilingResult =
1734  tileUsingSCF(rewriter, consumer, options.tilingOptions);
1735 
1736  if (failed(tilingResult))
1737  return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
1738  tiledAndFusedOps.insert_range(tilingResult->tiledOps);
1739 
1740  DenseMap<Value, Value> replacements;
1741  for (auto [origVal, replacement] :
1742  llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
1743  replacements[origVal] = replacement;
1744  }
1745 
1746  // If there are no loops generated, fusion is immaterial.
1747  auto &loops = tilingResult->loops;
1748  if (loops.empty()) {
1749  return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1750  replacements};
1751  }
1752 
1753  // Since the loop gets potentially replaced during fusion, we need to track
1754  // the mutation of replacement values. To do this, we attach a listener to
1755  // update the replacements as they happen.
1756  OpBuilder::Listener *previousListener = rewriter.getListener();
1757  auto resetListener =
1758  llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
1759  ReplacementListener replaceListener(replacements, previousListener);
1760  rewriter.setListener(&replaceListener);
1761 
1762  // 2. Typically, the operands of the tiled operation are slices of the
1763  // operands of the untiled operation. These are expressed in IR using
1764  // `tensor.extract_slice` operations with source being the operands of
1765  // the untiled operation. Create a worklist of these
1766  // `tensor.extract_slice` operations. If the producers of the source of
1767  // the `tensor.extract_slice` can be tiled such that the tiled value is
1768  // generated in-place, that effectively tiles + fuses the operations.
1769  struct WorklistItem {
1770  tensor::ExtractSliceOp candidateSlice;
1772  };
1773 
1774  SliceTrackingListener sliceTracker =
1775  SliceTrackingListener(options.cleanupPatterns);
1776 
1777  if (failed(
1778  sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1779  return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1780  }
1781  OpBuilder::InsertionGuard g(rewriter);
1782  while (!sliceTracker.worklist.empty()) {
1783  auto candidateSlice = sliceTracker.worklist.front();
1784  sliceTracker.worklist.pop_front();
1785 
1786  auto [fusableProducer, destinationInitArg] =
1787  getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(),
1788  loops);
1789  if (!fusableProducer)
1790  continue;
1791 
1792  std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1793  options.fusionControlFn(candidateSlice, fusableProducer,
1794  destinationInitArg.has_value());
1795  if (!controlFnResult)
1796  continue;
1797 
1798  WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
1799 
1800  // The operands of the fused producer might themselved be slices of
1801  // values produced by operations that implement the `TilingInterface`.
1802  // Add these operations to the worklist.
1803  std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1804  tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice,
1805  loops);
1806  if (!fusedResult)
1807  continue;
1808 
1809  SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
1810 
1811  if (worklistItem.controlFnResult.yieldProducerReplacement) {
1812  // Reconstruct and yield all opResult of fusableProducerOp by default.
1813  // The caller can specific which one to yield by designating optional
1814  // argument named `yieldResultNumber` of
1815  // `yieldReplacementForFusedProducer`.
1816  Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1817  FailureOr<SmallVector<Operation *>> newSlices =
1819  worklistItem.candidateSlice,
1820  fusedResult.value(), loops);
1821  if (failed(newSlices)) {
1822  return rewriter.notifyMatchFailure(
1823  fusableProducerOp, "failed to replacement value for this "
1824  "operation from within the tiled loop");
1825  }
1826  worklistCandidates.append(newSlices.value());
1827  for (auto [index, result] :
1828  llvm::enumerate(fusableProducerOp->getResults())) {
1829  replacements[result] = loops.front()->getResult(
1830  loops.front()->getNumResults() -
1831  fusableProducerOp->getNumResults() + index);
1832  }
1833  }
1834  if (Operation *tiledAndFusedOp =
1835  fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1836  fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1837  tiledAndFusedOps.insert(tiledAndFusedOp);
1838  }
1839 
1840  if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1841  return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1842  }
1843  }
1844 
1845  return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1846  replacements};
1847 }
1848 
1849 //===----------------------------------------------------------------------===//
1850 // tileAndFuseConsumerUsingSCF implementation.
1851 //===----------------------------------------------------------------------===//
1852 
1853 /// A utility function that checks whether the only use of the result of a
1854 /// tensor.insert_slice op is in a scf.yield op.
1855 static LogicalResult
1856 checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
1857  Value result = candidateSliceOp.getResult();
1858  Value::use_range uses = result.getUses();
1859  if (!llvm::hasSingleElement(uses)) {
1860  LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
1861  return failure();
1862  }
1863  OpOperand &operandUse = (*uses.begin());
1864  Operation *userOp = operandUse.getOwner();
1865  if (!isa<scf::YieldOp>(userOp)) {
1866  LLVM_DEBUG(llvm::dbgs()
1867  << "Expected scf.yield to be the only user, but got -> "
1868  << (*userOp));
1869  return failure();
1870  }
1871  if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
1872  LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
1873  "be in the same block\n");
1874  return failure();
1875  }
1876  return success();
1877 }
1878 
1879 /// An utility to get the first user of the given loopOp. If any of user stay
1880 /// in different block of loopOp, return failure.
1881 static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
1882  if (!isa<LoopLikeOpInterface>(loopOp))
1883  return failure();
1884  Operation *firstUserOfLoop = nullptr;
1885  for (Operation *userOp : loopOp->getUsers()) {
1886  // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
1887  // block with any other types of operation. Thus, just redirecting to its
1888  // parent `InParallelOp`. E.g.
1889  //
1890  // ```
1891  // %1 = scf.for {
1892  // ...
1893  // }
1894  // %2 = consumerOp ins(%1, ...)
1895  // scf.forall.in_parallel {
1896  // tensor.parallel_insert_slice %1
1897  // }
1898  // ```
1899  // where `InParallelOp` but not `ParallelInsertSlice` stays in the same
1900  // same block with `consumerOp`.
1901  if (isa<tensor::ParallelInsertSliceOp>(userOp))
1902  userOp = userOp->getParentOfType<scf::InParallelOp>();
1903 
1904  if (loopOp->getBlock() != userOp->getBlock())
1905  return failure();
1906 
1907  if (!firstUserOfLoop || userOp->isBeforeInBlock(firstUserOfLoop))
1908  firstUserOfLoop = userOp;
1909  }
1910  return firstUserOfLoop;
1911 }
1912 
1913 /// This utility currently checks whether the first userOp of loop is NOT
1914 /// before the last defineOp of consumer operand. Because that we need to move
1915 /// the whole loop structure right before the `firstUserOfLoop`. This utility
1916 /// thus helps ensuring that no invalid IR is formed, i.e. no backward slice
1917 /// of consumerOp is dominated by the `firstUserOfLoop`. Saying that:
1918 ///
1919 /// ```
1920 /// %0 = scf.for() {
1921 /// ...
1922 /// }
1923 /// ...
1924 /// %1 = firstUserOfLoop(%0)
1925 /// ...
1926 /// %2 = lastDefOfConsumerOperand
1927 /// ...
1928 /// %3 = consumerOp(%2)
1929 /// ```
1930 ///
1931 /// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it
1932 /// would be invalid to move the `loopOp` right before the `firstUserOfLoop`,
1933 /// a.k.a. use-def chain violation:
1934 ///
1935 /// ```
1936 /// %0:2 = scf.for() {
1937 /// // use before define error
1938 /// %3 = tiledConsumerOp(%2)
1939 /// }
1940 /// %1 = firstUserOfLoop(%0)
1941 /// ...
1942 /// %2 = lastDefOfConsumerOperand
1943 /// ```
1944 ///
1945 /// @param loopOp: loop operation
1946 /// @param consumerOp: consumer operation
1947 /// @param reorderOperations: the flag controls whether to reorder the
1948 /// backward slice w.r.t. the defineOp of `consumerOp` operands.
1949 /// @return: computed backward slice of consumerOp, but excluding those
1950 /// already dominates `firstUserOfLoop`.
1951 static FailureOr<llvm::SetVector<Operation *>>
1953  bool reorderOperations) {
1954  FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
1955  if (failed(firstUserOfLoop))
1956  return failure();
1957 
1959  DominanceInfo dominanceInfo;
1960  options.inclusive = true;
1961  options.omitBlockArguments = true;
1962  bool includeLoopOp = false;
1963  options.filter = [&](Operation *op) {
1964  if (op == loopOp) {
1965  includeLoopOp = true;
1966  return false;
1967  }
1968  // Cut off the slice to not include any operation that already dominates
1969  // firstUserOfLoop.
1970  return !dominanceInfo.properlyDominates(op, *firstUserOfLoop);
1971  };
1973  for (auto operand : consumerOp->getOperands()) {
1974  LogicalResult result = getBackwardSlice(operand, &slice, options);
1975  assert(result.succeeded() && "expected a backward slice");
1976  (void)result;
1977  }
1978 
1979  if (!slice.empty()) {
1980  // If consumerOp has one producer, which is also the user of loopOp.
1981  // E.g.
1982  // ```
1983  // %0 = %loopOp
1984  // %1 = consumerOp1 ins(%0)
1985  // %2 = consumerOp2 ins(%0, %1)
1986  // ```
1987  // We can not fuse consumerOp2 into loopOp due to UD chain, unless
1988  // consumerOp1 has already been fused into loopOp before.
1989  if (includeLoopOp || !reorderOperations)
1990  return failure();
1991  }
1992 
1993  return slice;
1994 }
1995 
1996 /// Fetches the OpOperand of the first valid user (and use) of the value `val`
1997 /// which implements `TilingInterface` and `DestinationStyleOpInterface`.
1998 /// Returns failure otherwise.
1999 static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
2000  Operation *loopOp,
2001  unsigned resultNumber) {
2002  if (!isa<LoopLikeOpInterface>(loopOp))
2003  return failure();
2004  Value val = loopOp->getResult(resultNumber);
2005  Block *loopBlock = loopOp->getBlock();
2006  for (OpOperand &opOperand : val.getUses()) {
2007  Operation *consumerOp = opOperand.getOwner();
2008  // Step 1. Check if the user is tilable.
2009  if (!isa<TilingInterface>(consumerOp) ||
2010  !isa<DestinationStyleOpInterface>(consumerOp)) {
2011  // TODO: We have to init result of consumer before scf.for, use
2012  // DestinationStyleOpInterface to get result shape from init for now.
2013  // Add support for other op such as op has InferTypeOpInterface.
2014  continue;
2015  }
2016  // Step 2. Check if user stay in the same block.
2017  if (loopBlock != consumerOp->getBlock())
2018  continue;
2019  // Step 3. Check if user has succeeding user. Otherwise, it usually
2020  // represents already tiled.
2021  if (consumerOp->use_empty())
2022  continue;
2023  // Step 4. Check assumption for loop with `reorderOperations` enabled.
2024  FailureOr<llvm::SetVector<Operation *>> slice =
2025  checkAssumptionForLoop(loopOp, consumerOp, true);
2026  if (failed(slice))
2027  continue;
2028  // Step 5. If backward sice is not empty, move them before
2029  // firstUserOfLoop.
2030  if (!slice->empty()) {
2031  mlir::topologicalSort(*slice);
2032  FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
2033  assert(succeeded(firstUserOfLoop) && "First user of loop is not found");
2034  for (auto op : *slice) {
2035  rewriter.moveOpBefore(op, *firstUserOfLoop);
2036  }
2037  }
2038  return &opOperand;
2039  }
2040  return failure();
2041 }
2042 
2043 /// Fetch the untiled consumer of the outermost scf.for's result which is
2044 /// yielded by a tensor.insert_slice from the innermost scf.for. This function
2045 /// makes the following assumptions :
2046 /// 1. tensor.insert_slice has scf.yield as its only user.
2047 /// 2. scf.for's corresponding result has only one use.
2048 /// 3. The `loops` passed in are perfectly nested `scf.for` operations.
2049 static FailureOr<OpOperand *>
2051  tensor::InsertSliceOp candidateSliceOp,
2053  assert(!loops.empty() && "unexpected loops to be empty");
2054  // 1. Expect slice to be part of the body of the inner most loop.
2055  Operation *containingOp = candidateSliceOp->getParentOp();
2056  if (containingOp != loops.back()) {
2057  return rewriter.notifyMatchFailure(
2058  candidateSliceOp,
2059  "expected slice to be within body of inner-most loop");
2060  }
2061 
2062  // 2. Check that the loop is perfectly nested.
2063  if (!isPerfectlyNestedForLoops(loops)) {
2064  return rewriter.notifyMatchFailure(
2065  candidateSliceOp, "expected passed loops to be perfectly nested.");
2066  }
2067 
2068  if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
2069  return failure();
2070  Value sliceResult = candidateSliceOp.getResult();
2071 
2072  // 3. Fetch the corresponding output.
2073  OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
2074  unsigned resultNumber = yieldOpOperand.getOperandNumber();
2075 
2076  scf::ForOp topLevelForOp = cast<scf::ForOp>(loops.front().getOperation());
2077 
2078  return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber);
2079 }
2080 
2081 /// Fetch the first untiled consumer of a scf.forall's result which is yielded
2082 /// by a tensor.parallel_insert_slice.
2083 static FailureOr<OpOperand *>
2085  tensor::ParallelInsertSliceOp candidateSliceOp,
2087  assert(!loops.empty() && "unexpected loops to be empty");
2088  // 1. Check that the surrounding loop is a single scf.forall loop.
2089  if (loops.size() != 1) {
2090  return rewriter.notifyMatchFailure(
2091  candidateSliceOp, "expected single surrounding scf.forall");
2092  }
2093  auto forallOp = dyn_cast<scf::ForallOp>(loops.front().getOperation());
2094  if (!forallOp) {
2095  return rewriter.notifyMatchFailure(
2096  candidateSliceOp, "expected single surrounding scf.forall");
2097  }
2098 
2099  // 2. Fetch the corresponding output
2100  Value sliceDest = candidateSliceOp.getDest();
2101  auto iterArg = dyn_cast<BlockArgument>(sliceDest);
2102  if (!iterArg)
2103  return failure();
2104  if (iterArg.getOwner()->getParentOp() != forallOp)
2105  return failure();
2106 
2107  unsigned resultNumber =
2108  forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
2109  .getResultNumber();
2110 
2111  return getConsumerFromLoopUses(rewriter, forallOp, resultNumber);
2112 }
2113 
2114 /// A utility to fetch an untiled consumer of
2115 /// tensor.insert_slice/tensor.parallel_insert_slice.
2116 static FailureOr<SmallVector<OpOperand *>> getUntiledConsumerOperandsFromSlices(
2117  RewriterBase &rewriter, ArrayRef<Operation *> sliceOps,
2119  assert(!loops.empty() && "unexpected empty loops");
2120  assert(!sliceOps.empty() && "unexpected empty list of candidate slices");
2121  SmallVector<OpOperand *> fusedOperands;
2122  for (auto sliceOp : sliceOps) {
2123  FailureOr<OpOperand *> fusedOperand =
2125  .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2126  [&](auto op) {
2127  return getUntiledConsumerFromSlice(rewriter, op, loops);
2128  })
2129  .Default([&](Operation *op) {
2130  return rewriter.notifyMatchFailure(op, "unhandled slice type");
2131  });
2132  if (failed(fusedOperand)) {
2133  return failure();
2134  }
2135  if (!fusedOperands.empty() &&
2136  fusedOperand.value()->getOwner() != fusedOperands.front()->getOwner()) {
2137  return rewriter.notifyMatchFailure(
2138  fusedOperand.value()->getOwner(),
2139  "all candidate slices must be to the same consumer");
2140  }
2141  fusedOperands.push_back(fusedOperand.value());
2142  }
2143  return fusedOperands;
2144 }
2145 
2146 template <typename InsertSliceOpTy>
2147 static tensor::InsertSliceOp cloneAsInsertSlice(RewriterBase &rewriter,
2148  InsertSliceOpTy sliceOp);
2149 
2150 template <>
2151 tensor::InsertSliceOp
2152 cloneAsInsertSlice<tensor::InsertSliceOp>(RewriterBase &rewriter,
2153  tensor::InsertSliceOp insertSliceOp) {
2154  return cast<tensor::InsertSliceOp>(
2155  rewriter.clone(*insertSliceOp.getOperation()));
2156 }
2157 
2158 template <>
2159 tensor::InsertSliceOp cloneAsInsertSlice<tensor::ParallelInsertSliceOp>(
2160  RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp) {
2161  return tensor::InsertSliceOp::create(
2162  rewriter, insertSliceOp->getLoc(), insertSliceOp.getSource(),
2163  insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(),
2164  insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
2165 }
2166 
2169  ArrayRef<Operation *> candidateSlices) {
2170  assert(!candidateSlices.empty() &&
2171  "unexpected empty list of slices to clone");
2173  for (auto sliceOp : candidateSlices) {
2174  TypeSwitch<Operation *>(sliceOp)
2175  .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2176  [&](auto op) {
2177  auto clonedOp = cloneAsInsertSlice(rewriter, op);
2178  clonedSlices.push_back(clonedOp);
2179  })
2180  .Default([&](Operation *op) {
2181  // Assert here assuming this has already been checked.
2182  assert(0 && "unexpected slice type while cloning as insert slice");
2183  });
2184  }
2185  return clonedSlices;
2186 }
2187 
2188 /// Implementation of fusing consumer of a single slice by computing the
2189 /// slice of the consumer in-place for scf loop.
2190 FailureOr<scf::SCFFuseConsumerOfSliceResult>
2192  RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
2194  if (candidateSlices.empty()) {
2195  return rewriter.notifyMatchFailure(
2196  rewriter.getUnknownLoc(),
2197  "no candidate slices provided for consumer fusion");
2198  }
2199  // Return if `loops` is empty, return an error for now. Caller is expected
2200  // to handle this case.
2201  if (loops.empty()) {
2202  return rewriter.notifyMatchFailure(
2203  candidateSlices.front(),
2204  "cannot call tile and fuse consumer with an empty loop nest");
2205  }
2206 
2207  if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
2208  llvm::all_of(candidateSlices,
2209  llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
2210  return rewriter.notifyMatchFailure(
2211  candidateSlices.front(),
2212  "candidates slices need to be all `tensor.extract_slice`s or "
2213  "`tensor.parallel_insert_slice`s");
2214  }
2215 
2216  // 1. Get the consumer of scf.for for the result yielded by
2217  // tensor.insert_slice/parallel_insert_slice.
2218  SmallVector<OpOperand *> consumerOpOperands;
2219  Operation *consumerOp;
2220  {
2221  FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
2222  getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
2223  if (failed(maybeConsumerOpOperand)) {
2224  return rewriter.notifyMatchFailure(candidateSlices.front(),
2225  "could not fetch consumer to fuse");
2226  }
2227  std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
2228  consumerOp = consumerOpOperands.front()->getOwner();
2229  }
2230 
2231  LoopLikeOpInterface outerMostLoop = loops.front();
2232  LoopLikeOpInterface innerMostLoop = loops.back();
2233 
2234  // Check assumption for loop with `reorderOperations` disabled.
2235  if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
2236  return rewriter.notifyMatchFailure(
2237  outerMostLoop, "the first user of loop should not dominate any define "
2238  "of consumer operand(s)");
2239  }
2240 
2241  OpBuilder::InsertionGuard g(rewriter);
2242 
2243  // 2. Check consumer is not using scf loop's output as init.
2244  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
2245  if (!dstOp)
2246  return rewriter.notifyMatchFailure(consumerOp,
2247  "consumer op is not DPS operation");
2248  if (llvm::any_of(consumerOpOperands, [&](OpOperand *opOperand) {
2249  return dstOp.isDpsInit(opOperand);
2250  })) {
2251  return rewriter.notifyMatchFailure(
2252  consumerOp,
2253  "consumer op taking the result of scf.for as init is not supported");
2254  }
2255  SmallVector<Value> newInits = llvm::to_vector(dstOp.getDpsInits());
2256 
2257  // 3. Move the whole loop structure right before firstUserOfLoop, the
2258  // dominance should be already ensured by `checkAssumptionForLoop`.
2259  FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(outerMostLoop);
2260  if (failed(firstUserOfLoop)) {
2261  return rewriter.notifyMatchFailure(
2262  outerMostLoop, "could not find the first user of outer most loop");
2263  }
2264  rewriter.moveOpBefore(outerMostLoop, *firstUserOfLoop);
2265 
2266  // 4. Set insertion point before terminator op of the loop and create a new
2267  // tensor.insert_slice. In the scf.for case this is a clone of the
2268  // candidateSliceOp whereas in the scf.forall case this is created from the
2269  // operands of tensor.parallel_insert_slice.
2270  if (auto sliceOp =
2271  dyn_cast<tensor::ParallelInsertSliceOp>(candidateSlices.front())) {
2272  auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
2273  rewriter.setInsertionPoint(newForallOp.getTerminator());
2274  } else {
2275  rewriter.setInsertionPoint(candidateSlices.front());
2276  }
2277  // 5.a. Clone all the candidate slices as equivalent insert slice ops.
2278  SmallVector<tensor::InsertSliceOp> clonedInsertSlices =
2279  cloneAsInsertSlices(rewriter, candidateSlices);
2280 
2281  // 5.b. Clone consumer op.
2282  auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
2283  SmallVector<unsigned> operandNumbers =
2284  llvm::map_to_vector(consumerOpOperands, [](OpOperand *opOperand) {
2285  return opOperand->getOperandNumber();
2286  });
2287  SmallVector<OpOperand *> clonedOpFusedOperandsList =
2288  llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
2289  return &clonedConsumerOp->getOpOperand(operandNum);
2290  });
2291 
2292  // 5.c. Replace all uses of the loop result with the result of the cloned
2293  // tensor.insert_slice.
2294  rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
2295  for (auto [operandToReplace, clonedSliceOp] :
2296  llvm::zip_equal(clonedOpFusedOperandsList, clonedInsertSlices)) {
2297  operandToReplace->set(clonedSliceOp.getResult());
2298  }
2299  });
2300 
2301  // 6. Perform tiling of the cloned consumer and replace the operand at
2302  // `operandNumber` with the source of the cloned tensor.insert_slice op.
2303  FailureOr<TilingResult> tileAndFuseResult =
2304  tensor::replaceInsertSlicesWithTiledConsumer(rewriter, clonedInsertSlices,
2305  clonedOpFusedOperandsList);
2306  if (failed(tileAndFuseResult)) {
2307  return failure();
2308  }
2309 
2310  auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
2311  for (auto [operandNum, clonedSliceOp] :
2312  llvm::zip_equal(operandNumbers, clonedInsertSlices)) {
2313  rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNum),
2314  clonedSliceOp.getSource());
2315  }
2316 
2317  // 7. Reconstruct [nested] loop with new inits.
2318  YieldTiledValuesFn newYieldValuesFn =
2319  [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
2320  ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
2322  SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
2323  OpBuilder::InsertionGuard g(innerRewriter);
2324  // 8. Set inner insertPoint right before tiled consumer op.
2325  innerRewriter.setInsertionPoint(tiledConsumerOp);
2326 
2327  SmallVector<SmallVector<OpFoldResult>> allOffsets, allSizes;
2328  for (auto candidateSliceOp : clonedInsertSlices) {
2329  SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
2330  SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
2331  SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
2332 
2333  // 9. Check all insert stride is 1.
2334  if (!llvm::all_of(strides, isOneInteger)) {
2335  return rewriter.notifyMatchFailure(
2336  candidateSliceOp, "containingOp's result yield with stride");
2337  }
2338 
2339  allOffsets.emplace_back(std::move(offsets));
2340  allSizes.emplace_back(std::move(sizes));
2341  }
2342 
2343  // 10. Try to get iter domain position from input position. Use
2344  // clonedConsumerOp instead of tiledConsumerOp, because the iteration
2345  // domain may require index computation based on the result size. The
2346  // sizes and offsets should be the same either way, but using
2347  // tiledConsumerOp could lead to some chained unnecessary extra index
2348  // computation.
2349  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
2350  if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTiles(
2351  rewriter, operandNumbers, allOffsets, allSizes, iterDomainOffsets,
2352  iterDomainSizes))) {
2353  return rewriter.notifyMatchFailure(
2354  clonedConsumerOp,
2355  "can't get iter domain position from input position");
2356  }
2357 
2358  // 11. Try to fetch the offset and size for all results of the cloned
2359  // consumer. This would then be used to form the corresponding
2360  // tensor.insert_slice/parallel_insert_slice later.
2361  unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
2363  totalNumResultsOfConsumer);
2365  totalNumResultsOfConsumer);
2366  for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
2367  if (failed(tiledConsumerOp.getResultTilePosition(
2368  rewriter, idx, iterDomainOffsets, iterDomainSizes,
2369  resultOffsets[idx], resultSizes[idx]))) {
2370  return rewriter.notifyMatchFailure(
2371  tiledConsumerOp,
2372  "can't get result domain position from iter domain position");
2373  }
2374  }
2375 
2376  // 12. Create `extract_slice` for `iter_args` for DPS operation if
2377  // necessary.
2378  if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
2379  tiledConsumerOp.getOperation())) {
2380  rewriter.setInsertionPoint(tiledDestStyleOp);
2381  for (const auto &&[index, newRegionArg] :
2382  llvm::enumerate(newRegionIterArgs)) {
2383  auto destSlice = tensor::ExtractSliceOp::create(
2384  rewriter, loc, newRegionArg, resultOffsets[index],
2385  resultSizes[index],
2386  SmallVector<OpFoldResult>(resultOffsets[index].size(),
2387  rewriter.getIndexAttr(1)));
2388  // Make a copy of index to avoid a capturing structured binding, which
2389  // is a C++20 extension.
2390  auto dstNumber = index;
2391  rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
2392  tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
2393  });
2394  }
2395  }
2396 
2397  // 13. Prepare tiled offset and sizes for later `insert_slice` creation by
2398  // caller.
2399  Block *block = rewriter.getInsertionPoint()->getBlock();
2400  rewriter.setInsertionPoint(block->getTerminator());
2401  for (const auto &&[index, result] :
2402  llvm::enumerate(tiledConsumerOp->getResults())) {
2403  tiledResult.push_back(result);
2404  tiledOffset.emplace_back(resultOffsets[index]);
2405  tiledSizes.emplace_back(resultSizes[index]);
2406  }
2407  return success();
2408  };
2409  // 14. Add new inits to [nested] loops.
2410  if (failed(addInitOperandsToLoopNest(rewriter, loops, newInits,
2411  newYieldValuesFn))) {
2412  return rewriter.notifyMatchFailure(tiledConsumerOp,
2413  "unable to add new inits to nest loop");
2414  }
2415 
2416  // 15. Replace the result of scf loop and consumer op with new loop's
2417  // results.
2418 
2419  for (auto &&[oldResult, newResult] :
2420  llvm::zip(consumerOp->getResults(),
2421  loops.front()->getResults().take_back(newInits.size()))) {
2422  rewriter.replaceAllUsesWith(oldResult, newResult);
2423  }
2424 
2425  // 16. Need to erase the old scf loop and the cloned consumer op.
2426  rewriter.eraseOp(clonedConsumerOp);
2427 
2428  SmallVector<OpOperand *> tiledAndFusedOpOperands =
2429  llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
2430  return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum);
2431  });
2433  std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands),
2434  std::move(tileAndFuseResult->tiledOps)};
2435 }
2436 
2437 //===----------------------------------------------------------------------===//
2438 // lowerToLoopsUsingSCFForOp implementation.
2439 //===----------------------------------------------------------------------===//
2440 
2441 FailureOr<SmallVector<scf::ForOp>>
2443  TilingInterface op) {
2444  // TODO: Handle cases where the op has results if needed.
2445  if (op->getNumResults() > 0) {
2446  return rewriter.notifyMatchFailure(
2447  op, "unable to lower to loops operations with return values");
2448  }
2449 
2450  SmallVector<Range> domain = op.getIterationDomain(rewriter);
2451  SmallVector<Value> ivs;
2453  Location loc = op.getLoc();
2454  for (auto loopRange : domain) {
2455  Value offsetVal =
2456  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
2457  Value sizeVal =
2458  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
2459  Value strideVal =
2460  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
2461  auto loop = scf::ForOp::create(rewriter, op.getLoc(), offsetVal, sizeVal,
2462  strideVal, ValueRange{});
2463  loops.push_back(loop);
2464  ivs.push_back(loop.getInductionVar());
2465  rewriter.setInsertionPoint(loop.getBody()->getTerminator());
2466  }
2467  if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
2468  return failure();
2469  }
2470  return loops;
2471 }
static llvm::ManagedStatic< PassManagerOptions > options
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult givenTileSize, OpFoldResult numThreads, OpFoldResult iterationSize)
Returns true if the maximum tile offset tileSize * numThreads-1 is less than iterationSize.
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizesWithForAllOp(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > givenTileSizes, ArrayRef< OpFoldResult > numThreads)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > givenTileSizes)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static LogicalResult checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp)
A utility function that checks whether the only use of the result of a tensor.insert_slice op is in a...
static LogicalResult checkTileSizes(TilingInterface op, scf::SCFTilingOptions::LoopType loopType, ReductionTilingStrategy reductionStrategy, ArrayRef< OpFoldResult > givenTileSizes, ArrayRef< OpFoldResult > numThreads)
Checks if any of the tiled loops are not parallel.
static FailureOr< SmallVector< LoopLikeOpInterface > > generateLoopNestUsingForOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes, ValueRange outerDestinationTensors, GenerateTiledBodyFn tiledBodyFn)
Generate the tile-loop nest using scf.for operation.
std::function< LogicalResult(RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, SmallVector< Value > &tiledValues, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> YieldTiledValuesFn
Typedef for function that allows returning additional yielded values during yieldTiledValuesAndReplac...
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes)
Function to return the bounds of the loops to be generated.
static SmallVector< OpFoldResult > getSplitReductionIvs(RewriterBase &rewriter, Location loc, ReductionTilingStrategy reductionStrategy, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims)
For the case of ReductionTilingStrategy::PartialReductionOuterParallel the PartialReductionOpInterfac...
std::function< LogicalResult(RewriterBase &rewriter, Location Loc, ValueRange ivs, ArrayRef< OpFoldResult > tileOffsets, ArrayRef< OpFoldResult > tileSizes, ValueRange outerDestinationTensors, SmallVector< Value > &tiledResults, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> GenerateTiledBodyFn
Typedef for function that implements the body of a tiled loop.
static FailureOr< SmallVector< LoopLikeOpInterface > > generateLoopNest(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes, ArrayRef< OpFoldResult > numThreads, ValueRange destinationTensors, GenerateTiledBodyFn tiledBodyFn)
Generate the tile-loop nest using the loop construct specifed in options.
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims)
static FailureOr< llvm::SetVector< Operation * > > checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, bool reorderOperations)
This utility currently checks whether the first userOp of loop is NOT before the last defineOp of con...
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Append the specified additional newInitOperands operands to the loops existing init operands (or simi...
static FailureOr< OpOperand * > getConsumerFromLoopUses(RewriterBase &rewriter, Operation *loopOp, unsigned resultNumber)
Fetches the OpOperand of the first valid user (and use) of the value val which implements TilingInter...
static FailureOr< MergeResult > mergeTilingResults(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, const SetVector< unsigned > &reductionDims, ValueRange partialResults)
static FailureOr< SmallVector< LoopLikeOpInterface > > generateLoopNestUsingCustomOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes, ValueRange outerDestinationTensors, const scf::SCFTilingOptions::GenerateLoopHeaderFn &generateLoopHeaderFn, const scf::SCFTilingOptions::GenerateLoopTerminatorFn &generateLoopTerminatorFn, GenerateTiledBodyFn tiledBodyFn)
Generate the tile-loop nest using custom loop operation.
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, OpFoldResult offset, OpFoldResult givenTileSize)
Returns the bounded tile size given the current offset, loopRange and tileSize, i....
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, ArrayRef< Range > iterationDomain, const scf::SCFTilingOptions &options)
Method to instantiate the tile sizes and/or number of threads specified by the user.
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< LoopLikeOpInterface > loops)
Return the untiled producer whose slice is used in a tiled consumer.
static bool tileDividesIterationDomain(Range loopRange)
Check if stride evenly divides the trip count size - offset.
static Operation * cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *op, ValueRange newDestArgs)
Clones the operation and updates the destination if the operation implements the DestinationStyleOpIn...
static SetVector< unsigned > getSanitizedReductionDims(ArrayRef< OpFoldResult > givenTileSizes, const scf::SCFTilingOptions &options)
Get the reduction dims that are tiled.
static FailureOr< Operation * > getFirstUserOfLoop(Operation *loopOp)
An utility to get the first user of the given loopOp.
static LogicalResult addInitOperandsToLoopNest(RewriterBase &rewriter, MutableArrayRef< LoopLikeOpInterface > loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn)
Method to add new init values to a loop nest.
static FailureOr< SmallVector< OpOperand * > > getUntiledConsumerOperandsFromSlices(RewriterBase &rewriter, ArrayRef< Operation * > sliceOps, MutableArrayRef< LoopLikeOpInterface > loops)
A utility to fetch an untiled consumer of tensor.insert_slice/tensor.parallel_insert_slice.
static SmallVector< tensor::InsertSliceOp > cloneAsInsertSlices(RewriterBase &rewriter, ArrayRef< Operation * > candidateSlices)
static tensor::InsertSliceOp cloneAsInsertSlice(RewriterBase &rewriter, InsertSliceOpTy sliceOp)
static FailureOr< OpOperand * > getUntiledConsumerFromSlice(RewriterBase &rewriter, tensor::InsertSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Fetch the untiled consumer of the outermost scf.for's result which is yielded by a tensor....
static FailureOr< SmallVector< Value > > createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims)
static FailureOr< SmallVector< LoopLikeOpInterface > > generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes, ArrayRef< OpFoldResult > numThreads, ArrayRef< Attribute > mappingVector, ValueRange outerDestinationTensors, GenerateTiledBodyFn tiledBodyFn)
Generate the tile-loop nest using scf.forall operation.
static LogicalResult verifyOptions(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options)
Verify the tile size options are set in a consistent manner.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:959
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
Definition: Block.h:33
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgListType getArguments()
Definition: Block.h:87
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
MLIRContext * getContext() const
Definition: Builders.h:56
Location getUnknownLoc()
Definition: Builders.cpp:24
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:265
A class for computing basic dominance information.
Definition: Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.cpp:323
This class allows control over how the GreedyPatternRewriteDriver works.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class represents a saved insertion point.
Definition: Builders.h:327
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:445
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:552
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:316
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:436
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:320
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:456
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:459
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:852
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:646
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
Definition: AffineOps.cpp:1439
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1329
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1432
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
FailureOr< scf::SCFTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSizes)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
FailureOr< scf::SCFFuseConsumerOfSliceResult > tileAndFuseConsumerOfSlices(RewriterBase &rewriter, ArrayRef< Operation * > candidateSlices, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing consumer of a single slice by computing the slice of the consumer in-place f...
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< SmallVector< Operation * > > yieldReplacementForFusedProducer(RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef< LoopLikeOpInterface > loops, ArrayRef< unsigned > yieldResultNumber=ArrayRef< unsigned >{})
Reconstruct the fused producer from within the tiled-and-fused code.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
std::optional< SCFFuseProducerOfSliceResult > tileAndFuseProducerOfSlice(RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing producer of a single slice by computing the slice of the producer in-place.
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
FailureOr< TilingResult > replaceInsertSlicesWithTiledConsumer(OpBuilder &builder, ArrayRef< tensor::InsertSliceOp > sliceOps, ArrayRef< OpOperand * > consumerOperands)
Method to swap tensor.insert_slices with their consumers when the consumer implements the TilingInter...
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:79
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:114
Include the generated interface declarations.
bool isPerfectlyNestedForLoops(MutableArrayRef< LoopLikeOpInterface > loops)
Check if the provided loops are perfectly nested for-loops.
Definition: Utils.cpp:1509
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
ReductionTilingStrategy
Tiling can be thought of as splitting a dimension into 2 and materializing the outer dimension as a l...
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:784
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
@ ExistingAndNewOps
Only pre-existing and newly created ops are processed.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:285
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:431
Container for result values of tiling.
Fuse the consumer candidateSlices by computing the required slice of the consumer in-place.
Fuse the producer of the source of candidateSliceOp by computing the required slice of the producer i...
Control function to check if a slice needs to be fused or not, The control function receives 1) the s...
Options used to control tile + fuse.
Transformation information returned after tile and fuse.
Options to use to control tiling.
SCFTileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes to use for each loop.
std::function< LogicalResult(RewriterBase &rewriter, Location loc, ValueRange tiledResults, ArrayRef< SmallVector< OpFoldResult > > resultOffsets, ArrayRef< SmallVector< OpFoldResult > > resultSizes, ValueRange destinationTensors)> GenerateLoopTerminatorFn
SCFTilingOptions & setNumThreads(ArrayRef< OpFoldResult > numThreads)
Convenience function to set the numThreadsComputationFunction to a function that computes num threads...
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
LoopType
Specify which loop construct to use for tile and fuse.
std::function< FailureOr< CustomLoopHeaderInfo >(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > givenTileSizes, ValueRange destinationTensors)> GenerateLoopHeaderFn
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.