MLIR  20.0.0git
TileUsingInterface.cpp
Go to the documentation of this file.
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
24 #include "mlir/IR/Dominance.h"
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Support/Debug.h"
33 #include <optional>
34 
35 #define DEBUG_TYPE "tile-using-interface"
36 
37 using namespace mlir;
38 
41  assert(!tileSizeComputationFunction && "tile sizes already set");
42  auto tileSizes = llvm::to_vector(ts);
43  tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
44  return tileSizes;
45  };
46  return *this;
47 }
48 
51  assert(!numThreadsComputationFunction && "num tiles already set");
52  auto numThreads = llvm::to_vector(nt);
53  numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) {
54  return numThreads;
55  };
56  return *this;
57 }
58 
59 /// Helper method to adjust the interchange vector to match the iteration
60 /// domain.
63  size_t iterationDomainSize) {
64  SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
65  if (filledVector.size() < iterationDomainSize) {
66  auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
67  filledVector.append(range.begin(), range.end());
68  }
69  if (filledVector.size() > iterationDomainSize)
70  filledVector.resize(iterationDomainSize);
71  return filledVector;
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // tileUsingSCF implementation.
76 //===----------------------------------------------------------------------===//
77 
78 /// Verify the tile size options are set in a consistent manner.
79 static LogicalResult
82  // Specifying number of threads is only supported on `scf.forall` op.
83  if (options.numThreadsComputationFunction &&
85  return rewriter.notifyMatchFailure(
86  loc, "number of threads can only by specified when loop type is "
87  "set to use `scf.forall`");
88  }
89 
90  // If specified, check that the interchange vector is a permutation.
91  if (!options.interchangeVector.empty()) {
92  if (!isPermutationVector(options.interchangeVector)) {
93  return rewriter.notifyMatchFailure(
94  loc, "invalid interchange vector, not a permutation of the entire "
95  "iteration space");
96  }
97  }
98  return success();
99 }
100 
101 /// Method to instantiate the tile sizes and/or number of threads specified
102 /// by the user.
103 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
104 getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
105  ArrayRef<Range> iterationDomain,
107  OpFoldResult zero = rewriter.getIndexAttr(0);
108  SmallVector<OpFoldResult> tileSizes, numThreads;
109  size_t numLoops = iterationDomain.size();
110 
111  // Check whether the number of tiles to use is specified.
112  if (options.numThreadsComputationFunction) {
113  numThreads = options.numThreadsComputationFunction(rewriter, op);
114  numThreads.resize(numLoops, zero);
115 
116  // If the number of tiles is also specified, use that.
117  if (options.tileSizeComputationFunction) {
118  tileSizes = options.tileSizeComputationFunction(rewriter, op);
119  tileSizes.resize(numLoops, zero);
120  return {tileSizes, numThreads};
121  }
122 
123  // Compute the tile sizes from the iteration domain and number
124  // of tiles as follows
125  // - niters = ceilDiv(ub - lb, step)
126  // - tileSize = ceilDiv(niters, numThreads)
127  AffineExpr s0, s1, s2;
128  bindSymbols(rewriter.getContext(), s0, s1, s2);
129  // TODO: The step here is assumed to be 1.
130  AffineExpr numItersExpr = (s1 - s0);
131  AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s2);
132  tileSizes.resize(numLoops, zero);
133  for (auto [index, range, nt] :
134  llvm::enumerate(iterationDomain, numThreads)) {
135  if (isConstantIntValue(nt, 0))
136  continue;
137 
138  tileSizes[index] = affine::makeComposedFoldedAffineApply(
139  rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
140  }
141  tileSizes.resize(numLoops, zero);
142  return {tileSizes, numThreads};
143  }
144 
145  // Enforce the convention that "tiling by zero"
146  // skips tiling a particular dimension. This convention is significantly
147  // simpler to handle instead of adjusting affine maps to account for missing
148  // dimensions.
149  assert(options.tileSizeComputationFunction &&
150  "expected tile sizes to be specified");
151  tileSizes = options.tileSizeComputationFunction(rewriter, op);
152  tileSizes.resize(numLoops, zero);
153 
154  return {tileSizes, numThreads};
155 }
156 
157 /// Checks if any of the tiled loops are not parallel.
158 static void checkSafeToTileToForall(TilingInterface op,
159  ArrayRef<OpFoldResult> tileSizes,
160  ArrayRef<OpFoldResult> numThreads) {
161  auto iterators = op.getLoopIteratorTypes();
162  assert(iterators.size() == tileSizes.size() &&
163  "expected as many tile size values as number of loops");
164  assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
165  "when specified, expected number of threads to use for each loop");
166 
167  for (auto [index, iterator, tileSize] :
168  llvm::enumerate(iterators, tileSizes)) {
169  // If num threads is specified, check that it is greater than one only for
170  // parallel dimensions.
171  if (!numThreads.empty()) {
172  if (std::optional<int64_t> constNumThreads =
173  getConstantIntValue(numThreads[index])) {
174  if (constNumThreads.value() > 1 &&
175  iterator != utils::IteratorType::parallel) {
176  op.emitWarning() << "tiling is not thread safe at axis #" << index;
177  }
178  }
179  continue;
180  }
181 
182  if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) {
183  if (constTileSize.value() > 0 &&
184  iterator != utils::IteratorType::parallel) {
185  op.emitWarning() << "tiling is not thread safe at axis #" << index;
186  }
187  }
188  }
189 }
190 
191 /// Check if `stride` evenly divides the trip count `size - offset`.
192 static bool tileDividesIterationDomain(Range loopRange) {
193  std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
194  if (!offsetAsInt)
195  return false;
196  std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
197  if (!sizeAsInt)
198  return false;
199  std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
200  if (!strideAsInt)
201  return false;
202  return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
203 }
204 
205 /// Returns the bounded tile size given the current `offset`, `loopRange` and
206 /// `tileSize`, i.e., `min(tileSize, range.end() - offset)`.
208  Range loopRange, OpFoldResult offset,
209  OpFoldResult tileSize) {
210  std::optional<int64_t> ts = getConstantIntValue(tileSize);
211  if (ts && ts.value() == 1)
212  return tileSize;
213 
215  Range{loopRange.offset, loopRange.size, tileSize}))
216  return tileSize;
217 
218  // The tile size to use (to avoid out of bounds access) is minimum of
219  // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
220  // loop.
221  AffineExpr s0, s1, d0;
222  bindDims(b.getContext(), d0);
223  bindSymbols(b.getContext(), s0, s1);
224  AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext());
225  Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
227  b, loc, minMap, SmallVector<OpFoldResult>{offset, size, tileSize});
228 }
229 
230 /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
231 /// than `iterationSize`.
233  OpFoldResult numThreads,
234  OpFoldResult iterationSize) {
235  std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
236  std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
237  std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
238  if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
239  return false;
240  return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
241 }
242 
243 /// Compute the `OpFoldResult`s that represents the multi-dimensional
244 /// `offset`s and `size`s of the tile of the iteration space that the
245 /// innermost loop body of the generated tiled loops corresponds to.
246 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
248  ArrayRef<Range> iterationDomain,
249  ArrayRef<OpFoldResult> tileSizes,
250  ArrayRef<OpFoldResult> numThreads) {
251  SmallVector<OpFoldResult> offsets, sizes;
252  int materializedLoopNum = 0;
253 
254  if (!numThreads.empty()) {
255  AffineExpr d0, d1, s0, s1;
256  AffineExpr offsetExpr, residualTileSizeExpr;
257  bindDims(rewriter.getContext(), d0, d1);
258  bindSymbols(rewriter.getContext(), s0, s1);
259  offsetExpr = d0 + d1 * s0;
260  residualTileSizeExpr = s1 - (d0 + d1 * s0);
261 
262  for (auto [nt, tileSize, loopRange] :
263  llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
264 
265  // Non-tiled cases, set the offset and size to the
266  // `loopRange.offset/size`.
267  if (isConstantIntValue(nt, 0)) {
268  offsets.push_back(loopRange.offset);
269  sizes.push_back(loopRange.size);
270  continue;
271  }
272 
273  Value iv = ivs[materializedLoopNum++];
275  rewriter, loc, offsetExpr,
276  ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize});
278  rewriter, loc, residualTileSizeExpr,
279  {loopRange.offset, nt, tileSize, loopRange.size});
280 
281  OpFoldResult size = tileSize;
282  if (!isConstantIntValue(residualTileSize, 0)) {
283  OpFoldResult sizeMinusOffsetPerThread =
284  affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
285  {offset, loopRange.size});
287  rewriter, loc,
289  {sizeMinusOffsetPerThread, tileSize});
290  }
291 
292  // Consider the case where the original loop was `[0, 100)`.
293  // If number of threads are `7`, the tile size would be computed as
294  // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6)
295  // - `offset = 0 + 6 * 15 = 105`
296  // - `tileSize = min(15, 100 - 105) = -5`
297  // To avoid negative tile sizes, we need to do a further
298  // `nonNegativeTileSize = affine.max(0, tileSize)`.
299  // This `max` can be avoided if
300  // `offset + tileSize * (numThreads - 1) < (ub - lb)`
301  if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) {
302  AffineMap maxMap =
305  rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});
306  }
307 
308  offsets.push_back(offset);
309  sizes.push_back(size);
310  }
311  return {offsets, sizes};
312  } else {
313  for (auto [tileSize, loopRange] :
314  llvm::zip_equal(tileSizes, iterationDomain)) {
315 
316  // Non-tiled cases, set the offset and size to the
317  // `loopRange.offset/size`.
318  if (isConstantIntValue(tileSize, 0)) {
319  offsets.push_back(loopRange.offset);
320  sizes.push_back(loopRange.size);
321  continue;
322  }
323 
324  Value iv = ivs[materializedLoopNum++];
325  OpFoldResult offset = getAsOpFoldResult(iv);
326  offsets.push_back(offset);
327  OpFoldResult size =
328  getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize);
329  sizes.push_back(size);
330  }
331  return {offsets, sizes};
332  }
333 }
334 
335 /// Function to return the bounds of the loops to be generated.
336 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
339  ArrayRef<OpFoldResult> tileSizes) {
340  SmallVector<OpFoldResult> lbs, ubs, steps;
341  for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
342  // No loop if the tile size is 0.
343  if (isConstantIntValue(tileSize, 0))
344  continue;
345  lbs.push_back(loopRange.offset);
346  ubs.push_back(loopRange.size);
347  steps.push_back(tileSize);
348  }
349  return {lbs, ubs, steps};
350 }
351 
352 /// A function that allows returning additional yielded values during
353 /// `yieldTiledValuesAndReplace`.
354 /// - `ivs` induction variable for the loop.
355 /// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
356 /// - `tiledValues` the tiled values to return. Must be of same size as
357 /// `newbbArgs`, each element of this array is inserted into the corresponding
358 /// element in `newbbArgs`.
359 /// - `resultOffsets` is of the same size as `tiledValues` and represents
360 /// the offsets to use when inserting corresponding element from `tiledValues`
361 /// into the element from `newBbArgs`.
362 /// - `resultSizes` is of the same size as `tiledValues` and represents
363 /// the size of the corresponding element from `tiledValues` inserted into
364 /// the element from `newBbArgs`.
365 /// In case the method needs to return `failure()` the method is expected
366 /// to clean up any inserted operations.
367 using YieldTiledValuesFn = std::function<LogicalResult(
368  RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
369  SmallVector<Value> &tiledValues,
370  SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
371  SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
372 
373 /// Clones the operation and updates the destination if the operation
374 /// implements the `DestinationStyleOpInterface`.
376  Operation *op,
377  ValueRange newDestArgs) {
378  Operation *clonedOp = rewriter.clone(*op);
379  if (newDestArgs.empty())
380  return clonedOp;
381  if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
382  destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
383  return clonedOp;
384 }
385 
386 /// Generate the tile-loop nest using `scf.for` operation.
387 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
388 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
389 /// - `destinationTensors` are the init values to use for the outer most loop.
390 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
391 /// most
392 /// loop.
393 /// - `loops` is an in-out parameter into which the generated loops are
394 /// populated.
395 static LogicalResult generateLoopNestUsingForOp(
396  RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
397  ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
398  YieldTiledValuesFn yieldTiledValuesFn,
400  assert(!loopRanges.empty() && "unexpected empty loop ranges");
401  assert(loopRanges.size() == tileSizes.size() &&
402  "expected as many tile sizes as loop ranges");
403  OpBuilder::InsertionGuard guard(rewriter);
404 
405  SmallVector<OpFoldResult> lbs, ubs, steps;
406  std::tie(lbs, ubs, steps) =
407  getLoopBounds(rewriter, loc, loopRanges, tileSizes);
408  SmallVector<Value> lbVals =
409  getValueOrCreateConstantIndexOp(rewriter, loc, lbs);
410  SmallVector<Value> ubVals =
411  getValueOrCreateConstantIndexOp(rewriter, loc, ubs);
412  SmallVector<Value> stepVals =
413  getValueOrCreateConstantIndexOp(rewriter, loc, steps);
414 
415  SmallVector<Value> ivs;
416  for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
417  auto loop =
418  rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
419  [](OpBuilder &bodyBuilder, Location bodyLoc,
420  Value iv, ValueRange /*iterArgs*/) {});
421  loops.push_back(loop);
422  ivs.push_back(loop.getInductionVar());
423  rewriter.setInsertionPointToEnd(loop.getBody());
424  destinationTensors = loop.getRegionIterArgs();
425  }
426 
427  SmallVector<Value> tiledResults;
428  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
429  if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
430  tiledResults, resultOffsets, resultSizes))) {
431  return rewriter.notifyMatchFailure(
432  loc, "failed to generate inner tile loop body");
433  }
434  if (loops.empty())
435  return success();
436 
437  assert(tiledResults.size() == destinationTensors.size() &&
438  "Number of results of body should be equal to number of iter args");
439 
440  // 6. Yield all the results of the tiled operation.
441  SmallVector<Value> yieldedValues;
442  for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
443  llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
444  resultSizes)) {
445  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
446  rewriter.getIndexAttr(1));
447  auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
448  loc, tiledValue, destinationTensor, resultOffset, resultSize,
449  resultStride);
450  yieldedValues.push_back(insertSlice);
451  }
452  rewriter.create<scf::YieldOp>(loc, yieldedValues);
453 
454  // Add the scf.yield operations for all the outer loops.
455  for (auto [outerLoop, innerLoop] :
456  llvm::zip_equal(MutableArrayRef(loops).drop_back(),
457  MutableArrayRef(loops).drop_front())) {
458  rewriter.setInsertionPointToEnd(
459  cast<scf::ForOp>(outerLoop.getOperation()).getBody());
460  rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
461  }
462  return success();
463 }
464 
465 /// Generate the tile-loop nest using `scf.forall` operation.
466 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
467 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
468 /// - `destinationTensors` are the init values to use for the outer most loop.
469 /// - `mappingVector` is the mapping attributes to use for loop construction.
470 /// Can be empty.
471 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
472 /// most
473 /// loop.
474 /// - `loops` is an in-out parameter into which the generated loops are
475 /// populated.
476 static LogicalResult generateLoopNestUsingForallOp(
477  RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
478  ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
479  ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
481  assert(!loopRanges.empty() && "unexpected empty loop ranges");
482  assert(loopRanges.size() == tileSizes.size() &&
483  "expected as many tile sizes as loop ranges");
484  OpBuilder::InsertionGuard guard(rewriter);
485  SmallVector<OpFoldResult> offsets(loopRanges.size()),
486  sizes(loopRanges.size());
487 
488  std::optional<ArrayAttr> mappingAttr;
489  if (!mappingVector.empty())
490  mappingAttr = rewriter.getArrayAttr(mappingVector);
491 
492  scf::ForallOp forallOp;
493  bool useNumThreads = !numThreads.empty();
494 
495  if (useNumThreads) {
496  // Prune the zero numthreads.
497  SmallVector<OpFoldResult> nonZeroNumThreads;
498  for (auto nt : numThreads) {
499  if (isConstantIntValue(nt, 0))
500  continue;
501  nonZeroNumThreads.push_back(nt);
502  }
503  forallOp = rewriter.create<scf::ForallOp>(loc, nonZeroNumThreads,
504  destinationTensors, mappingAttr);
505  } else {
506  SmallVector<OpFoldResult> lbs, ubs, steps;
507  std::tie(lbs, ubs, steps) =
508  getLoopBounds(rewriter, loc, loopRanges, tileSizes);
509  forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps,
510  destinationTensors, mappingAttr);
511  }
512  loops.push_back(forallOp);
513 
514  rewriter.setInsertionPoint(forallOp.getTerminator());
515  destinationTensors = forallOp.getRegionOutArgs();
516 
517  SmallVector<Value> tiledResults;
518  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
519  if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
520  destinationTensors, tiledResults, resultOffsets,
521  resultSizes)))
522  return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
523 
524  rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
525  for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
526  llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
527  resultSizes)) {
528  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
529  rewriter.getIndexAttr(1));
530 
531  rewriter.create<tensor::ParallelInsertSliceOp>(
532  loc, tiledValue, destinationTensor, resultOffset, resultSize,
533  resultStride);
534  }
535  return success();
536 }
537 
538 /// Generate the tile-loop nest using the loop construct specifed in `options`.
539 /// - `options`: Tiling options specified.
540 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
541 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
542 /// - `destinationTensors` are the init values to use for the outer most loop.
543 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
544 /// most
545 /// loop.
546 /// - `loops` is an in-out parameter into which the generated loops are
547 /// populated.
548 static LogicalResult generateLoopNest(
549  RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options,
550  ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes,
551  ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors,
553  // If the tile sizes are all zero, no loops are generated. Just call the
554  // callback function to handle untiled case.
555  if (llvm::all_of(tileSizes, isZeroIndex)) {
556  SmallVector<Value> tiledResults;
557  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
558  return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
559  tiledResults, resultOffsets, resultSizes);
560  }
562  return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
563  destinationTensors, tiledBodyFn, loops);
564  }
567  rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector,
568  destinationTensors, tiledBodyFn, loops);
569  }
570  return rewriter.notifyMatchFailure(loc, "unhandled loop type");
571 }
572 
573 /// Append the specified additional `newInitOperands` operands to the
574 /// loops existing `init` operands (or similar), and replace `loopOp` with
575 /// the new loop that has the additional init operands. The loop body of
576 /// this loop is moved over to the new loop. `yieldTiledValuesFn`
577 /// is called to get the new tiled values returned, and the offset
578 /// and sizes at which the tiled value is inserted into the
579 /// new region iter_args that correspond to the newly added init operands.
580 template <typename LoopType>
581 FailureOr<LoopLikeOpInterface>
582 yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter,
583  ValueRange newInitOperands,
584  YieldTiledValuesFn yieldTiledValuesFn) {
585  return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
586 }
587 
588 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`.
589 template <>
590 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
591  scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
592  YieldTiledValuesFn yieldTiledValuesFn) {
593  OpBuilder::InsertionGuard g(rewriter);
594  Location loc = loopOp.getLoc();
595  rewriter.setInsertionPoint(loopOp);
596 
597  auto inits = llvm::to_vector(loopOp.getInitArgs());
598  inits.append(newInitOperands.begin(), newInitOperands.end());
599  auto newLoop = rewriter.create<scf::ForOp>(
600  loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
601  inits, [](OpBuilder &, Location, Value, ValueRange) {});
602 
603  // Move the loop body to the new op.
604  Block *loopBody = loopOp.getBody();
605  Block *newLoopBody = newLoop.getBody();
606  rewriter.mergeBlocks(
607  loopBody, newLoopBody,
608  newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
609 
610  auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator());
611  rewriter.setInsertionPoint(yieldOp);
612 
613  SmallVector<Value> tiledValues;
614  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
615  ValueRange newRegionIterArgs =
616  newLoop.getRegionIterArgs().take_back(newInitOperands.size());
617  if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
618  newRegionIterArgs, tiledValues, resultOffsets,
619  resultSizes))) {
620  rewriter.eraseOp(newLoop);
621  return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");
622  }
623 
624  SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
625  for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
626  llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
627  resultSizes)) {
628  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
629  rewriter.getIndexAttr(1));
630  Value insert = rewriter.create<tensor::InsertSliceOp>(
631  yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
632  resultStride);
633  newYieldValues.push_back(insert);
634  }
635 
636  rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
637  rewriter.replaceOp(loopOp,
638  newLoop->getResults().take_front(loopOp.getNumResults()));
639  return cast<LoopLikeOpInterface>(newLoop.getOperation());
640 }
641 
642 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall`
643 template <>
644 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
645  scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
646  YieldTiledValuesFn yieldTiledValuesFn) {
647  OpBuilder::InsertionGuard g(rewriter);
648  Location loc = loopOp.getLoc();
649  rewriter.setInsertionPoint(loopOp);
650  auto inits = llvm::to_vector(loopOp.getOutputs());
651  inits.append(newInitOperands.begin(), newInitOperands.end());
652  auto newLoop = rewriter.create<scf::ForallOp>(
653  loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
654  loopOp.getMixedStep(), inits, loopOp.getMapping(),
655  [](OpBuilder &, Location, ValueRange) {});
656 
657  // Move the region of the current block to the newly created op.
658  Block *loopBody = loopOp.getBody();
659  Block *newLoopBody = newLoop.getBody();
660  rewriter.mergeBlocks(
661  loopBody, newLoopBody,
662  newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
663 
664  auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
665  rewriter.setInsertionPoint(terminator);
666  SmallVector<Value> tiledValues;
667  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
668  ValueRange regionIterArgs =
669  newLoop.getRegionIterArgs().take_back(newInitOperands.size());
670  if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
671  regionIterArgs, tiledValues, resultOffsets,
672  resultSizes))) {
673  rewriter.eraseOp(newLoop);
674  return rewriter.notifyMatchFailure(loopOp,
675  "failed to get yielded tiled values");
676  }
677 
678  // Update the terminator.
679  rewriter.setInsertionPointToEnd(terminator.getBody());
680 
681  for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
682  tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
683  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
684  rewriter.getIndexAttr(1));
685  rewriter.create<tensor::ParallelInsertSliceOp>(
686  terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
687  resultStride);
688  }
689 
690  rewriter.replaceOp(loopOp,
691  newLoop->getResults().take_front(loopOp.getNumResults()));
692  return cast<LoopLikeOpInterface>(newLoop.getOperation());
693 }
694 
695 /// Implementation of `yieldTiledValuesAndReplaceLoop` for
696 /// `LoopLikeOpInterface`, that just dispatches to the implementation for each
697 /// supported loop type.
698 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
699  LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
700  ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
702  loopLikeOp.getOperation())
703  .Case<scf::ForOp, scf::ForallOp>(
704  [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
706  loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
707  })
708  .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
709  return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
710  });
711 }
712 
713 /// Method to add new init values to a loop nest. Updates `loops` in-place with
714 /// new loops that use the `newInitValues`.
715 /// The outer-loops are updated to yield the new result values of the inner
716 /// loop. For the innermost loop, the call back `getNewYields` is invoked to get
717 /// the additional values to yield form the innermost loop.
718 static LogicalResult addInitOperandsToLoopNest(
720  ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
721  SmallVector<scf::ForOp> newLoops;
722  if (loops.empty())
723  return success();
724  OpBuilder::InsertionGuard g(rewriter);
725  rewriter.setInsertionPoint(loops.front());
726 
727  SmallVector<Value> ivs;
728  for (auto &loop : loops.drop_back()) {
729  rewriter.setInsertionPoint(loop);
730 
731  // if loops.size() > 1 we assume that scf.for is used for the loops.
732  auto forLoop = cast<scf::ForOp>(loop.getOperation());
733 
734  // Create a new loop with the new init values for this loop.
735  SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs());
736  newInits.append(newInitValues.begin(), newInitValues.end());
737  auto newLoop = rewriter.create<scf::ForOp>(
738  forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
739  forLoop.getStep(), newInits,
740  [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
741 
742  // Merge the body of the new loop with the body of the old loops.
743  SmallVector<Value> sourceBlockArgs;
744  sourceBlockArgs.push_back(newLoop.getInductionVar());
745  auto newRegionIterArgs = newLoop.getRegionIterArgs();
746  sourceBlockArgs.append(
747  newRegionIterArgs.begin(),
748  std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
749  rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
750  rewriter.replaceOp(
751  forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
752  loop = newLoop;
753  ivs.push_back(newLoop.getInductionVar());
754  newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
755  }
756 
757  // Update the loop body of the innermost loop to get new yield values.
758  LoopLikeOpInterface innerMostLoop = loops.back();
759  FailureOr<LoopLikeOpInterface> newInnerMostLoop =
760  yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues,
761  getNewTiledYieldsFn);
762 
763  if (failed(newInnerMostLoop))
764  return innerMostLoop.emitOpError("failed to return additional yields");
765  loops.back() = newInnerMostLoop.value();
766 
767  // Make all other loops except the innermost loops yield the values returned
768  // by the inner loop.
769  for (auto [outerLoop, innerLoop] :
770  llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
771  // Again assume that all the outer loops are scf.for operations.
772  auto outerForLoop = cast<scf::ForOp>(outerLoop);
773  auto outerLoopYield =
774  cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
775  SmallVector<Value> newYields =
776  llvm::to_vector(outerLoopYield.getOperands());
777  ValueRange additionalYields =
778  innerLoop->getResults().take_back(newInitValues.size());
779  newYields.append(additionalYields.begin(), additionalYields.end());
780  rewriter.setInsertionPoint(outerLoopYield);
781  rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
782  }
783  return success();
784 }
785 
786 /// Implementation of tiling transformation of `op` that implements the
787 /// `TilingInterface` using `scf.for` to iterate over the tiles.
788 FailureOr<scf::SCFTilingResult>
789 mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
791  if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) {
792  return failure();
793  }
794 
795  OpBuilder::InsertionGuard guard(rewriter);
796  rewriter.setInsertionPointAfter(op);
797 
798  // 1. Get the range of the loops that are represented by the operation.
799  SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
800 
801  // 2. Materialize the tile sizes and/or number of threads;
802  SmallVector<OpFoldResult> tileSizes, numThreads;
803  std::tie(tileSizes, numThreads) =
804  getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options);
805 
806  // Check if it is safe to tile. This is hold over from previous iterations
807  // of tile to for-all. Consider dropping it.
809  checkSafeToTileToForall(op, tileSizes, numThreads);
810  }
811 
812  // 3. If there is an interchange specified, permute the iteration domain and
813  // the tile sizes.
814  SmallVector<int64_t> interchangeVector;
815  if (!options.interchangeVector.empty()) {
816  interchangeVector = fillInterchangeVector(options.interchangeVector,
817  iterationDomain.size());
818  assert(isPermutationVector(interchangeVector) &&
819  "expected interchange vector to be a permutation");
820 
821  applyPermutationToVector(iterationDomain, interchangeVector);
822  applyPermutationToVector(tileSizes, interchangeVector);
823  if (!numThreads.empty())
824  applyPermutationToVector(numThreads, interchangeVector);
825  }
826 
827  FailureOr<TilingResult> tilingResult;
828  // 4. Define the lambda function used later to generate the body of the
829  // innermost tiled loop.
830  YieldTiledValuesFn innerYieldTiledValuesFn =
831  [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
832  ValueRange regionIterArgs, SmallVector<Value> &tiledResults,
835  -> LogicalResult {
836  // 4a. Compute the `offsets` and `sizes` to use for tiling.
837  SmallVector<OpFoldResult> offsets, sizes;
838  std::tie(offsets, sizes) = getTileOffsetAndSizes(
839  rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
840 
841  // 4b. If interchange was provided, apply inverse of the interchange
842  // to get back the offsets/sizes in the order to be specified.
843  if (!interchangeVector.empty()) {
844  auto inversePermutation = invertPermutationVector(interchangeVector);
847  }
848 
849  // 5. Generate the tiled implementation within the inner most loop.
850 
851  // 5a. Clone the operation within the loop body.
852  auto clonedOp = cast<TilingInterface>(
853  cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
854 
855  // 5b. Early return cloned op if tiling is not happening. We can not return
856  // the original op because it could lead to
857  // `rewriter.replaceOp(op, op->getResults())` and users would get crash.
858  if (llvm::all_of(tileSizes, isZeroIndex)) {
859  tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
860  tilingResult =
861  TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
862  /*generatedSlices=*/{}};
863  return success();
864  }
865 
866  // 5c. Tile the cloned operation.
867  tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes);
868  if (failed(tilingResult)) {
869  rewriter.eraseOp(clonedOp);
870  return op.emitOpError("faild to tile operation");
871  }
872 
873  // 5d. Delete the cloned operation.
874  rewriter.eraseOp(clonedOp);
875 
876  // 5e. Compute the offsets at which the result values are to be inserted
877  // back into its destinations.
878  for (auto [index, tiledValue] :
879  llvm::enumerate(tilingResult->tiledValues)) {
880  tiledResults.push_back(tiledValue);
881  SmallVector<OpFoldResult> resultOffset, resultSize;
882  if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
883  resultOffset, resultSize))) {
884  for (auto op : tilingResult->tiledOps) {
885  rewriter.eraseOp(op);
886  }
887  return rewriter.notifyMatchFailure(
888  op, "failed to get slice of result produced");
889  }
890  resultOffsets.emplace_back(std::move(resultOffset));
891  resultSizes.emplace_back(std::move(resultSize));
892  }
893 
894  return success();
895  };
896 
897  // 6. Find the destination tensors to use for the operation.
898  SmallVector<Value> destinationTensors;
899  if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
900  destinationTensors))) {
901  return rewriter.notifyMatchFailure(op,
902  "unable to create destination tensors");
903  }
904 
905  // 7. Generate the tiled loops nest using the callback defined above.
907  if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
908  tileSizes, numThreads, destinationTensors,
909  innerYieldTiledValuesFn, loops)))
910  return op.emitOpError("failed to generate tiling loops");
911  assert(succeeded(tilingResult) &&
912  "expected tiling result to be computed after loop generation");
913 
914  // If loops are empty, the tiled op is used as the replacement for the untiled
915  // op.
916  if (loops.empty()) {
917  return scf::SCFTilingResult{tilingResult->tiledOps, loops,
918  tilingResult->tiledValues,
919  tilingResult->generatedSlices};
920  }
921 
922  SmallVector<Value> replacements = llvm::map_to_vector(
923  loops.front()->getResults(), [](OpResult r) -> Value { return r; });
924  return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements,
925  tilingResult->generatedSlices};
926 }
927 
928 FailureOr<scf::SCFReductionTilingResult>
930  PartialReductionOpInterface op,
931  ArrayRef<OpFoldResult> tileSizes) {
932  Location loc = op.getLoc();
933  // Ops implementing PartialReductionOpInterface are expected to implement
934  // TilingInterface.
935  auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
936  SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
937  auto tileSizesVector = llvm::to_vector(tileSizes);
938  if (tileSizesVector.size() < iterationDomain.size()) {
939  auto zero = b.getIndexAttr(0);
940  tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
941  zero);
942  }
944  tilingInterfaceOp.getLoopIteratorTypes();
945 
946  SmallVector<int> reductionDims;
947  for (auto [idx, iteratorType] :
948  llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
949  if (iteratorType == utils::IteratorType::reduction)
950  reductionDims.push_back(idx);
951  }
952 
953  // 2. create the inital tensor value.
954  FailureOr<SmallVector<Value>> maybeInitTensors =
955  op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
956  reductionDims);
957  if (failed(maybeInitTensors)) {
958  return b.notifyMatchFailure(op, "Failed to create initial tensors.");
959  }
960  SmallVector<Value> &initTensors = maybeInitTensors.value();
961 
962  // 3. Define the callback to use for generating the inner most tile loop body.
963  SmallVector<Operation *> parallelTiledOps;
964  auto innerYieldTiledValuesFn =
965  [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
966  ValueRange regionIterArgs, SmallVector<Value> &tiledResult,
969  -> LogicalResult {
970  SmallVector<OpFoldResult> offsets, sizes;
971  {
972  int materializedLoopNum = 0;
973  for (auto [tileSize, loopRange] :
974  llvm::zip_equal(tileSizesVector, iterationDomain)) {
975  if (isConstantIntValue(tileSize, 0)) {
976  offsets.push_back(loopRange.offset);
977  sizes.push_back(loopRange.size);
978  continue;
979  }
980  Value iv = ivs[materializedLoopNum++];
981  offsets.push_back(iv);
982  sizes.push_back(
983  getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
984  }
985  }
986 
987  // 4a. Clone the operation.
988  {
989  auto clonedOp = cast<PartialReductionOpInterface>(
990  cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
991 
992  // 4b. Tile the cloned operation.
993  FailureOr<TilingResult> partialTilingResult =
994  clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets,
995  sizes, reductionDims);
996  if (failed(partialTilingResult)) {
997  return failure();
998  }
999  std::swap(parallelTiledOps, partialTilingResult->tiledOps);
1000  std::swap(tiledResult, partialTilingResult->tiledValues);
1001 
1002  // 4c. Delete the cloned operation.
1003  b.eraseOp(clonedOp);
1004  }
1005 
1006  // 4d. Compute the offsets and sizes needed to insert the result of the
1007  // tiled value back into destination before yielding the destination.
1008  for (auto result : tiledResult) {
1009  SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
1010  resultOffsets.emplace_back(std::move(outOffsets));
1011 
1012  SmallVector<OpFoldResult> outSizes;
1013  for (size_t i = 0; i < offsets.size(); i++) {
1014  outSizes.push_back(tensor::getMixedSize(b, loc, result, i));
1015  }
1016  resultSizes.emplace_back(std::move(outSizes));
1017  }
1018  return success();
1019  };
1020 
1021  // 5. Generate the tiled implementation using the destination tensors.
1025  if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
1026  /*numThreads=*/ArrayRef<OpFoldResult>{},
1027  initTensors, innerYieldTiledValuesFn, loops)))
1028  return b.notifyMatchFailure(op, "failed to tile for parallel reduction");
1029 
1030  SmallVector<Value> replacements = llvm::map_to_vector(
1031  loops.front()->getResults(), [](OpResult r) -> Value { return r; });
1032 
1033  // 5. Apply the merge reduction to combine all the partial values.
1034  b.setInsertionPointAfter(*loops.begin());
1035  FailureOr<MergeResult> mergeResult =
1036  op.mergeReductions(b, loc, replacements, reductionDims);
1037  if (failed(mergeResult)) {
1038  return failure();
1039  }
1040  b.replaceOp(op, mergeResult->replacements);
1041 
1042  SCFReductionTilingResult reductionTilingResult;
1043  std::swap(reductionTilingResult.parallelTiledOps, parallelTiledOps);
1044  std::swap(reductionTilingResult.mergeOps, mergeResult->mergeOps);
1045  std::swap(reductionTilingResult.initialValues, initTensors);
1046  std::swap(reductionTilingResult.loops, loops);
1047  std::swap(reductionTilingResult.replacements, mergeResult->replacements);
1048 
1049  return reductionTilingResult;
1050 }
1051 
1052 //===----------------------------------------------------------------------===//
1053 // tileConsumerAndFuseProducersUsingSCF implementation.
1054 //===----------------------------------------------------------------------===//
1055 
1056 /// Return the untiled producer whose slice is used in a tiled consumer. The
1057 /// method traverses the tile loop nest (`loops`) if needed, and returns the
1058 /// `iter_args` of the outer most that is encountered. Traversing the iter_args
1059 /// indicates that this is a destination operand of the consumer. If there was
1060 /// no loop traversal needed, the second value of the returned tuple is empty.
1061 static std::tuple<OpResult, std::optional<OpOperand *>>
1064  std::optional<OpOperand *> destinationIterArg;
1065  auto loopIt = loops.rbegin();
1066  while (auto iterArg = dyn_cast<BlockArgument>(source->get())) {
1067  auto loop = *loopIt;
1068  if (iterArg.getOwner()->getParentOp() != loop)
1069  break;
1070  source = loop.getTiedLoopInit(iterArg);
1071  loopIt++;
1072  }
1073  if (loopIt == loops.rend())
1074  destinationIterArg = source;
1075  return {dyn_cast<OpResult>(source->get()), destinationIterArg};
1076 }
1077 
1078 /// Implementation of fusing producer of a single slice by computing the
1079 /// slice of the producer in-place.
1080 std::optional<scf::SCFFuseProducerOfSliceResult>
1082  RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1084  // 1. Get the producer of the source (potentially walking through
1085  // `iter_args` of nested `scf.for`)
1086  auto [fusableProducer, destinationInitArg] =
1087  getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
1088  loops);
1089  if (!fusableProducer)
1090  return std::nullopt;
1091  unsigned resultNumber = fusableProducer.getResultNumber();
1092 
1093  OpBuilder::InsertionGuard g(rewriter);
1094  rewriter.setInsertionPoint(candidateSliceOp);
1095 
1096  // 2. Clone the fused producer
1097  // 2a. Compute the destination operands to use for the cloned operation.
1098  SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
1099  Operation *fusableProducerOp = fusableProducer.getOwner();
1100  if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
1102  rewriter, fusableProducerOp->getLoc(), fusableProducerOp,
1103  origDestinationTensors)))
1104  return std::nullopt;
1105 
1106  clonedOpDestinationTensors = origDestinationTensors;
1107  if (destinationInitArg &&
1108  isa<DestinationStyleOpInterface>(fusableProducerOp)) {
1109  // 2b. If the producer is also destination style, then to maintain the
1110  // destination passing style, update the destination of the producer to be
1111  // the source of the slice.
1112  clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1113  }
1114  // 2c. Clone the fused producer.
1115  Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
1116  rewriter, fusableProducerOp, clonedOpDestinationTensors);
1117  // 2d. Update the source of the candidateSlice to be the cloned producer.
1118  // Easier to just clone the slice with different source since replacements
1119  // and DCE of cloned ops becomes easier
1120  SmallVector<Value> candidateSliceOpOperands =
1121  llvm::to_vector(candidateSliceOp->getOperands());
1122  candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
1123  tensor::ExtractSliceOp clonedCandidateSliceOp =
1124  mlir::clone(rewriter, candidateSliceOp,
1125  candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1126 
1127  // 3. Generate the tiled implementation of the producer of the source
1128  FailureOr<TilingResult> tileAndFuseResult =
1130  rewriter, clonedCandidateSliceOp,
1131  clonedProducerOp->getResult(resultNumber));
1132  if (failed(tileAndFuseResult))
1133  return std::nullopt;
1134  // Note: Do not delete the candidateSliceOp, since its passed in from the
1135  // caller.
1136  rewriter.replaceAllUsesWith(candidateSliceOp,
1137  tileAndFuseResult->tiledValues[0]);
1138  rewriter.eraseOp(clonedCandidateSliceOp);
1139  rewriter.eraseOp(clonedProducerOp);
1140 
1141  // 3. If the slice is for a destination operand, for example,
1142  //
1143  // ```mlir
1144  // %0 = linalg.init
1145  // %1 = linalg.fill .. outs(%0 : )
1146  // %2 = scf.for .. iter_args(%arg0 = %1) {
1147  // %3 = scf.for .. iter_args(%arg1 = %arg0) {
1148  // %4 = tensor.extract_slice %arg1 [..]
1149  // .. = linalg.matmul .. outs(%4 : )
1150  // }
1151  // }
1152  // ```
1153  //
1154  // the IR is currently
1155  //
1156  // ```
1157  // %0 = linalg.init
1158  // %1 = linalg.fill
1159  // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
1160  // %3 = scf.for .. iter_args(%arg1 = %arg0) {
1161  // %4 = tensor.extract_slice %arg1[..]
1162  // %5 = linalg.fill .. outs(%4 : )
1163  // .. = linalg.matmul .. outs(%5 : )
1164  // }
1165  // }
1166  // ```
1167  //
1168  // The untiled `linalg.fill` is still used as the `init_value` since it
1169  // was originally a destination operand of the untiled `linalg.matmul`.
1170  // When fusing an operand that is a destination operand, the iter_arg of
1171  // the outer most loop should be changed to use the destination of the
1172  // fused operation. With this the IR will be.
1173  //
1174  // ```
1175  // %0 = linalg.init
1176  // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
1177  // %2 = scf.for .. iter_args(%arg1 = %arg0) {
1178  // %3 = tensor.extract_slice %arg1[..]
1179  // %4 = linalg.fill .. outs(%3 : )
1180  // .. = linalg.matmul .. outs(%4 : )
1181  // }
1182  // }
1183  // ```
1184  if (destinationInitArg &&
1185  isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
1186  loops.front()
1187  ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1188  .set(origDestinationTensors[resultNumber]);
1189  }
1191  fusableProducer, tileAndFuseResult->tiledValues[0],
1192  tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices};
1193 }
1194 
1195 /// Reconstruct the fused producer from within the tiled-and-fused code.
1196 FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
1197  RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1198  scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
1200  ArrayRef<unsigned> yieldResultNumber) {
1201  if (loops.empty())
1202  return success();
1203 
1204  Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
1205  *tiledOwner = fusedProducerInfo.tiledOps[0];
1206 
1207  Location loc = originalOwner->getLoc();
1208  // a. collect all init Value to be appended
1209  SmallVector<unsigned> initNumberList =
1210  yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
1211  0, originalOwner->getNumResults()))
1212  : llvm::to_vector(yieldResultNumber);
1213  SmallVector<Value> initValueList;
1214  for (const auto &resultNumber : initNumberList) {
1215  FailureOr<Value> initValue = tensor::getOrCreateDestination(
1216  rewriter, loc, originalOwner->getResult(resultNumber));
1217  if (succeeded(initValue)) {
1218  initValueList.push_back(initValue.value());
1219  } else {
1220  return failure();
1221  }
1222  }
1223 
1224  SmallVector<Operation *> generatedSlices;
1225  YieldTiledValuesFn newYieldValuesFn =
1226  [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
1227  ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
1229  SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
1230  OpBuilder::InsertionGuard g(innerRewriter);
1231 
1232  // get sliceOp tile information
1233  SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
1234  sliceSizes = sliceOp.getMixedSizes();
1235 
1236  // expect all strides of sliceOp being 1
1237  if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
1238  return !isConstantIntValue(ofr, 1);
1239  }))
1240  return failure();
1241 
1242  unsigned sliceResultNumber =
1243  fusedProducerInfo.origProducer.getResultNumber();
1244 
1245  auto tilableOp = cast<TilingInterface>(originalOwner);
1246  // b. get iterDomain Offset and Sizes based on sliceOp tile
1247  SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
1248  // skip tensor.pack/unpack/pad, which expects single opResult
1249  if (tilableOp->getNumResults() > 1 &&
1250  failed(tilableOp.getIterationDomainTileFromResultTile(
1251  rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1252  iterDomainOffset, iterDomainSizes))) {
1253  // In theory, it is unnecessary to raise an error here. Actually although
1254  // it fails to reconstruct the result tensor, it should not broke current
1255  // fusion anyway. The reason why we must return failure currently is that
1256  // the callback function `newYieldValuesFn` will be called after new init
1257  // operand(s) has already been appended. It will take more refactoring to
1258  // make sure the init operands are added consistently in the future. For
1259  // more details, please refer to:
1260  // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
1261  return failure();
1262  }
1263 
1264  // c. calculate offsets and sizes info of all OpResults respectively based
1265  // on iteration Domain Tile
1266  SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
1267  for (const auto &resultNumber : initNumberList) {
1268  if (resultNumber == sliceResultNumber) {
1269  offsetList.push_back(sliceOffset);
1270  sizesList.push_back(sliceSizes);
1271  } else {
1272  assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1273  // infer result tile according to the iteration domain tile
1274  SmallVector<OpFoldResult> offset, sizes;
1275  if (failed(tilableOp.getResultTilePosition(
1276  rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1277  offset, sizes))) {
1278  return failure();
1279  }
1280  offsetList.push_back(offset);
1281  sizesList.push_back(sizes);
1282  }
1283  }
1284 
1285  // d. create `extract_slice` for `iter_args` for DPS operation if necessary
1286  if (auto tiledDestStyleOp =
1287  dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1288  rewriter.setInsertionPoint(tiledDestStyleOp);
1289  for (const auto &&[index, newRegionArg] :
1290  llvm::enumerate(newRegionIterArgs)) {
1291  auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
1292  loc, newRegionArg, offsetList[index], sizesList[index],
1293  SmallVector<OpFoldResult>(offsetList[index].size(),
1294  rewriter.getIndexAttr(1)));
1295  generatedSlices.push_back(destSlice);
1296  unsigned resultNumber = initNumberList[index];
1297  rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
1298  tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1299  });
1300  }
1301  }
1302 
1303  // e. prepare tiled offset and sizes for later `insert_slice` creation by
1304  // caller
1305  Block *block = rewriter.getInsertionPoint()->getBlock();
1306  rewriter.setInsertionPoint(block->getTerminator());
1307  for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
1308  tiledResult.push_back(tiledOwner->getResult(resultNumber));
1309  tiledOffset.emplace_back(offsetList[index]);
1310  tiledSizes.emplace_back(sizesList[index]);
1311  }
1312  return success();
1313  };
1314 
1315  if (failed(addInitOperandsToLoopNest(rewriter, loops, initValueList,
1316  newYieldValuesFn))) {
1317  return failure();
1318  }
1319  return generatedSlices;
1320 }
1321 
1322 namespace {
1323 
1324 //===----------------------------------------------------------------------===//
1325 // SliceTrackingListener
1326 //===----------------------------------------------------------------------===//
1327 
1328 /// This class is a listener for tracking the insertion and removal of
1329 /// `tensor.extract_slice` ops in a worklist. This can be used in a greedy
1330 /// fusion algorithm to apply cleanup patterns in between fusion steps.
1331 class SliceTrackingListener : public RewriterBase::Listener {
1332 public:
1333  explicit SliceTrackingListener(
1334  std::optional<FrozenRewritePatternSet> patterns);
1335  SliceTrackingListener() = default;
1336 
1337  /// Adds the given list of operations to the worklist, and if present, applies
1338  /// the list of `patterns` to the newly added operations. This only processes
1339  /// the given operations and any newly inserted ones by the pattern set.
1340  LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
1341 
1342  /// Add to the new operation worklist if it is an extract_slice.
1343  void notifyOperationInserted(Operation *op,
1344  OpBuilder::InsertPoint previous) override;
1345 
1346  /// Shared helper for operation removal from the worklist.
1347  void removeOp(Operation *op);
1348 
1349  /// Remove the operation from the worklist.
1350  void notifyOperationErased(Operation *op) override;
1351 
1352  /// Remove the operation from the worklist.
1353  void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
1354 
1355  /// The worklist for this transformation keeps track of the slices to visit
1356  /// next for fusion.
1357  std::deque<tensor::ExtractSliceOp> worklist;
1358 
1359 private:
1360  /// Optional pattern set to apply when adding new operations to the worklist.
1361  std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
1362 };
1363 
1364 SliceTrackingListener::SliceTrackingListener(
1365  std::optional<FrozenRewritePatternSet> p) {
1366  patterns = std::move(p);
1367 }
1368 
1369 LogicalResult
1370 SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
1371  for (Operation *op : ops) {
1372  if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1373  worklist.push_back(slice);
1374  }
1375 
1376  if (!patterns)
1377  return success();
1378 
1379  GreedyRewriteConfig config;
1380  config.listener = this;
1382  return applyOpPatternsAndFold(ops, patterns.value(), config);
1383 }
1384 
1385 void SliceTrackingListener::notifyOperationInserted(
1386  Operation *op, OpBuilder::InsertPoint previous) {
1387  auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1388  if (!slice)
1389  return;
1390  worklist.push_back(slice);
1391 }
1392 
1393 // Scan the worklist for the given op and remove it if present. The expectation
1394 // is for the worklist to be small and for removal to be relatively rare.
1395 void SliceTrackingListener::removeOp(Operation *op) {
1396  if (!isa<tensor::ExtractSliceOp>(op))
1397  return;
1398  auto iter = worklist.begin();
1399  while (iter != worklist.end()) {
1400  if (*iter == op)
1401  break;
1402  iter++;
1403  }
1404  if (iter == worklist.end())
1405  return;
1406 
1407  worklist.erase(iter);
1408 }
1409 
1410 void SliceTrackingListener::notifyOperationErased(Operation *op) {
1411  removeOp(op);
1412 }
1413 
1414 void SliceTrackingListener::notifyOperationReplaced(Operation *op,
1415  ValueRange replacement) {
1416  removeOp(op);
1417 }
1418 } // namespace
1419 
1420 /// Implementation of tile consumer and fuse producer greedily.
1421 FailureOr<scf::SCFTileAndFuseResult>
1423  RewriterBase &rewriter, TilingInterface consumer,
1425  // This transformation is only valid for ops that return values (i.e. not
1426  // valid to use with operations that have memref operands).
1427  if (!consumer->getNumResults()) {
1428  return rewriter.notifyMatchFailure(
1429  consumer, "invalid pattern for op with no results");
1430  }
1431 
1432  // 1. First tile the consumer.
1433  SetVector<Operation *> fusedProducers, tiledAndFusedOps;
1434  llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
1435 
1436  FailureOr<scf::SCFTilingResult> tilingResult =
1437  tileUsingSCF(rewriter, consumer, options.tilingOptions);
1438 
1439  if (failed(tilingResult))
1440  return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
1441  for (auto *tiledOp : tilingResult->tiledOps)
1442  tiledAndFusedOps.insert(tiledOp);
1443 
1444  // If there are no loops generated, fusion is immaterial.
1445  auto &loops = tilingResult->loops;
1446  if (loops.empty()) {
1447  DenseMap<Value, Value> replacements;
1448  for (auto [origVal, replacement] :
1449  llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
1450  replacements[origVal] = replacement;
1451  }
1452  return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1453  replacements};
1454  }
1455 
1456  // To keep track of replacements for now just record the map from the original
1457  // untiled value to the result number of the for loop. Since the loop gets
1458  // potentially replaced during fusion, keeping the value directly wont work.
1459  DenseMap<Value, size_t> origValToResultNumber;
1460  for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
1461  origValToResultNumber[result] = index;
1462  }
1463 
1464  // 2. Typically, the operands of the tiled operation are slices of the
1465  // operands of the untiled operation. These are expressed in IR using
1466  // `tensor.extract_slice` operations with source being the operands of the
1467  // untiled operation. Create a worklist of these `tensor.extract_slice`
1468  // operations. If the producers of the source of the `tensor.extract_slice`
1469  // can be tiled such that the tiled value is generated in-place, that
1470  // effectively tiles + fuses the operations.
1471  struct WorklistItem {
1472  tensor::ExtractSliceOp candidateSlice;
1474  };
1475 
1476  SliceTrackingListener sliceTracker =
1477  SliceTrackingListener(options.cleanupPatterns);
1478 
1479  if (failed(
1480  sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1481  return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1482  }
1483  OpBuilder::InsertionGuard g(rewriter);
1484  while (!sliceTracker.worklist.empty()) {
1485  auto candidateSlice = sliceTracker.worklist.front();
1486  sliceTracker.worklist.pop_front();
1487 
1488  auto [fusableProducer, destinationInitArg] =
1489  getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(),
1490  loops);
1491  if (!fusableProducer)
1492  continue;
1493 
1494  std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1495  options.fusionControlFn(candidateSlice, fusableProducer,
1496  destinationInitArg.has_value());
1497  if (!controlFnResult)
1498  continue;
1499 
1500  WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
1501 
1502  // The operands of the fused producer might themselved be slices of
1503  // values produced by operations that implement the `TilingInterface`.
1504  // Add these operations to the worklist.
1505  std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1506  tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice,
1507  loops);
1508  if (!fusedResult)
1509  continue;
1510 
1511  SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
1512 
1513  if (worklistItem.controlFnResult.yieldProducerReplacement) {
1514  // Reconstruct and yield all opResult of fusableProducerOp by default. The
1515  // caller can specific which one to yield by designating optional argument
1516  // named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
1517  Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1518  FailureOr<SmallVector<Operation *>> newSlices =
1520  worklistItem.candidateSlice,
1521  fusedResult.value(), loops);
1522  if (failed(newSlices)) {
1523  return rewriter.notifyMatchFailure(
1524  fusableProducerOp, "failed to replacement value for this "
1525  "operation from within the tiled loop");
1526  }
1527  worklistCandidates.append(newSlices.value());
1528  for (auto [index, result] :
1529  llvm::enumerate(fusableProducerOp->getResults())) {
1530  origValToResultNumber[result] = loops.front()->getNumResults() -
1531  fusableProducerOp->getNumResults() +
1532  index;
1533  }
1534  }
1535  if (Operation *tiledAndFusedOp =
1536  fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1537  fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1538  tiledAndFusedOps.insert(tiledAndFusedOp);
1539  }
1540 
1541  if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1542  return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1543  }
1544  }
1545 
1546  DenseMap<Value, Value> replacements;
1547  for (auto [origVal, resultNumber] : origValToResultNumber) {
1548  replacements[origVal] = loops.front()->getResult(resultNumber);
1549  }
1550 
1551  return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1552  replacements};
1553 }
1554 
1555 //===----------------------------------------------------------------------===//
1556 // tileAndFuseConsumerUsingSCF implementation.
1557 //===----------------------------------------------------------------------===//
1558 
1559 /// A utility function that checks whether the only use of the result of a
1560 /// tensor.insert_slice op is in a scf.yield op.
1561 static LogicalResult
1562 checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
1563  Value result = candidateSliceOp.getResult();
1564  Value::use_range uses = result.getUses();
1565  if (!llvm::hasSingleElement(uses)) {
1566  LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
1567  return failure();
1568  }
1569  OpOperand &operandUse = (*uses.begin());
1570  Operation *userOp = operandUse.getOwner();
1571  if (!isa<scf::YieldOp>(userOp)) {
1572  LLVM_DEBUG(llvm::dbgs()
1573  << "Expected scf.yield to be the only user, but got -> "
1574  << (*userOp));
1575  return failure();
1576  }
1577  if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
1578  LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
1579  "be in the same block\n");
1580  return failure();
1581  }
1582  return success();
1583 }
1584 
1585 /// An utility to get the first user of the given loopOp. If any of user stay in
1586 /// different block of loopOp, return failure.
1587 static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
1588  if (!isa<LoopLikeOpInterface>(loopOp))
1589  return failure();
1590  Operation *firstUserOfLoop = nullptr;
1591  for (Operation *userOp : loopOp->getUsers()) {
1592  // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
1593  // block with any other types of operation. Thus, just redirecting to its
1594  // parent `InParallelOp`. E.g.
1595  //
1596  // ```
1597  // %1 = scf.for {
1598  // ...
1599  // }
1600  // %2 = consumerOp ins(%1, ...)
1601  // scf.forall.in_parallel {
1602  // tensor.parallel_insert_slice %1
1603  // }
1604  // ```
1605  // where `InParallelOp` but not `ParallelInsertSlice` stays in the same
1606  // same block with `consumerOp`.
1607  if (isa<tensor::ParallelInsertSliceOp>(userOp))
1608  userOp = userOp->getParentOfType<scf::InParallelOp>();
1609 
1610  if (loopOp->getBlock() != userOp->getBlock())
1611  return failure();
1612 
1613  if (!firstUserOfLoop || userOp->isBeforeInBlock(firstUserOfLoop))
1614  firstUserOfLoop = userOp;
1615  }
1616  return firstUserOfLoop;
1617 }
1618 
1619 /// This utility currently checks whether the first userOp of loop is NOT before
1620 /// the last defineOp of consumer operand. Because that we need to move the
1621 /// whole loop structure right before the `firstUserOfLoop`. This utility thus
1622 /// helps ensuring that no invalid IR is formed, i.e. no backward slice of
1623 /// consumerOp is dominated by the `firstUserOfLoop`. Saying that:
1624 ///
1625 /// ```
1626 /// %0 = scf.for() {
1627 /// ...
1628 /// }
1629 /// ...
1630 /// %1 = firstUserOfLoop(%0)
1631 /// ...
1632 /// %2 = lastDefOfConsumerOperand
1633 /// ...
1634 /// %3 = consumerOp(%2)
1635 /// ```
1636 ///
1637 /// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it would
1638 /// be invalid to move the `loopOp` right before the `firstUserOfLoop`, a.k.a.
1639 /// use-def chain violation:
1640 ///
1641 /// ```
1642 /// %0:2 = scf.for() {
1643 /// // use before define error
1644 /// %3 = tiledConsumerOp(%2)
1645 /// }
1646 /// %1 = firstUserOfLoop(%0)
1647 /// ...
1648 /// %2 = lastDefOfConsumerOperand
1649 /// ```
1650 ///
1651 /// @param loopOp: loop operation
1652 /// @param consumerOp: consumer operation
1653 /// @param reorderOperations: the flag controls whether to reorder the backward
1654 /// slice w.r.t. the defineOp of `consumerOp` operands.
1655 /// @return: computed backward slice of consumerOp, but excluding those already
1656 /// dominates `firstUserOfLoop`.
1657 static FailureOr<llvm::SetVector<Operation *>>
1659  bool reorderOperations) {
1660  FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
1661  if (failed(firstUserOfLoop))
1662  return failure();
1663 
1665  DominanceInfo dominanceInfo;
1666  options.inclusive = true;
1667  options.omitBlockArguments = true;
1668  bool includeLoopOp = false;
1669  options.filter = [&](Operation *op) {
1670  if (op == loopOp) {
1671  includeLoopOp = true;
1672  return false;
1673  }
1674  // Cut off the slice to not include any operation that already dominates
1675  // firstUserOfLoop.
1676  return !dominanceInfo.properlyDominates(op, *firstUserOfLoop);
1677  };
1679  for (auto operand : consumerOp->getOperands()) {
1680  getBackwardSlice(operand, &slice, options);
1681  }
1682 
1683  if (!slice.empty()) {
1684  // If consumerOp has one producer, which is also the user of loopOp.
1685  // E.g.
1686  // ```
1687  // %0 = %loopOp
1688  // %1 = consumerOp1 ins(%0)
1689  // %2 = consumerOp2 ins(%0, %1)
1690  // ```
1691  // We can not fuse consumerOp2 into loopOp due to UD chain, unless
1692  // consumerOp1 has already been fused into loopOp before.
1693  if (includeLoopOp || !reorderOperations)
1694  return failure();
1695  }
1696 
1697  return slice;
1698 }
1699 
1700 /// Fetches the OpOperand of the first valid user (and use) of the value `val`
1701 /// which implements `TilingInterface` and `DestinationStyleOpInterface`.
1702 /// Returns failure otherwise.
1703 static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
1704  Operation *loopOp,
1705  unsigned resultNumber) {
1706  if (!isa<LoopLikeOpInterface>(loopOp))
1707  return failure();
1708  Value val = loopOp->getResult(resultNumber);
1709  Block *loopBlock = loopOp->getBlock();
1710  for (OpOperand &opOperand : val.getUses()) {
1711  Operation *consumerOp = opOperand.getOwner();
1712  // Step 1. Check if the user is tilable.
1713  if (!isa<TilingInterface>(consumerOp) ||
1714  !isa<DestinationStyleOpInterface>(consumerOp)) {
1715  // TODO: We have to init result of consumer before scf.for, use
1716  // DestinationStyleOpInterface to get result shape from init for now. Add
1717  // support for other op such as op has InferTypeOpInterface.
1718  continue;
1719  }
1720  // Step 2. Check if user stay in the same block.
1721  if (loopBlock != consumerOp->getBlock())
1722  continue;
1723  // Step 3. Check if user has succeeding user. Otherwise, it usually
1724  // represents already tiled.
1725  if (consumerOp->use_empty())
1726  continue;
1727  // Step 4. Check assumption for loop with `reorderOperations` enabled.
1728  FailureOr<llvm::SetVector<Operation *>> slice =
1729  checkAssumptionForLoop(loopOp, consumerOp, true);
1730  if (failed(slice))
1731  continue;
1732  // Step 5. If backward sice is not empty, move them before firstUserOfLoop.
1733  if (!slice->empty()) {
1734  mlir::topologicalSort(*slice);
1735  FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
1736  assert(succeeded(firstUserOfLoop) && "First user of loop is not found");
1737  for (auto op : *slice) {
1738  rewriter.moveOpBefore(op, *firstUserOfLoop);
1739  }
1740  }
1741  return &opOperand;
1742  }
1743  return failure();
1744 }
1745 
1746 /// Find the perfectly nested loops outside of given loop(included) sorted from
1747 /// outer to inner.
1748 ///
1749 /// E.g.
1750 ///
1751 /// ```
1752 /// %0 = scf.for()
1753 /// %1 = scf.for()
1754 /// %2 = scf.for()
1755 /// %3 = ...
1756 /// yield %3
1757 /// yield %2
1758 /// yield %1
1759 /// ```
1760 ///
1761 /// This function will return three perfectly nested loops: %0 + %1 + %2, when
1762 /// target inner loop is %2.
1765  SmallVector<scf::ForOp> nestLoops = {loop};
1766  auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp());
1767 
1768  // Check if it is the ForOp that yield the result of inner loop.
1769  auto isForOpYieldResultOfInnerLoop =
1770  [](scf::ForOp outerLoop) -> LogicalResult {
1771  Block *body = outerLoop.getBody();
1772  if (!llvm::hasSingleElement(body->without_terminator()))
1773  return failure();
1774  auto yieldOp = cast<scf::YieldOp>(body->getTerminator());
1775  auto innerForOp = dyn_cast<scf::ForOp>(body->front());
1776  if (!innerForOp)
1777  return failure();
1778  // All of innerForOp results should be yielded.
1779  return success(innerForOp->getNumResults() == yieldOp->getNumOperands());
1780  };
1781 
1782  while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) {
1783  nestLoops.push_back(outerLoop);
1784  outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp());
1785  }
1786  // sorted from outer to inner
1787  return {nestLoops.rbegin(), nestLoops.rend()};
1788 }
1789 
1790 /// Fetch the untiled consumer of a scf.for's result which is yielded by a
1791 /// tensor.insert_slice. This function makes the following assumptions :
1792 /// 1. tensor.insert_slice has scf.yield as its only user.
1793 /// 2. scf.for's corresponding result has only one use.
1794 static FailureOr<OpOperand *>
1796  tensor::InsertSliceOp candidateSliceOp) {
1797  if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
1798  return failure();
1799  Value sliceResult = candidateSliceOp.getResult();
1800  // Step 1. Fetch the corresponding output.
1801  OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
1802  unsigned resultNumber = yieldOpOperand.getOperandNumber();
1803  // Step 2. Check containing op is scf.for.
1804  Operation *containingOp = candidateSliceOp->getParentOp();
1805  auto forOp = dyn_cast<scf::ForOp>(containingOp);
1806  if (!forOp)
1807  return failure();
1808  scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front();
1809 
1810  return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber);
1811 }
1812 
1813 /// Fetch the first untiled consumer of a scf.forall's result which is yielded
1814 /// by a tensor.parallel_insert_slice.
1815 static FailureOr<OpOperand *>
1817  tensor::ParallelInsertSliceOp candidateSliceOp) {
1818  // Step 1. Fetch the corresponding output
1819  Value sliceDest = candidateSliceOp.getDest();
1820  auto iterArg = dyn_cast<BlockArgument>(sliceDest);
1821  if (!iterArg)
1822  return failure();
1823  Operation *containingOp = iterArg.getOwner()->getParentOp();
1824  if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
1825  return failure();
1826  // Step 2. Check that the containing op is scf.forall.
1827  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
1828  if (!forallOp)
1829  return failure();
1830  unsigned resultNumber =
1831  forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
1832  .getResultNumber();
1833 
1834  return getConsumerFromLoopUses(rewriter, containingOp, resultNumber);
1835 }
1836 
1837 /// A utility to fetch an untiled consumer of
1838 /// tensor.insert_slice/tensor.parallel_insert_slice.
1839 static FailureOr<OpOperand *>
1841  if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1842  return getUntiledConsumerFromSlice(rewriter, insertSlice);
1843  } else if (auto parallelInsertSlice =
1844  dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1845  return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice);
1846  } else {
1847  return failure();
1848  }
1849 }
1850 
1851 /// Implementation of fusing consumer of a single slice by computing the
1852 /// slice of the consumer in-place for scf loop.
1853 FailureOr<scf::SCFFuseConsumerOfSliceResult>
1855  Operation *candidateSliceOp) {
1856  if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
1857  candidateSliceOp))
1858  return failure();
1859 
1860  bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
1861 
1862  // 1. Get the consumer of scf.for for the result yielded by
1863  // tensor.insert_slice/parallel_insert_slice.
1864  FailureOr<OpOperand *> maybeConsumerOpOperand =
1865  getUntiledConsumerFromSlice(rewriter, candidateSliceOp);
1866  if (failed(maybeConsumerOpOperand)) {
1867  return rewriter.notifyMatchFailure(candidateSliceOp,
1868  "could not fetch consumer to fuse");
1869  }
1870  OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
1871  Operation *consumerOp = consumerOpOperand->getOwner();
1872  unsigned operandNumber = consumerOpOperand->getOperandNumber();
1873  unsigned resultNumber = 0;
1874  if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
1875  resultNumber = producerResult.getResultNumber();
1876  } else {
1877  return rewriter.notifyMatchFailure(
1878  consumerOp, "consumer op's operand doesn't seem to be an OpResult");
1879  }
1880 
1881  // There are two possible cases regarding `oldLoopOp` here:
1882  // 1. single `scf.forall` or `scf.for`.
1883  // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
1884  // top-level loop is the outer-most one of these nested loops.
1885  LoopLikeOpInterface innerMostLoop =
1886  candidateSliceOp->getParentOfType<LoopLikeOpInterface>();
1888  if (isInsertSliceOp) {
1889  nestedLoops = llvm::map_to_vector(
1891  cast<scf::ForOp>(innerMostLoop.getOperation())),
1892  [](scf::ForOp forOp) {
1893  return cast<LoopLikeOpInterface>(forOp.getOperation());
1894  });
1895  } else {
1896  nestedLoops = {innerMostLoop};
1897  }
1898 
1899  LoopLikeOpInterface outerMostLoop = nestedLoops.front();
1900 
1901  // Check assumption for loop with `reorderOperations` disabled.
1902  if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
1903  return rewriter.notifyMatchFailure(
1904  outerMostLoop, "the first user of loop should not dominate any define "
1905  "of consumer operand(s)");
1906  }
1907 
1908  OpBuilder::InsertionGuard g(rewriter);
1909 
1910  // 2. Check consumer is not using scf loop's output as init.
1911  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
1912  if (!dstOp)
1913  return rewriter.notifyMatchFailure(consumerOp,
1914  "consumer op is not DPS operation");
1915  SmallVector<Value> dpsInits =
1916  llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
1917  if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
1918  return rewriter.notifyMatchFailure(
1919  consumerOp,
1920  "consumer op taking the result of scf.for as init is not supported");
1921  }
1922  SmallVector<Value> newInits = dpsInits;
1923 
1924  Location loc = outerMostLoop->getLoc();
1925 
1926  // 3. Move the whole loop structure right before firstUserOfLoop, the
1927  // dominance should be already ensured by `checkAssumptionForLoop`.
1928  FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(outerMostLoop);
1929  if (failed(firstUserOfLoop)) {
1930  return rewriter.notifyMatchFailure(
1931  outerMostLoop, "could not find the first user of outer most loop");
1932  }
1933  rewriter.moveOpBefore(outerMostLoop, *firstUserOfLoop);
1934 
1935  // 4. Set insertion point before terminator op of the loop and create a new
1936  // tensor.insert_slice. In the scf.for case this is a clone of the
1937  // candidateSliceOp whereas in the scf.forall case this is created from the
1938  // operands of tensor.parallel_insert_slice.
1939  tensor::InsertSliceOp clonedInsertSliceOp;
1940  if (auto sliceOp =
1941  dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
1942  auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
1943  rewriter.setInsertionPoint(newForallOp.getTerminator());
1944  clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
1945  loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
1946  sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
1947  } else {
1948  rewriter.setInsertionPoint(candidateSliceOp);
1949  clonedInsertSliceOp =
1950  cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
1951  }
1952 
1953  // 5.a. Clone consumer op.
1954  auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
1955 
1956  // 5.b. Replace all uses of the loop result with the result of the cloned
1957  // tensor.insert_slice.
1958  OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
1959  rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
1960  operandToReplace.set(clonedInsertSliceOp.getResult());
1961  });
1962 
1963  // 6. Perform tiling of the cloned consumer and replace the operand at
1964  // `operandNumber` with the source of the cloned tensor.insert_slice op.
1965  auto ossSliceOp =
1966  cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
1967  FailureOr<TilingResult> tileAndFuseResult =
1969  rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
1970  if (failed(tileAndFuseResult)) {
1971  return failure();
1972  }
1973  auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
1974  rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
1975  clonedInsertSliceOp.getSource());
1976 
1977  // 7. Reconstruct [nested] loop with new inits.
1978  YieldTiledValuesFn newYieldValuesFn =
1979  [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
1980  ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
1982  SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
1983  OpBuilder::InsertionGuard g(innerRewriter);
1984  // 8. Set inner insertPoint right before tiled consumer op.
1985  innerRewriter.setInsertionPoint(tiledConsumerOp);
1986 
1987  SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
1988  SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
1989  SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
1990 
1991  // 9. Check all insert stride is 1.
1992  if (llvm::any_of(strides, [](OpFoldResult stride) {
1993  return !isConstantIntValue(stride, 1);
1994  })) {
1995  return rewriter.notifyMatchFailure(
1996  candidateSliceOp, "containingOp's result yield with stride");
1997  }
1998 
1999  // 10. Try to get iter domain position from input position. Use
2000  // clonedConsumerOp instead of tiledConsumerOp, because the iteration domain
2001  // may require index computation based on the result size. The sizes and
2002  // offsets should be the same either way, but using tiledConsumerOp could
2003  // lead to some chained unnecessary extra index computation.
2004  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
2005  if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
2006  rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
2007  iterDomainSizes))) {
2008  return rewriter.notifyMatchFailure(
2009  clonedConsumerOp,
2010  "can't get iter domain position from input position");
2011  }
2012 
2013  // 11. Try to fetch the offset and size for all results of the cloned
2014  // consumer. This would then be used to form the corresponding
2015  // tensor.insert_slice/parallel_insert_slice later.
2016  unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
2018  totalNumResultsOfConsumer);
2020  totalNumResultsOfConsumer);
2021  for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
2022  if (failed(tiledConsumerOp.getResultTilePosition(
2023  rewriter, idx, iterDomainOffsets, iterDomainSizes,
2024  resultOffsets[idx], resultSizes[idx]))) {
2025  return rewriter.notifyMatchFailure(
2026  tiledConsumerOp,
2027  "can't get result domain position from iter domain position");
2028  }
2029  }
2030 
2031  // 12. Create `extract_slice` for `iter_args` for DPS operation if
2032  // necessary.
2033  if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
2034  tiledConsumerOp.getOperation())) {
2035  rewriter.setInsertionPoint(tiledDestStyleOp);
2036  for (const auto &&[index, newRegionArg] :
2037  llvm::enumerate(newRegionIterArgs)) {
2038  auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
2039  loc, newRegionArg, resultOffsets[index], resultSizes[index],
2040  SmallVector<OpFoldResult>(resultOffsets[index].size(),
2041  rewriter.getIndexAttr(1)));
2042  // Make a copy of index to avoid a capturing structured binding, which
2043  // is a C++20 extension.
2044  auto dstNumber = index;
2045  rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
2046  tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
2047  });
2048  }
2049  }
2050 
2051  // 13. Prepare tiled offset and sizes for later `insert_slice` creation by
2052  // caller.
2053  Block *block = rewriter.getInsertionPoint()->getBlock();
2054  rewriter.setInsertionPoint(block->getTerminator());
2055  for (const auto &&[index, result] :
2056  llvm::enumerate(tiledConsumerOp->getResults())) {
2057  tiledResult.push_back(result);
2058  tiledOffset.emplace_back(resultOffsets[index]);
2059  tiledSizes.emplace_back(resultSizes[index]);
2060  }
2061  return success();
2062  };
2063  // 14. Add new inits to [nested] loops.
2064  if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits,
2065  newYieldValuesFn))) {
2066  return rewriter.notifyMatchFailure(tiledConsumerOp,
2067  "unable to add new inits to nest loop");
2068  }
2069 
2070  // 15. Replace the result of scf loop and consumer op with new loop's results.
2071 
2072  for (auto &&[oldResult, newResult] : llvm::zip(
2073  consumerOp->getResults(),
2074  nestedLoops.front()->getResults().take_back(newInits.size()))) {
2075  rewriter.replaceAllUsesWith(oldResult, newResult);
2076  }
2077 
2078  // 16. Need to erase the old scf loop and the cloned consumer op.
2079  rewriter.eraseOp(clonedConsumerOp);
2080 
2082  consumerOpOperand,
2083  &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
2084  tileAndFuseResult->tiledOps};
2085 }
2086 
2087 //===----------------------------------------------------------------------===//
2088 // lowerToLoopsUsingSCFForOp implementation.
2089 //===----------------------------------------------------------------------===//
2090 
2091 FailureOr<SmallVector<scf::ForOp>>
2093  TilingInterface op) {
2094  // TODO: Handle cases where the op has results if needed.
2095  if (op->getNumResults() > 0) {
2096  return rewriter.notifyMatchFailure(
2097  op, "unable to lower to loops operations with return values");
2098  }
2099 
2100  SmallVector<Range> domain = op.getIterationDomain(rewriter);
2101  SmallVector<Value> ivs;
2103  Location loc = op.getLoc();
2104  for (auto loopRange : domain) {
2105  Value offsetVal =
2106  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
2107  Value sizeVal =
2108  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
2109  Value strideVal =
2110  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
2111  auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
2112  strideVal, ValueRange{});
2113  loops.push_back(loop);
2114  ivs.push_back(loop.getInductionVar());
2115  rewriter.setInsertionPoint(loop.getBody()->getTerminator());
2116  }
2117  if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
2118  return failure();
2119  }
2120  return loops;
2121 }
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static LogicalResult verifyTileSizeOptions(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options)
Verify the tile size options are set in a consistent manner.
static LogicalResult checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp)
A utility function that checks whether the only use of the result of a tensor.insert_slice op is in a...
std::function< LogicalResult(RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, SmallVector< Value > &tiledValues, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> YieldTiledValuesFn
A function that allows returning additional yielded values during yieldTiledValuesAndReplace.
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, OpFoldResult numThreads, OpFoldResult iterationSize)
Returns true if the maximum tile offset tileSize * numThreads-1 is less than iterationSize.
static FailureOr< llvm::SetVector< Operation * > > checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, bool reorderOperations)
This utility currently checks whether the first userOp of loop is NOT before the last defineOp of con...
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, OpFoldResult offset, OpFoldResult tileSize)
Returns the bounded tile size given the current offset, loopRange and tileSize, i....
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Append the specified additional newInitOperands operands to the loops existing init operands (or simi...
static LogicalResult generateLoopNestUsingForOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ValueRange destinationTensors, YieldTiledValuesFn yieldTiledValuesFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.for operation.
static FailureOr< OpOperand * > getConsumerFromLoopUses(RewriterBase &rewriter, Operation *loopOp, unsigned resultNumber)
Fetches the OpOperand of the first valid user (and use) of the value val which implements TilingInter...
static FailureOr< OpOperand * > getUntiledConsumerFromSlice(RewriterBase &rewriter, tensor::InsertSliceOp candidateSliceOp)
Fetch the untiled consumer of a scf.for's result which is yielded by a tensor.insert_slice.
static void checkSafeToTileToForall(TilingInterface op, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Checks if any of the tiled loops are not parallel.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, ArrayRef< Range > iterationDomain, const scf::SCFTilingOptions &options)
Method to instantiate the tile sizes and/or number of threads specified by the user.
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< LoopLikeOpInterface > loops)
Return the untiled producer whose slice is used in a tiled consumer.
static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using the loop construct specifed in options.
static SmallVector< scf::ForOp > getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop)
Find the perfectly nested loops outside of given loop(included) sorted from outer to inner.
static bool tileDividesIterationDomain(Range loopRange)
Check if stride evenly divides the trip count size - offset.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes)
Function to return the bounds of the loops to be generated.
static Operation * cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *op, ValueRange newDestArgs)
Clones the operation and updates the destination if the operation implements the DestinationStyleOpIn...
static FailureOr< Operation * > getFirstUserOfLoop(Operation *loopOp)
An utility to get the first user of the given loopOp.
static LogicalResult addInitOperandsToLoopNest(RewriterBase &rewriter, MutableArrayRef< LoopLikeOpInterface > loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn)
Method to add new init values to a loop nest.
static LogicalResult generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ArrayRef< Attribute > mappingVector, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.forall operation.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:964
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:306
A class for computing basic dominance information.
Definition: Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.h:153
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteStrictness strictMode
Strict mode can restrict the ops that are added to the worklist during the rewrite.
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class represents a saved insertion point.
Definition: Builders.h:335
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:453
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:444
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:420
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:466
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:848
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:644
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
Definition: AffineOps.cpp:1307
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1300
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1194
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< scf::SCFReductionTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
FailureOr< scf::SCFFuseConsumerOfSliceResult > tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp)
Implementation of fusing consumer of a single slice by computing the slice of the consumer in-place f...
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< SmallVector< Operation * > > yieldReplacementForFusedProducer(RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef< LoopLikeOpInterface > loops, ArrayRef< unsigned > yieldResultNumber=ArrayRef< unsigned >{})
Reconstruct the fused producer from within the tiled-and-fused code.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
std::optional< SCFFuseProducerOfSliceResult > tileAndFuseProducerOfSlice(RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing producer of a single slice by computing the slice of the producer in-place.
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:56
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:75
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:110
FailureOr< TilingResult > replaceInsertSliceWithTiledConsumer(OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp, OpOperand &consumerOp)
Method to swap an tensor.insert_slice with its consumer when the consumer implements the TilingInterf...
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
LogicalResult applyOpPatternsAndFold(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
@ ExistingAndNewOps
Only pre-existing and newly created ops are processed.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
Container for result values of tiling.
Fuse the consumer of the source of candidateSliceOp by computing the required slice of the consumer i...
Fuse the producer of the source of candidateSliceOp by computing the required slice of the producer i...
Transformation information returned after reduction tiling.
SmallVector< Value > replacements
The replacements to use for the results of the tiled operation.
SmallVector< Value > initialValues
Initial values used for reduction.
SmallVector< Operation * > parallelTiledOps
The partial reduction tiled op generated.
SmallVector< LoopLikeOpInterface > loops
The loop operations that iterate over the tiles.
SmallVector< Operation * > mergeOps
The final reduction operation merging all the partial reductions.
Control function to check if a slice needs to be fused or not, The control function receives 1) the s...
Options used to control tile + fuse.
Transformation information returned after tile and fuse.
Options to use to control tiling.
SCFTileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes to use for each loop.
SCFTilingOptions & setNumThreads(ArrayRef< OpFoldResult > numThreads)
Convenience function to set the numThreadsComputationFunction to a function that computes num threads...
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.