MLIR  20.0.0git
TileUsingInterface.cpp
Go to the documentation of this file.
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
22 #include "mlir/IR/Dominance.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/Support/Debug.h"
31 #include <optional>
32 
33 #define DEBUG_TYPE "tile-using-interface"
34 
35 using namespace mlir;
36 
39  assert(!tileSizeComputationFunction && "tile sizes already set");
40  auto tileSizes = llvm::to_vector(ts);
41  tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
42  return tileSizes;
43  };
44  return *this;
45 }
46 
49  assert(!numThreadsComputationFunction && "num tiles already set");
50  auto numThreads = llvm::to_vector(nt);
51  numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) {
52  return numThreads;
53  };
54  return *this;
55 }
56 
57 /// Helper method to adjust the interchange vector to match the iteration
58 /// domain.
61  size_t iterationDomainSize) {
62  SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
63  if (filledVector.size() < iterationDomainSize) {
64  auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
65  filledVector.append(range.begin(), range.end());
66  }
67  if (filledVector.size() > iterationDomainSize)
68  filledVector.resize(iterationDomainSize);
69  return filledVector;
70 }
71 
72 //===----------------------------------------------------------------------===//
73 // tileUsingSCF implementation.
74 //===----------------------------------------------------------------------===//
75 
76 /// Verify the tile size options are set in a consistent manner.
77 static LogicalResult
80  // Specifying number of threads is only supported on `scf.forall` op.
81  if (options.numThreadsComputationFunction &&
83  return rewriter.notifyMatchFailure(
84  loc, "number of threads can only by specified when loop type is "
85  "set to use `scf.forall`");
86  }
87 
88  // If specified, check that the interchange vector is a permutation.
89  if (!options.interchangeVector.empty()) {
90  if (!isPermutationVector(options.interchangeVector)) {
91  return rewriter.notifyMatchFailure(
92  loc, "invalid interchange vector, not a permutation of the entire "
93  "iteration space");
94  }
95  }
96  return success();
97 }
98 
99 /// Method to instantiate the tile sizes and/or number of threads specified
100 /// by the user.
101 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
102 getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
103  ArrayRef<Range> iterationDomain,
105  OpFoldResult zero = rewriter.getIndexAttr(0);
106  SmallVector<OpFoldResult> tileSizes, numThreads;
107  size_t numLoops = iterationDomain.size();
108 
109  // Check whether the number of tiles to use is specified.
110  if (options.numThreadsComputationFunction) {
111  numThreads = options.numThreadsComputationFunction(rewriter, op);
112  numThreads.resize(numLoops, zero);
113 
114  // If the number of tiles is also specified, use that.
115  if (options.tileSizeComputationFunction) {
116  tileSizes = options.tileSizeComputationFunction(rewriter, op);
117  tileSizes.resize(numLoops, zero);
118  return {tileSizes, numThreads};
119  }
120 
121  // Compute the tile sizes from the iteration domain and number
122  // of tiles as follows
123  // - niters = ceilDiv(ub - lb, step)
124  // - tileSize = ceilDiv(niters, numThreads)
125  AffineExpr s0, s1, s2;
126  bindSymbols(rewriter.getContext(), s0, s1, s2);
127  // TODO: The step here is assumed to be 1.
128  AffineExpr numItersExpr = (s1 - s0);
129  AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s2);
130  tileSizes.resize(numLoops, zero);
131  for (auto [index, range, nt] :
132  llvm::enumerate(iterationDomain, numThreads)) {
133  if (isConstantIntValue(nt, 0))
134  continue;
135 
136  tileSizes[index] = affine::makeComposedFoldedAffineApply(
137  rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
138  }
139  tileSizes.resize(numLoops, zero);
140  return {tileSizes, numThreads};
141  }
142 
143  // Enforce the convention that "tiling by zero"
144  // skips tiling a particular dimension. This convention is significantly
145  // simpler to handle instead of adjusting affine maps to account for missing
146  // dimensions.
147  assert(options.tileSizeComputationFunction &&
148  "expected tile sizes to be specified");
149  tileSizes = options.tileSizeComputationFunction(rewriter, op);
150  tileSizes.resize(numLoops, zero);
151 
152  return {tileSizes, numThreads};
153 }
154 
155 /// Checks if any of the tiled loops are not parallel.
156 static void checkSafeToTileToForall(TilingInterface op,
157  ArrayRef<OpFoldResult> tileSizes,
158  ArrayRef<OpFoldResult> numThreads) {
159  auto iterators = op.getLoopIteratorTypes();
160  assert(iterators.size() == tileSizes.size() &&
161  "expected as many tile size values as number of loops");
162  assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
163  "when specified, expected number of threads to use for each loop");
164 
165  for (auto [index, iterator, tileSize] :
166  llvm::enumerate(iterators, tileSizes)) {
167  // If num threads is specified, check that it is greater than one only for
168  // parallel dimensions.
169  if (!numThreads.empty()) {
170  if (std::optional<int64_t> constNumThreads =
171  getConstantIntValue(numThreads[index])) {
172  if (constNumThreads.value() > 1 &&
173  iterator != utils::IteratorType::parallel) {
174  op.emitWarning() << "tiling is not thread safe at axis #" << index;
175  }
176  }
177  continue;
178  }
179 
180  if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) {
181  if (constTileSize.value() > 0 &&
182  iterator != utils::IteratorType::parallel) {
183  op.emitWarning() << "tiling is not thread safe at axis #" << index;
184  }
185  }
186  }
187 }
188 
189 /// Check if `stride` evenly divides the trip count `size - offset`.
190 static bool tileDividesIterationDomain(Range loopRange) {
191  std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
192  if (!offsetAsInt)
193  return false;
194  std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
195  if (!sizeAsInt)
196  return false;
197  std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
198  if (!strideAsInt)
199  return false;
200  return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
201 }
202 
203 /// Returns the bounded tile size given the current `offset`, `loopRange` and
204 /// `tileSize`, i.e., `min(tileSize, range.end() - offset)`.
206  Range loopRange, OpFoldResult offset,
207  OpFoldResult tileSize) {
208  std::optional<int64_t> ts = getConstantIntValue(tileSize);
209  if (ts && ts.value() == 1)
210  return tileSize;
211 
213  Range{loopRange.offset, loopRange.size, tileSize}))
214  return tileSize;
215 
216  // The tile size to use (to avoid out of bounds access) is minimum of
217  // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
218  // loop.
219  AffineExpr s0, s1, d0;
220  bindDims(b.getContext(), d0);
221  bindSymbols(b.getContext(), s0, s1);
222  AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext());
223  Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
225  b, loc, minMap, SmallVector<OpFoldResult>{offset, size, tileSize});
226 }
227 
228 /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
229 /// than `iterationSize`.
231  OpFoldResult numThreads,
232  OpFoldResult iterationSize) {
233  std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
234  std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
235  std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
236  if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
237  return false;
238  return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
239 }
240 
241 /// Compute the `OpFoldResult`s that represents the multi-dimensional
242 /// `offset`s and `size`s of the tile of the iteration space that the
243 /// innermost loop body of the generated tiled loops corresponds to.
244 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
246  ArrayRef<Range> iterationDomain,
247  ArrayRef<OpFoldResult> tileSizes,
248  ArrayRef<OpFoldResult> numThreads) {
249  SmallVector<OpFoldResult> offsets, sizes;
250  int materializedLoopNum = 0;
251 
252  if (!numThreads.empty()) {
253  AffineExpr d0, d1, s0, s1;
254  AffineExpr offsetExpr, residualTileSizeExpr;
255  bindDims(rewriter.getContext(), d0, d1);
256  bindSymbols(rewriter.getContext(), s0, s1);
257  offsetExpr = d0 + d1 * s0;
258  residualTileSizeExpr = s1 - (d0 + d1 * s0);
259 
260  for (auto [nt, tileSize, loopRange] :
261  llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
262 
263  // Non-tiled cases, set the offset and size to the
264  // `loopRange.offset/size`.
265  if (isConstantIntValue(nt, 0)) {
266  offsets.push_back(loopRange.offset);
267  sizes.push_back(loopRange.size);
268  continue;
269  }
270 
271  Value iv = ivs[materializedLoopNum++];
273  rewriter, loc, offsetExpr,
274  ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize});
276  rewriter, loc, residualTileSizeExpr,
277  {loopRange.offset, nt, tileSize, loopRange.size});
278 
279  OpFoldResult size = tileSize;
280  if (!isConstantIntValue(residualTileSize, 0)) {
281  OpFoldResult sizeMinusOffsetPerThread =
282  affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
283  {offset, loopRange.size});
285  rewriter, loc,
287  {sizeMinusOffsetPerThread, tileSize});
288  }
289 
290  // Consider the case where the original loop was `[0, 100)`.
291  // If number of threads are `7`, the tile size would be computed as
292  // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6)
293  // - `offset = 0 + 6 * 15 = 105`
294  // - `tileSize = min(15, 100 - 105) = -5`
295  // To avoid negative tile sizes, we need to do a further
296  // `nonNegativeTileSize = affine.max(0, tileSize)`.
297  // This `max` can be avoided if
298  // `offset + tileSize * (numThreads - 1) < (ub - lb)`
299  if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) {
300  AffineMap maxMap =
303  rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});
304  }
305 
306  offsets.push_back(offset);
307  sizes.push_back(size);
308  }
309  return {offsets, sizes};
310  } else {
311  for (auto [tileSize, loopRange] :
312  llvm::zip_equal(tileSizes, iterationDomain)) {
313 
314  // Non-tiled cases, set the offset and size to the
315  // `loopRange.offset/size`.
316  if (isConstantIntValue(tileSize, 0)) {
317  offsets.push_back(loopRange.offset);
318  sizes.push_back(loopRange.size);
319  continue;
320  }
321 
322  Value iv = ivs[materializedLoopNum++];
323  OpFoldResult offset = getAsOpFoldResult(iv);
324  offsets.push_back(offset);
325  OpFoldResult size =
326  getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize);
327  sizes.push_back(size);
328  }
329  return {offsets, sizes};
330  }
331 }
332 
333 /// Function to return the bounds of the loops to be generated.
334 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
337  ArrayRef<OpFoldResult> tileSizes) {
338  SmallVector<OpFoldResult> lbs, ubs, steps;
339  for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
340  // No loop if the tile size is 0.
341  if (isConstantIntValue(tileSize, 0))
342  continue;
343  lbs.push_back(loopRange.offset);
344  ubs.push_back(loopRange.size);
345  steps.push_back(tileSize);
346  }
347  return {lbs, ubs, steps};
348 }
349 
350 /// A function that allows returning additional yielded values during
351 /// `yieldTiledValuesAndReplace`.
352 /// - `ivs` induction variable for the loop.
353 /// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
354 /// - `tiledValues` the tiled values to return. Must be of same size as
355 /// `newbbArgs`, each element of this array is inserted into the corresponding
356 /// element in `newbbArgs`.
357 /// - `resultOffsets` is of the same size as `tiledValues` and represents
358 /// the offsets to use when inserting corresponding element from `tiledValues`
359 /// into the element from `newBbArgs`.
360 /// - `resultSizes` is of the same size as `tiledValues` and represents
361 /// the size of the corresponding element from `tiledValues` inserted into
362 /// the element from `newBbArgs`.
363 /// In case the method needs to return `failure()` the method is expected
364 /// to clean up any inserted operations.
365 using YieldTiledValuesFn = std::function<LogicalResult(
366  RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
367  SmallVector<Value> &tiledValues,
368  SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
369  SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
370 
371 /// Clones the operation and updates the destination if the operation
372 /// implements the `DestinationStyleOpInterface`.
374  Operation *op,
375  ValueRange newDestArgs) {
376  Operation *clonedOp = rewriter.clone(*op);
377  if (newDestArgs.empty())
378  return clonedOp;
379  if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
380  destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
381  return clonedOp;
382 }
383 
384 /// Generate the tile-loop nest using `scf.for` operation.
385 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
386 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
387 /// - `destinationTensors` are the init values to use for the outer most loop.
388 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
389 /// most
390 /// loop.
391 /// - `loops` is an in-out parameter into which the generated loops are
392 /// populated.
393 static LogicalResult generateLoopNestUsingForOp(
394  RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
395  ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
396  YieldTiledValuesFn yieldTiledValuesFn,
398  assert(!loopRanges.empty() && "unexpected empty loop ranges");
399  assert(loopRanges.size() == tileSizes.size() &&
400  "expected as many tile sizes as loop ranges");
401  OpBuilder::InsertionGuard guard(rewriter);
402 
403  SmallVector<OpFoldResult> lbs, ubs, steps;
404  std::tie(lbs, ubs, steps) =
405  getLoopBounds(rewriter, loc, loopRanges, tileSizes);
406  SmallVector<Value> lbVals =
407  getValueOrCreateConstantIndexOp(rewriter, loc, lbs);
408  SmallVector<Value> ubVals =
409  getValueOrCreateConstantIndexOp(rewriter, loc, ubs);
410  SmallVector<Value> stepVals =
411  getValueOrCreateConstantIndexOp(rewriter, loc, steps);
412 
413  SmallVector<Value> ivs;
414  for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
415  auto loop =
416  rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
417  [](OpBuilder &bodyBuilder, Location bodyLoc,
418  Value iv, ValueRange /*iterArgs*/) {});
419  loops.push_back(loop);
420  ivs.push_back(loop.getInductionVar());
421  rewriter.setInsertionPointToEnd(loop.getBody());
422  destinationTensors = loop.getRegionIterArgs();
423  }
424 
425  SmallVector<Value> tiledResults;
426  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
427  if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
428  tiledResults, resultOffsets, resultSizes))) {
429  return rewriter.notifyMatchFailure(
430  loc, "failed to generate inner tile loop body");
431  }
432  if (loops.empty())
433  return success();
434 
435  assert(tiledResults.size() == destinationTensors.size() &&
436  "Number of results of body should be equal to number of iter args");
437 
438  // 6. Yield all the results of the tiled operation.
439  SmallVector<Value> yieldedValues;
440  for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
441  llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
442  resultSizes)) {
443  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
444  rewriter.getIndexAttr(1));
445  auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
446  loc, tiledValue, destinationTensor, resultOffset, resultSize,
447  resultStride);
448  yieldedValues.push_back(insertSlice);
449  }
450  rewriter.create<scf::YieldOp>(loc, yieldedValues);
451 
452  // Add the scf.yield operations for all the outer loops.
453  for (auto [outerLoop, innerLoop] :
454  llvm::zip_equal(MutableArrayRef(loops).drop_back(),
455  MutableArrayRef(loops).drop_front())) {
456  rewriter.setInsertionPointToEnd(
457  cast<scf::ForOp>(outerLoop.getOperation()).getBody());
458  rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
459  }
460  return success();
461 }
462 
463 /// Generate the tile-loop nest using `scf.forall` operation.
464 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
465 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
466 /// - `destinationTensors` are the init values to use for the outer most loop.
467 /// - `mappingVector` is the mapping attributes to use for loop construction.
468 /// Can be empty.
469 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
470 /// most
471 /// loop.
472 /// - `loops` is an in-out parameter into which the generated loops are
473 /// populated.
474 static LogicalResult generateLoopNestUsingForallOp(
475  RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
476  ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
477  ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
479  assert(!loopRanges.empty() && "unexpected empty loop ranges");
480  assert(loopRanges.size() == tileSizes.size() &&
481  "expected as many tile sizes as loop ranges");
482  OpBuilder::InsertionGuard guard(rewriter);
483  SmallVector<OpFoldResult> offsets(loopRanges.size()),
484  sizes(loopRanges.size());
485 
486  std::optional<ArrayAttr> mappingAttr;
487  if (!mappingVector.empty())
488  mappingAttr = rewriter.getArrayAttr(mappingVector);
489 
490  scf::ForallOp forallOp;
491  bool useNumThreads = !numThreads.empty();
492 
493  if (useNumThreads) {
494  // Prune the zero numthreads.
495  SmallVector<OpFoldResult> nonZeroNumThreads;
496  for (auto nt : numThreads) {
497  if (isConstantIntValue(nt, 0))
498  continue;
499  nonZeroNumThreads.push_back(nt);
500  }
501  forallOp = rewriter.create<scf::ForallOp>(loc, nonZeroNumThreads,
502  destinationTensors, mappingAttr);
503  } else {
504  SmallVector<OpFoldResult> lbs, ubs, steps;
505  std::tie(lbs, ubs, steps) =
506  getLoopBounds(rewriter, loc, loopRanges, tileSizes);
507  forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps,
508  destinationTensors, mappingAttr);
509  }
510  loops.push_back(forallOp);
511 
512  rewriter.setInsertionPoint(forallOp.getTerminator());
513  destinationTensors = forallOp.getRegionOutArgs();
514 
515  SmallVector<Value> tiledResults;
516  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
517  if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
518  destinationTensors, tiledResults, resultOffsets,
519  resultSizes)))
520  return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
521 
522  rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
523  for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
524  llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
525  resultSizes)) {
526  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
527  rewriter.getIndexAttr(1));
528 
529  rewriter.create<tensor::ParallelInsertSliceOp>(
530  loc, tiledValue, destinationTensor, resultOffset, resultSize,
531  resultStride);
532  }
533  return success();
534 }
535 
536 /// Generate the tile-loop nest using the loop construct specifed in `options`.
537 /// - `options`: Tiling options specified.
538 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
539 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
540 /// - `destinationTensors` are the init values to use for the outer most loop.
541 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
542 /// most
543 /// loop.
544 /// - `loops` is an in-out parameter into which the generated loops are
545 /// populated.
546 static LogicalResult generateLoopNest(
547  RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options,
548  ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes,
549  ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors,
551  // If the tile sizes are all zero, no loops are generated. Just call the
552  // callback function to handle untiled case.
553  if (llvm::all_of(tileSizes, isZeroIndex)) {
554  SmallVector<Value> tiledResults;
555  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
556  return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
557  tiledResults, resultOffsets, resultSizes);
558  }
560  return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
561  destinationTensors, tiledBodyFn, loops);
562  }
565  rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector,
566  destinationTensors, tiledBodyFn, loops);
567  }
568  return rewriter.notifyMatchFailure(loc, "unhandled loop type");
569 }
570 
571 /// Append the specified additional `newInitOperands` operands to the
572 /// loops existing `init` operands (or similar), and replace `loopOp` with
573 /// the new loop that has the additional init operands. The loop body of
574 /// this loop is moved over to the new loop. `yieldTiledValuesFn`
575 /// is called to get the new tiled values returned, and the offset
576 /// and sizes at which the tiled value is inserted into the
577 /// new region iter_args that correspond to the newly added init operands.
578 template <typename LoopType>
579 FailureOr<LoopLikeOpInterface>
580 yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter,
581  ValueRange newInitOperands,
582  YieldTiledValuesFn yieldTiledValuesFn) {
583  return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
584 }
585 
586 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`.
587 template <>
588 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
589  scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
590  YieldTiledValuesFn yieldTiledValuesFn) {
591  OpBuilder::InsertionGuard g(rewriter);
592  Location loc = loopOp.getLoc();
593  rewriter.setInsertionPoint(loopOp);
594 
595  auto inits = llvm::to_vector(loopOp.getInitArgs());
596  inits.append(newInitOperands.begin(), newInitOperands.end());
597  auto newLoop = rewriter.create<scf::ForOp>(
598  loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
599  inits, [](OpBuilder &, Location, Value, ValueRange) {});
600 
601  // Move the loop body to the new op.
602  Block *loopBody = loopOp.getBody();
603  Block *newLoopBody = newLoop.getBody();
604  rewriter.mergeBlocks(
605  loopBody, newLoopBody,
606  newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
607 
608  auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator());
609  rewriter.setInsertionPoint(yieldOp);
610 
611  SmallVector<Value> tiledValues;
612  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
613  ValueRange newRegionIterArgs =
614  newLoop.getRegionIterArgs().take_back(newInitOperands.size());
615  if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
616  newRegionIterArgs, tiledValues, resultOffsets,
617  resultSizes))) {
618  rewriter.eraseOp(newLoop);
619  return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");
620  }
621 
622  SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
623  for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
624  llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
625  resultSizes)) {
626  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
627  rewriter.getIndexAttr(1));
628  Value insert = rewriter.create<tensor::InsertSliceOp>(
629  yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
630  resultStride);
631  newYieldValues.push_back(insert);
632  }
633 
634  rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
635  rewriter.replaceOp(loopOp,
636  newLoop->getResults().take_front(loopOp.getNumResults()));
637  return cast<LoopLikeOpInterface>(newLoop.getOperation());
638 }
639 
640 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall`
641 template <>
642 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
643  scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
644  YieldTiledValuesFn yieldTiledValuesFn) {
645  OpBuilder::InsertionGuard g(rewriter);
646  Location loc = loopOp.getLoc();
647  rewriter.setInsertionPoint(loopOp);
648  auto inits = llvm::to_vector(loopOp.getOutputs());
649  inits.append(newInitOperands.begin(), newInitOperands.end());
650  auto newLoop = rewriter.create<scf::ForallOp>(
651  loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
652  loopOp.getMixedStep(), inits, loopOp.getMapping(),
653  [](OpBuilder &, Location, ValueRange) {});
654 
655  // Move the region of the current block to the newly created op.
656  Block *loopBody = loopOp.getBody();
657  Block *newLoopBody = newLoop.getBody();
658  rewriter.mergeBlocks(
659  loopBody, newLoopBody,
660  newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
661 
662  auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
663  rewriter.setInsertionPoint(terminator);
664  SmallVector<Value> tiledValues;
665  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
666  ValueRange regionIterArgs =
667  newLoop.getRegionIterArgs().take_back(newInitOperands.size());
668  if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
669  regionIterArgs, tiledValues, resultOffsets,
670  resultSizes))) {
671  rewriter.eraseOp(newLoop);
672  return rewriter.notifyMatchFailure(loopOp,
673  "failed to get yielded tiled values");
674  }
675 
676  // Update the terminator.
677  rewriter.setInsertionPointToEnd(terminator.getBody());
678 
679  for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
680  tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
681  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
682  rewriter.getIndexAttr(1));
683  rewriter.create<tensor::ParallelInsertSliceOp>(
684  terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
685  resultStride);
686  }
687 
688  rewriter.replaceOp(loopOp,
689  newLoop->getResults().take_front(loopOp.getNumResults()));
690  return cast<LoopLikeOpInterface>(newLoop.getOperation());
691 }
692 
693 /// Implementation of `yieldTiledValuesAndReplaceLoop` for
694 /// `LoopLikeOpInterface`, that just dispatches to the implementation for each
695 /// supported loop type.
696 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
697  LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
698  ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
700  loopLikeOp.getOperation())
701  .Case<scf::ForOp, scf::ForallOp>(
702  [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
704  loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
705  })
706  .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
707  return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
708  });
709 }
710 
711 /// Method to add new init values to a loop nest. Updates `loops` in-place with
712 /// new loops that use the `newInitValues`.
713 /// The outer-loops are updated to yield the new result values of the inner
714 /// loop. For the innermost loop, the call back `getNewYields` is invoked to get
715 /// the additional values to yield form the innermost loop.
716 static LogicalResult addInitOperandsToLoopNest(
718  ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
719  SmallVector<scf::ForOp> newLoops;
720  if (loops.empty())
721  return success();
722  OpBuilder::InsertionGuard g(rewriter);
723  rewriter.setInsertionPoint(loops.front());
724 
725  SmallVector<Value> ivs;
726  for (auto &loop : loops.drop_back()) {
727  rewriter.setInsertionPoint(loop);
728 
729  // if loops.size() > 1 we assume that scf.for is used for the loops.
730  auto forLoop = cast<scf::ForOp>(loop.getOperation());
731 
732  // Create a new loop with the new init values for this loop.
733  SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs());
734  newInits.append(newInitValues.begin(), newInitValues.end());
735  auto newLoop = rewriter.create<scf::ForOp>(
736  forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
737  forLoop.getStep(), newInits,
738  [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
739 
740  // Merge the body of the new loop with the body of the old loops.
741  SmallVector<Value> sourceBlockArgs;
742  sourceBlockArgs.push_back(newLoop.getInductionVar());
743  auto newRegionIterArgs = newLoop.getRegionIterArgs();
744  sourceBlockArgs.append(
745  newRegionIterArgs.begin(),
746  std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
747  rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
748  rewriter.replaceOp(
749  forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
750  loop = newLoop;
751  ivs.push_back(newLoop.getInductionVar());
752  newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
753  }
754 
755  // Update the loop body of the innermost loop to get new yield values.
756  LoopLikeOpInterface innerMostLoop = loops.back();
757  FailureOr<LoopLikeOpInterface> newInnerMostLoop =
758  yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues,
759  getNewTiledYieldsFn);
760 
761  if (failed(newInnerMostLoop))
762  return innerMostLoop.emitOpError("failed to return additional yields");
763  loops.back() = newInnerMostLoop.value();
764 
765  // Make all other loops except the innermost loops yield the values returned
766  // by the inner loop.
767  for (auto [outerLoop, innerLoop] :
768  llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
769  // Again assume that all the outer loops are scf.for operations.
770  auto outerForLoop = cast<scf::ForOp>(outerLoop);
771  auto outerLoopYield =
772  cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
773  SmallVector<Value> newYields =
774  llvm::to_vector(outerLoopYield.getOperands());
775  ValueRange additionalYields =
776  innerLoop->getResults().take_back(newInitValues.size());
777  newYields.append(additionalYields.begin(), additionalYields.end());
778  rewriter.setInsertionPoint(outerLoopYield);
779  rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
780  }
781  return success();
782 }
783 
784 /// Implementation of tiling transformation of `op` that implements the
785 /// `TilingInterface` using `scf.for` to iterate over the tiles.
786 FailureOr<scf::SCFTilingResult>
787 mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
789  if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) {
790  return failure();
791  }
792 
793  OpBuilder::InsertionGuard guard(rewriter);
794  rewriter.setInsertionPointAfter(op);
795 
796  // 1. Get the range of the loops that are represented by the operation.
797  SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
798 
799  // 2. Materialize the tile sizes and/or number of threads;
800  SmallVector<OpFoldResult> tileSizes, numThreads;
801  std::tie(tileSizes, numThreads) =
802  getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options);
803 
804  // Check if it is safe to tile. This is hold over from previous iterations
805  // of tile to for-all. Consider dropping it.
807  checkSafeToTileToForall(op, tileSizes, numThreads);
808  }
809 
810  // 3. If there is an interchange specified, permute the iteration domain and
811  // the tile sizes.
812  SmallVector<int64_t> interchangeVector;
813  if (!options.interchangeVector.empty()) {
814  interchangeVector = fillInterchangeVector(options.interchangeVector,
815  iterationDomain.size());
816  assert(isPermutationVector(interchangeVector) &&
817  "expected interchange vector to be a permutation");
818 
819  applyPermutationToVector(iterationDomain, interchangeVector);
820  applyPermutationToVector(tileSizes, interchangeVector);
821  if (!numThreads.empty())
822  applyPermutationToVector(numThreads, interchangeVector);
823  }
824 
825  FailureOr<TilingResult> tilingResult;
826  // 4. Define the lambda function used later to generate the body of the
827  // innermost tiled loop.
828  YieldTiledValuesFn innerYieldTiledValuesFn =
829  [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
830  ValueRange regionIterArgs, SmallVector<Value> &tiledResults,
833  -> LogicalResult {
834  // 4a. Compute the `offsets` and `sizes` to use for tiling.
835  SmallVector<OpFoldResult> offsets, sizes;
836  std::tie(offsets, sizes) = getTileOffsetAndSizes(
837  rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
838 
839  // 4b. If interchange was provided, apply inverse of the interchange
840  // to get back the offsets/sizes in the order to be specified.
841  if (!interchangeVector.empty()) {
842  auto inversePermutation = invertPermutationVector(interchangeVector);
845  }
846 
847  // 5. Generate the tiled implementation within the inner most loop.
848 
849  // 5a. Clone the operation within the loop body.
850  auto clonedOp = cast<TilingInterface>(
851  cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
852 
853  // 5b. Early return cloned op if tiling is not happening. We can not return
854  // the original op because it could lead to
855  // `rewriter.replaceOp(op, op->getResults())` and users would get crash.
856  if (llvm::all_of(tileSizes, isZeroIndex)) {
857  tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
858  tilingResult =
859  TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
860  /*generatedSlices=*/{}};
861  return success();
862  }
863 
864  // 5c. Tile the cloned operation.
865  tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes);
866  if (failed(tilingResult)) {
867  rewriter.eraseOp(clonedOp);
868  return op.emitOpError("faild to tile operation");
869  }
870 
871  // 5d. Delete the cloned operation.
872  rewriter.eraseOp(clonedOp);
873 
874  // 5e. Compute the offsets at which the result values are to be inserted
875  // back into its destinations.
876  for (auto [index, tiledValue] :
877  llvm::enumerate(tilingResult->tiledValues)) {
878  tiledResults.push_back(tiledValue);
879  SmallVector<OpFoldResult> resultOffset, resultSize;
880  if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
881  resultOffset, resultSize))) {
882  for (auto op : tilingResult->tiledOps) {
883  rewriter.eraseOp(op);
884  }
885  return rewriter.notifyMatchFailure(
886  op, "failed to get slice of result produced");
887  }
888  resultOffsets.emplace_back(std::move(resultOffset));
889  resultSizes.emplace_back(std::move(resultSize));
890  }
891 
892  return success();
893  };
894 
895  // 6. Find the destination tensors to use for the operation.
896  SmallVector<Value> destinationTensors;
897  if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
898  destinationTensors))) {
899  return rewriter.notifyMatchFailure(op,
900  "unable to create destination tensors");
901  }
902 
903  // 7. Generate the tiled loops nest using the callback defined above.
905  if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
906  tileSizes, numThreads, destinationTensors,
907  innerYieldTiledValuesFn, loops)))
908  return op.emitOpError("failed to generate tiling loops");
909  assert(succeeded(tilingResult) &&
910  "expected tiling result to be computed after loop generation");
911 
912  // If loops are empty, the tiled op is used as the replacement for the untiled
913  // op.
914  if (loops.empty()) {
915  return scf::SCFTilingResult{tilingResult->tiledOps, loops,
916  tilingResult->tiledValues,
917  tilingResult->generatedSlices};
918  }
919 
920  SmallVector<Value> replacements = llvm::map_to_vector(
921  loops.front()->getResults(), [](OpResult r) -> Value { return r; });
922  return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements,
923  tilingResult->generatedSlices};
924 }
925 
926 FailureOr<scf::SCFReductionTilingResult>
928  PartialReductionOpInterface op,
929  ArrayRef<OpFoldResult> tileSizes) {
930  Location loc = op.getLoc();
931  // Ops implementing PartialReductionOpInterface are expected to implement
932  // TilingInterface.
933  auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
934  SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
935  auto tileSizesVector = llvm::to_vector(tileSizes);
936  if (tileSizesVector.size() < iterationDomain.size()) {
937  auto zero = b.getIndexAttr(0);
938  tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
939  zero);
940  }
942  tilingInterfaceOp.getLoopIteratorTypes();
943 
944  SmallVector<int> reductionDims;
945  for (auto [idx, iteratorType] :
946  llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
947  if (iteratorType == utils::IteratorType::reduction)
948  reductionDims.push_back(idx);
949  }
950 
951  // 2. create the inital tensor value.
952  FailureOr<SmallVector<Value>> maybeInitTensors =
953  op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
954  reductionDims);
955  if (failed(maybeInitTensors)) {
956  return b.notifyMatchFailure(op, "Failed to create initial tensors.");
957  }
958  SmallVector<Value> &initTensors = maybeInitTensors.value();
959 
960  // 3. Define the callback to use for generating the inner most tile loop body.
961  SmallVector<Operation *> parallelTiledOps;
962  auto innerYieldTiledValuesFn =
963  [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
964  ValueRange regionIterArgs, SmallVector<Value> &tiledResult,
967  -> LogicalResult {
968  SmallVector<OpFoldResult> offsets, sizes;
969  {
970  int materializedLoopNum = 0;
971  for (auto [tileSize, loopRange] :
972  llvm::zip_equal(tileSizesVector, iterationDomain)) {
973  if (isConstantIntValue(tileSize, 0)) {
974  offsets.push_back(loopRange.offset);
975  sizes.push_back(loopRange.size);
976  continue;
977  }
978  Value iv = ivs[materializedLoopNum++];
979  offsets.push_back(iv);
980  sizes.push_back(
981  getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
982  }
983  }
984 
985  // 4a. Clone the operation.
986  {
987  auto clonedOp = cast<PartialReductionOpInterface>(
988  cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
989 
990  // 4b. Tile the cloned operation.
991  FailureOr<TilingResult> partialTilingResult =
992  clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets,
993  sizes, reductionDims);
994  if (failed(partialTilingResult)) {
995  return failure();
996  }
997  std::swap(parallelTiledOps, partialTilingResult->tiledOps);
998  std::swap(tiledResult, partialTilingResult->tiledValues);
999 
1000  // 4c. Delete the cloned operation.
1001  b.eraseOp(clonedOp);
1002  }
1003 
1004  // 4d. Compute the offsets and sizes needed to insert the result of the
1005  // tiled value back into destination before yielding the destination.
1006  for (auto result : tiledResult) {
1007  SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
1008  resultOffsets.emplace_back(std::move(outOffsets));
1009 
1010  SmallVector<OpFoldResult> outSizes;
1011  for (size_t i = 0; i < offsets.size(); i++) {
1012  outSizes.push_back(tensor::getMixedSize(b, loc, result, i));
1013  }
1014  resultSizes.emplace_back(std::move(outSizes));
1015  }
1016  return success();
1017  };
1018 
1019  // 5. Generate the tiled implementation using the destination tensors.
1023  if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
1024  /*numThreads=*/ArrayRef<OpFoldResult>{},
1025  initTensors, innerYieldTiledValuesFn, loops)))
1026  return b.notifyMatchFailure(op, "failed to tile for parallel reduction");
1027 
1028  SmallVector<Value> replacements = llvm::map_to_vector(
1029  loops.front()->getResults(), [](OpResult r) -> Value { return r; });
1030 
1031  // 5. Apply the merge reduction to combine all the partial values.
1032  b.setInsertionPointAfter(*loops.begin());
1033  FailureOr<MergeResult> mergeResult =
1034  op.mergeReductions(b, loc, replacements, reductionDims);
1035  if (failed(mergeResult)) {
1036  return failure();
1037  }
1038  b.replaceOp(op, mergeResult->replacements);
1039 
1040  SCFReductionTilingResult reductionTilingResult;
1041  std::swap(reductionTilingResult.parallelTiledOps, parallelTiledOps);
1042  std::swap(reductionTilingResult.mergeOps, mergeResult->mergeOps);
1043  std::swap(reductionTilingResult.initialValues, initTensors);
1044  std::swap(reductionTilingResult.loops, loops);
1045  std::swap(reductionTilingResult.replacements, mergeResult->replacements);
1046 
1047  return reductionTilingResult;
1048 }
1049 
1050 //===----------------------------------------------------------------------===//
1051 // tileConsumerAndFuseProducersUsingSCF implementation.
1052 //===----------------------------------------------------------------------===//
1053 
1054 /// Return the untiled producer whose slice is used in a tiled consumer. The
1055 /// method traverses the tile loop nest (`loops`) if needed, and returns the
1056 /// `iter_args` of the outer most that is encountered. Traversing the iter_args
1057 /// indicates that this is a destination operand of the consumer. If there was
1058 /// no loop traversal needed, the second value of the returned tuple is empty.
1059 static std::tuple<OpResult, std::optional<OpOperand *>>
1062  std::optional<OpOperand *> destinationIterArg;
1063  auto loopIt = loops.rbegin();
1064  while (auto iterArg = dyn_cast<BlockArgument>(source->get())) {
1065  auto loop = *loopIt;
1066  if (iterArg.getOwner()->getParentOp() != loop)
1067  break;
1068  source = loop.getTiedLoopInit(iterArg);
1069  loopIt++;
1070  }
1071  if (loopIt == loops.rend())
1072  destinationIterArg = source;
1073  return {dyn_cast<OpResult>(source->get()), destinationIterArg};
1074 }
1075 
1076 /// Implementation of fusing producer of a single slice by computing the
1077 /// slice of the producer in-place.
1078 std::optional<scf::SCFFuseProducerOfSliceResult>
1080  RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1082  // 1. Get the producer of the source (potentially walking through
1083  // `iter_args` of nested `scf.for`)
1084  auto [fusableProducer, destinationInitArg] =
1085  getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
1086  loops);
1087  if (!fusableProducer)
1088  return std::nullopt;
1089  unsigned resultNumber = fusableProducer.getResultNumber();
1090 
1091  OpBuilder::InsertionGuard g(rewriter);
1092  rewriter.setInsertionPoint(candidateSliceOp);
1093 
1094  // 2. Clone the fused producer
1095  // 2a. Compute the destination operands to use for the cloned operation.
1096  SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
1097  Operation *fusableProducerOp = fusableProducer.getOwner();
1098  if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
1100  rewriter, fusableProducerOp->getLoc(), fusableProducerOp,
1101  origDestinationTensors)))
1102  return std::nullopt;
1103 
1104  clonedOpDestinationTensors = origDestinationTensors;
1105  if (destinationInitArg &&
1106  isa<DestinationStyleOpInterface>(fusableProducerOp)) {
1107  // 2b. If the producer is also destination style, then to maintain the
1108  // destination passing style, update the destination of the producer to be
1109  // the source of the slice.
1110  clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1111  }
1112  // 2c. Clone the fused producer.
1113  Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
1114  rewriter, fusableProducerOp, clonedOpDestinationTensors);
1115  // 2d. Update the source of the candidateSlice to be the cloned producer.
1116  // Easier to just clone the slice with different source since replacements
1117  // and DCE of cloned ops becomes easier
1118  SmallVector<Value> candidateSliceOpOperands =
1119  llvm::to_vector(candidateSliceOp->getOperands());
1120  candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
1121  tensor::ExtractSliceOp clonedCandidateSliceOp =
1122  mlir::clone(rewriter, candidateSliceOp,
1123  candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1124 
1125  // 3. Generate the tiled implementation of the producer of the source
1126  FailureOr<TilingResult> tileAndFuseResult =
1128  rewriter, clonedCandidateSliceOp,
1129  clonedProducerOp->getResult(resultNumber));
1130  if (failed(tileAndFuseResult))
1131  return std::nullopt;
1132  // Note: Do not delete the candidateSliceOp, since its passed in from the
1133  // caller.
1134  rewriter.replaceAllUsesWith(candidateSliceOp,
1135  tileAndFuseResult->tiledValues[0]);
1136  rewriter.eraseOp(clonedCandidateSliceOp);
1137  rewriter.eraseOp(clonedProducerOp);
1138 
1139  // 3. If the slice is for a destination operand, for example,
1140  //
1141  // ```mlir
1142  // %0 = linalg.init
1143  // %1 = linalg.fill .. outs(%0 : )
1144  // %2 = scf.for .. iter_args(%arg0 = %1) {
1145  // %3 = scf.for .. iter_args(%arg1 = %arg0) {
1146  // %4 = tensor.extract_slice %arg1 [..]
1147  // .. = linalg.matmul .. outs(%4 : )
1148  // }
1149  // }
1150  // ```
1151  //
1152  // the IR is currently
1153  //
1154  // ```
1155  // %0 = linalg.init
1156  // %1 = linalg.fill
1157  // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
1158  // %3 = scf.for .. iter_args(%arg1 = %arg0) {
1159  // %4 = tensor.extract_slice %arg1[..]
1160  // %5 = linalg.fill .. outs(%4 : )
1161  // .. = linalg.matmul .. outs(%5 : )
1162  // }
1163  // }
1164  // ```
1165  //
1166  // The untiled `linalg.fill` is still used as the `init_value` since it
1167  // was originally a destination operand of the untiled `linalg.matmul`.
1168  // When fusing an operand that is a destination operand, the iter_arg of
1169  // the outer most loop should be changed to use the destination of the
1170  // fused operation. With this the IR will be.
1171  //
1172  // ```
1173  // %0 = linalg.init
1174  // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
1175  // %2 = scf.for .. iter_args(%arg1 = %arg0) {
1176  // %3 = tensor.extract_slice %arg1[..]
1177  // %4 = linalg.fill .. outs(%3 : )
1178  // .. = linalg.matmul .. outs(%4 : )
1179  // }
1180  // }
1181  // ```
1182  if (destinationInitArg &&
1183  isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
1184  loops.front()
1185  ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1186  .set(origDestinationTensors[resultNumber]);
1187  }
1189  fusableProducer, tileAndFuseResult->tiledValues[0],
1190  tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices};
1191 }
1192 
1193 /// Reconstruct the fused producer from within the tiled-and-fused code.
1194 FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
1195  RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1196  scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
1198  ArrayRef<unsigned> yieldResultNumber) {
1199  if (loops.empty())
1200  return success();
1201 
1202  Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
1203  *tiledOwner = fusedProducerInfo.tiledOps[0];
1204 
1205  Location loc = originalOwner->getLoc();
1206  // a. collect all init Value to be appended
1207  SmallVector<unsigned> initNumberList =
1208  yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
1209  0, originalOwner->getNumResults()))
1210  : llvm::to_vector(yieldResultNumber);
1211  SmallVector<Value> initValueList;
1212  for (const auto &resultNumber : initNumberList) {
1213  FailureOr<Value> initValue = tensor::getOrCreateDestination(
1214  rewriter, loc, originalOwner->getResult(resultNumber));
1215  if (succeeded(initValue)) {
1216  initValueList.push_back(initValue.value());
1217  } else {
1218  return failure();
1219  }
1220  }
1221 
1222  SmallVector<Operation *> generatedSlices;
1223  YieldTiledValuesFn newYieldValuesFn =
1224  [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
1225  ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
1227  SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
1228  OpBuilder::InsertionGuard g(innerRewriter);
1229 
1230  // get sliceOp tile information
1231  SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
1232  sliceSizes = sliceOp.getMixedSizes();
1233 
1234  // expect all strides of sliceOp being 1
1235  if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
1236  return !isConstantIntValue(ofr, 1);
1237  }))
1238  return failure();
1239 
1240  unsigned sliceResultNumber =
1241  fusedProducerInfo.origProducer.getResultNumber();
1242 
1243  auto tilableOp = cast<TilingInterface>(originalOwner);
1244  // b. get iterDomain Offset and Sizes based on sliceOp tile
1245  SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
1246  // skip tensor.pack/unpack/pad, which expects single opResult
1247  if (tilableOp->getNumResults() > 1 &&
1248  failed(tilableOp.getIterationDomainTileFromResultTile(
1249  rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1250  iterDomainOffset, iterDomainSizes))) {
1251  // In theory, it is unnecessary to raise an error here. Actually although
1252  // it fails to reconstruct the result tensor, it should not broke current
1253  // fusion anyway. The reason why we must return failure currently is that
1254  // the callback function `newYieldValuesFn` will be called after new init
1255  // operand(s) has already been appended. It will take more refactoring to
1256  // make sure the init operands are added consistently in the future. For
1257  // more details, please refer to:
1258  // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
1259  return failure();
1260  }
1261 
1262  // c. calculate offsets and sizes info of all OpResults respectively based
1263  // on iteration Domain Tile
1264  SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
1265  for (const auto &resultNumber : initNumberList) {
1266  if (resultNumber == sliceResultNumber) {
1267  offsetList.push_back(sliceOffset);
1268  sizesList.push_back(sliceSizes);
1269  } else {
1270  assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1271  // infer result tile according to the iteration domain tile
1272  SmallVector<OpFoldResult> offset, sizes;
1273  if (failed(tilableOp.getResultTilePosition(
1274  rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1275  offset, sizes))) {
1276  return failure();
1277  }
1278  offsetList.push_back(offset);
1279  sizesList.push_back(sizes);
1280  }
1281  }
1282 
1283  // d. create `extract_slice` for `iter_args` for DPS operation if necessary
1284  if (auto tiledDestStyleOp =
1285  dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1286  rewriter.setInsertionPoint(tiledDestStyleOp);
1287  for (const auto &&[index, newRegionArg] :
1288  llvm::enumerate(newRegionIterArgs)) {
1289  auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
1290  loc, newRegionArg, offsetList[index], sizesList[index],
1291  SmallVector<OpFoldResult>(offsetList[index].size(),
1292  rewriter.getIndexAttr(1)));
1293  generatedSlices.push_back(destSlice);
1294  unsigned resultNumber = initNumberList[index];
1295  rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
1296  tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1297  });
1298  }
1299  }
1300 
1301  // e. prepare tiled offset and sizes for later `insert_slice` creation by
1302  // caller
1303  Block *block = rewriter.getInsertionPoint()->getBlock();
1304  rewriter.setInsertionPoint(block->getTerminator());
1305  for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
1306  tiledResult.push_back(tiledOwner->getResult(resultNumber));
1307  tiledOffset.emplace_back(offsetList[index]);
1308  tiledSizes.emplace_back(sizesList[index]);
1309  }
1310  return success();
1311  };
1312 
1313  if (failed(addInitOperandsToLoopNest(rewriter, loops, initValueList,
1314  newYieldValuesFn))) {
1315  return failure();
1316  }
1317  return generatedSlices;
1318 }
1319 
1320 namespace {
1321 
1322 //===----------------------------------------------------------------------===//
1323 // SliceTrackingListener
1324 //===----------------------------------------------------------------------===//
1325 
1326 /// This class is a listener for tracking the insertion and removal of
1327 /// `tensor.extract_slice` ops in a worklist. This can be used in a greedy
1328 /// fusion algorithm to apply cleanup patterns in between fusion steps.
1329 class SliceTrackingListener : public RewriterBase::Listener {
1330 public:
1331  explicit SliceTrackingListener(
1332  std::optional<FrozenRewritePatternSet> patterns);
1333  SliceTrackingListener() = default;
1334 
1335  /// Adds the given list of operations to the worklist, and if present, applies
1336  /// the list of `patterns` to the newly added operations. This only processes
1337  /// the given operations and any newly inserted ones by the pattern set.
1338  LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
1339 
1340  /// Add to the new operation worklist if it is an extract_slice.
1341  void notifyOperationInserted(Operation *op,
1342  OpBuilder::InsertPoint previous) override;
1343 
1344  /// Shared helper for operation removal from the worklist.
1345  void removeOp(Operation *op);
1346 
1347  /// Remove the operation from the worklist.
1348  void notifyOperationErased(Operation *op) override;
1349 
1350  /// Remove the operation from the worklist.
1351  void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
1352 
1353  /// The worklist for this transformation keeps track of the slices to visit
1354  /// next for fusion.
1355  std::deque<tensor::ExtractSliceOp> worklist;
1356 
1357 private:
1358  /// Optional pattern set to apply when adding new operations to the worklist.
1359  std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
1360 };
1361 
1362 SliceTrackingListener::SliceTrackingListener(
1363  std::optional<FrozenRewritePatternSet> p) {
1364  patterns = std::move(p);
1365 }
1366 
1367 LogicalResult
1368 SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
1369  for (Operation *op : ops) {
1370  if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1371  worklist.push_back(slice);
1372  }
1373 
1374  if (!patterns)
1375  return success();
1376 
1377  GreedyRewriteConfig config;
1378  config.listener = this;
1380  return applyOpPatternsAndFold(ops, patterns.value(), config);
1381 }
1382 
1383 void SliceTrackingListener::notifyOperationInserted(
1384  Operation *op, OpBuilder::InsertPoint previous) {
1385  auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1386  if (!slice)
1387  return;
1388  worklist.push_back(slice);
1389 }
1390 
1391 // Scan the worklist for the given op and remove it if present. The expectation
1392 // is for the worklist to be small and for removal to be relatively rare.
1393 void SliceTrackingListener::removeOp(Operation *op) {
1394  if (!isa<tensor::ExtractSliceOp>(op))
1395  return;
1396  auto iter = worklist.begin();
1397  while (iter != worklist.end()) {
1398  if (*iter == op)
1399  break;
1400  iter++;
1401  }
1402  if (iter == worklist.end())
1403  return;
1404 
1405  worklist.erase(iter);
1406 }
1407 
1408 void SliceTrackingListener::notifyOperationErased(Operation *op) {
1409  removeOp(op);
1410 }
1411 
1412 void SliceTrackingListener::notifyOperationReplaced(Operation *op,
1413  ValueRange replacement) {
1414  removeOp(op);
1415 }
1416 } // namespace
1417 
1418 /// Implementation of tile consumer and fuse producer greedily.
1419 FailureOr<scf::SCFTileAndFuseResult>
1421  RewriterBase &rewriter, TilingInterface consumer,
1423  // This transformation is only valid for ops that return values (i.e. not
1424  // valid to use with operations that have memref operands).
1425  if (!consumer->getNumResults()) {
1426  return rewriter.notifyMatchFailure(
1427  consumer, "invalid pattern for op with no results");
1428  }
1429 
1430  // 1. First tile the consumer.
1431  SetVector<Operation *> fusedProducers, tiledAndFusedOps;
1432  llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
1433 
1434  FailureOr<scf::SCFTilingResult> tilingResult =
1435  tileUsingSCF(rewriter, consumer, options.tilingOptions);
1436 
1437  if (failed(tilingResult))
1438  return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
1439  for (auto *tiledOp : tilingResult->tiledOps)
1440  tiledAndFusedOps.insert(tiledOp);
1441 
1442  // If there are no loops generated, fusion is immaterial.
1443  auto &loops = tilingResult->loops;
1444  if (loops.empty()) {
1445  DenseMap<Value, Value> replacements;
1446  for (auto [origVal, replacement] :
1447  llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
1448  replacements[origVal] = replacement;
1449  }
1450  return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1451  replacements};
1452  }
1453 
1454  // To keep track of replacements for now just record the map from the original
1455  // untiled value to the result number of the for loop. Since the loop gets
1456  // potentially replaced during fusion, keeping the value directly wont work.
1457  DenseMap<Value, size_t> origValToResultNumber;
1458  for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
1459  origValToResultNumber[result] = index;
1460  }
1461 
1462  // 2. Typically, the operands of the tiled operation are slices of the
1463  // operands of the untiled operation. These are expressed in IR using
1464  // `tensor.extract_slice` operations with source being the operands of the
1465  // untiled operation. Create a worklist of these `tensor.extract_slice`
1466  // operations. If the producers of the source of the `tensor.extract_slice`
1467  // can be tiled such that the tiled value is generated in-place, that
1468  // effectively tiles + fuses the operations.
1469  struct WorklistItem {
1470  tensor::ExtractSliceOp candidateSlice;
1472  };
1473 
1474  SliceTrackingListener sliceTracker =
1475  SliceTrackingListener(options.cleanupPatterns);
1476 
1477  if (failed(
1478  sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1479  return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1480  }
1481  OpBuilder::InsertionGuard g(rewriter);
1482  while (!sliceTracker.worklist.empty()) {
1483  auto candidateSlice = sliceTracker.worklist.front();
1484  sliceTracker.worklist.pop_front();
1485 
1486  auto [fusableProducer, destinationInitArg] =
1487  getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(),
1488  loops);
1489  if (!fusableProducer)
1490  continue;
1491 
1492  std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1493  options.fusionControlFn(candidateSlice, fusableProducer,
1494  destinationInitArg.has_value());
1495  if (!controlFnResult)
1496  continue;
1497 
1498  WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
1499 
1500  // The operands of the fused producer might themselved be slices of
1501  // values produced by operations that implement the `TilingInterface`.
1502  // Add these operations to the worklist.
1503  std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1504  tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice,
1505  loops);
1506  if (!fusedResult)
1507  continue;
1508 
1509  SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
1510 
1511  if (worklistItem.controlFnResult.yieldProducerReplacement) {
1512  // Reconstruct and yield all opResult of fusableProducerOp by default. The
1513  // caller can specific which one to yield by designating optional argument
1514  // named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
1515  Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1516  FailureOr<SmallVector<Operation *>> newSlices =
1518  worklistItem.candidateSlice,
1519  fusedResult.value(), loops);
1520  if (failed(newSlices)) {
1521  return rewriter.notifyMatchFailure(
1522  fusableProducerOp, "failed to replacement value for this "
1523  "operation from within the tiled loop");
1524  }
1525  worklistCandidates.append(newSlices.value());
1526  for (auto [index, result] :
1527  llvm::enumerate(fusableProducerOp->getResults())) {
1528  origValToResultNumber[result] = loops.front()->getNumResults() -
1529  fusableProducerOp->getNumResults() +
1530  index;
1531  }
1532  }
1533  if (Operation *tiledAndFusedOp =
1534  fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1535  fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1536  tiledAndFusedOps.insert(tiledAndFusedOp);
1537  }
1538 
1539  if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1540  return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1541  }
1542  }
1543 
1544  DenseMap<Value, Value> replacements;
1545  for (auto [origVal, resultNumber] : origValToResultNumber) {
1546  replacements[origVal] = loops.front()->getResult(resultNumber);
1547  }
1548 
1549  return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1550  replacements};
1551 }
1552 
1553 //===----------------------------------------------------------------------===//
1554 // tileAndFuseConsumerUsingSCF implementation.
1555 //===----------------------------------------------------------------------===//
1556 
1557 /// A utility function that checks whether the only use of the result of a
1558 /// tensor.insert_slice op is in a scf.yield op.
1559 static LogicalResult
1560 checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
1561  Value result = candidateSliceOp.getResult();
1562  Value::use_range uses = result.getUses();
1563  if (!llvm::hasSingleElement(uses)) {
1564  LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
1565  return failure();
1566  }
1567  OpOperand &operandUse = (*uses.begin());
1568  Operation *userOp = operandUse.getOwner();
1569  if (!isa<scf::YieldOp>(userOp)) {
1570  LLVM_DEBUG(llvm::dbgs()
1571  << "Expected scf.yield to be the only user, but got -> "
1572  << (*userOp));
1573  return failure();
1574  }
1575  if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
1576  LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
1577  "be in the same block\n");
1578  return failure();
1579  }
1580  return success();
1581 }
1582 
1583 /// Fetches the OpOperand of the only user (and use) of the value `val` which
1584 /// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
1585 /// failure otherwise.
1586 static FailureOr<OpOperand *> getConsumerFromUses(Value val,
1587  Block *containingOpBlock) {
1588  // Check that the value has exactly one use which isn't a scf.yield or a
1589  // tensor.parallel_insert_slice op.
1590  OpOperand *operand = nullptr;
1591  for (OpOperand &opOperand : val.getUses()) {
1592  Operation *consumerOp = opOperand.getOwner();
1593  if (isa<scf::YieldOp, tensor::ParallelInsertSliceOp>(consumerOp))
1594  continue;
1595  if (operand)
1596  return failure();
1597  // TODO: We have to init result of consumer before scf.for, use
1598  // DestinationStyleOpInterface to get result shape from init for now.
1599  // Add support for other op such as op has InferTypeOpInterface.
1600  if (!isa<TilingInterface>(consumerOp) ||
1601  !isa<DestinationStyleOpInterface>(consumerOp))
1602  return failure();
1603  if (containingOpBlock != consumerOp->getBlock())
1604  return failure();
1605  operand = &opOperand;
1606  }
1607 
1608  if (operand)
1609  return operand;
1610  return failure();
1611 }
1612 
1613 /// Find the perfectly nested loops outside of given loop(included) sorted from
1614 /// outer to inner.
1615 ///
1616 /// E.g.
1617 ///
1618 /// ```
1619 /// %0 = scf.for()
1620 /// %1 = scf.for()
1621 /// %2 = scf.for()
1622 /// %3 = ...
1623 /// yield %3
1624 /// yield %2
1625 /// yield %1
1626 /// ```
1627 ///
1628 /// This function will return three perfectly nested loops: %0 + %1 + %2, when
1629 /// target inner loop is %2.
1632  SmallVector<scf::ForOp> nestLoops = {loop};
1633  auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp());
1634 
1635  // Check if it is the ForOp that yield the result of inner loop.
1636  auto isForOpYieldResultOfInnerLoop =
1637  [](scf::ForOp outerLoop) -> LogicalResult {
1638  Block *body = outerLoop.getBody();
1639  if (!llvm::hasSingleElement(body->without_terminator()))
1640  return failure();
1641  auto yieldOp = cast<scf::YieldOp>(body->getTerminator());
1642  auto innerForOp = dyn_cast<scf::ForOp>(body->front());
1643  if (!innerForOp)
1644  return failure();
1645  // All of innerForOp results should be yielded.
1646  return success(innerForOp->getNumResults() == yieldOp->getNumOperands());
1647  };
1648 
1649  while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) {
1650  nestLoops.push_back(outerLoop);
1651  outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp());
1652  }
1653  // sorted from outer to inner
1654  return {nestLoops.rbegin(), nestLoops.rend()};
1655 }
1656 
1657 /// Fetch the untiled consumer of a scf.for's result which is yielded by a
1658 /// tensor.insert_slice. This function makes the following assumptions :
1659 /// 1. tensor.insert_slice has scf.yield as its only user.
1660 /// 2. scf.for's corresponding result has only one use.
1661 static FailureOr<OpOperand *>
1662 getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
1663  if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
1664  return failure();
1665  Value sliceResult = candidateSliceOp.getResult();
1666  // Step 1. Fetch the corresponding output.
1667  OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
1668  unsigned resultNumber = yieldOpOperand.getOperandNumber();
1669  // Step 2. Check containing op is scf.for.
1670  Operation *containingOp = candidateSliceOp->getParentOp();
1671  auto forOp = dyn_cast<scf::ForOp>(containingOp);
1672  if (!forOp)
1673  return failure();
1674  scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front();
1675  Value resultingValue = topLevelForOp->getResult(resultNumber);
1676 
1677  return getConsumerFromUses(resultingValue, topLevelForOp->getBlock());
1678 }
1679 
1680 /// Fetch the first untiled consumer of a scf.forall's result which is yielded
1681 /// by a tensor.parallel_insert_slice.
1682 static FailureOr<OpOperand *>
1683 getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
1684  // Step 1. Fetch the corresponding output
1685  Value sliceDest = candidateSliceOp.getDest();
1686  auto iterArg = dyn_cast<BlockArgument>(sliceDest);
1687  if (!iterArg)
1688  return failure();
1689  Operation *containingOp = iterArg.getOwner()->getParentOp();
1690  if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
1691  return failure();
1692  // Step 2. Check that the containing op is scf.forall.
1693  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
1694  if (!forallOp)
1695  return failure();
1696  Value resultingValue =
1697  forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
1698 
1699  return getConsumerFromUses(resultingValue, containingOp->getBlock());
1700 }
1701 
1702 /// This utility currently checks whether the loop either :-
1703 /// 1. Yields exactly one result.
1704 /// 2. Has consumer op as its first user and other users to be in the same
1705 /// containing block as that of consumer op's. Currently we clone the loop op
1706 /// right before the consumer op in order to maintain a valid def-use chain.
1707 /// This utility thus helps ensuring that no invalid IR is formed due to the
1708 /// same.
1709 static LogicalResult checkAssumptionForLoop(Operation *loopOp,
1710  Operation *consumerOp) {
1711  // Check if the loop op yields one result.
1712  if (loopOp->getNumResults() == 1)
1713  return success();
1714  // Check if the consumerOp is the first user of the loopOp and if other users
1715  // are in the same containing block as that of consumer op's.
1716  Block *parentBlock = consumerOp->getBlock();
1717  for (Operation *userOp : loopOp->getUsers()) {
1718  if (userOp == consumerOp)
1719  continue;
1720  if (parentBlock != userOp->getBlock() ||
1721  !consumerOp->isBeforeInBlock(userOp))
1722  return failure();
1723  }
1724  return success();
1725 }
1726 
1727 /// A utility to fetch an untiled consumer of
1728 /// tensor.insert_slice/tensor.parallel_insert_slice.
1729 static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
1730  if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1731  return getUntiledConsumerFromSlice(insertSlice);
1732  } else if (auto parallelInsertSlice =
1733  dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1734  return getUntiledConsumerFromSlice(parallelInsertSlice);
1735  } else {
1736  return failure();
1737  }
1738 }
1739 
1740 /// Implementation of fusing consumer of a single slice by computing the
1741 /// slice of the consumer in-place for scf loop.
1742 FailureOr<scf::SCFFuseConsumerOfSliceResult>
1744  Operation *candidateSliceOp) {
1745  if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
1746  candidateSliceOp))
1747  return failure();
1748 
1749  bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
1750 
1751  // 1. Get the consumer of scf.for for the result yielded by
1752  // tensor.insert_slice/parallel_insert_slice.
1753  FailureOr<OpOperand *> maybeConsumerOpOperand =
1754  getUntiledConsumerFromSlice(candidateSliceOp);
1755  if (failed(maybeConsumerOpOperand)) {
1756  return rewriter.notifyMatchFailure(candidateSliceOp,
1757  "could not fetch consumer to fuse");
1758  }
1759  OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
1760  Operation *consumerOp = consumerOpOperand->getOwner();
1761  unsigned operandNumber = consumerOpOperand->getOperandNumber();
1762  unsigned resultNumber = 0;
1763  if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
1764  resultNumber = producerResult.getResultNumber();
1765  } else {
1766  return rewriter.notifyMatchFailure(
1767  consumerOp, "consumer op's operand doesn't seem to be an OpResult");
1768  }
1769 
1770  // There are two possible cases regarding `oldLoopOp` here:
1771  // 1. single `scf.forall` or `scf.for`.
1772  // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
1773  // top-level loop is the outer-most one of these nested loops.
1774  LoopLikeOpInterface innerMostLoop =
1775  candidateSliceOp->getParentOfType<LoopLikeOpInterface>();
1777  if (isInsertSliceOp) {
1778  nestedLoops = llvm::map_to_vector(
1780  cast<scf::ForOp>(innerMostLoop.getOperation())),
1781  [](scf::ForOp forOp) {
1782  return cast<LoopLikeOpInterface>(forOp.getOperation());
1783  });
1784  } else {
1785  nestedLoops = {innerMostLoop};
1786  }
1787 
1788  LoopLikeOpInterface outerMostLoop = nestedLoops.front();
1789 
1790  if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp))) {
1791  return rewriter.notifyMatchFailure(
1792  outerMostLoop,
1793  "containing loop op should either yield just one value or "
1794  "have the consumer op as its first user");
1795  }
1796 
1797  OpBuilder::InsertionGuard g(rewriter);
1798 
1799  // 2. Check consumer is not using scf loop's output as init.
1800  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
1801  if (!dstOp)
1802  return rewriter.notifyMatchFailure(consumerOp,
1803  "consumer op is not DPS operation");
1804  SmallVector<Value> dpsInits =
1805  llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
1806  if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
1807  return rewriter.notifyMatchFailure(
1808  consumerOp,
1809  "consumer op taking the result of scf.for as init is not supported");
1810  }
1811  SmallVector<Value> newInits = dpsInits;
1812 
1813  Location loc = outerMostLoop->getLoc();
1814 
1815  // 3. Move the whole loop structure right before consumer Op, the dominance
1816  // should be already ensured by `checkAssumptionForLoop`.
1817  rewriter.moveOpBefore(outerMostLoop, consumerOp);
1818 
1819  // 4. Set insertion point before terminator op of the loop and create a new
1820  // tensor.insert_slice. In the scf.for case this is a clone of the
1821  // candidateSliceOp whereas in the scf.forall case this is created from the
1822  // operands of tensor.parallel_insert_slice.
1823  tensor::InsertSliceOp clonedInsertSliceOp;
1824  if (auto sliceOp =
1825  dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
1826  auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
1827  rewriter.setInsertionPoint(newForallOp.getTerminator());
1828  clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
1829  loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
1830  sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
1831  } else {
1832  rewriter.setInsertionPoint(candidateSliceOp);
1833  clonedInsertSliceOp =
1834  cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
1835  }
1836 
1837  // 5.a. Clone consumer op.
1838  auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
1839 
1840  // 5.b. Replace all uses of the loop result with the result of the cloned
1841  // tensor.insert_slice.
1842  OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
1843  rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
1844  operandToReplace.set(clonedInsertSliceOp.getResult());
1845  });
1846 
1847  // 6. Perform tiling of the cloned consumer and replace the operand at
1848  // `operandNumber` with the source of the cloned tensor.insert_slice op.
1849  auto ossSliceOp =
1850  cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
1851  FailureOr<TilingResult> tileAndFuseResult =
1853  rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
1854  if (failed(tileAndFuseResult)) {
1855  return failure();
1856  }
1857  auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
1858  rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
1859  clonedInsertSliceOp.getSource());
1860 
1861  // 7. Reconstruct [nested] loop with new inits.
1862  YieldTiledValuesFn newYieldValuesFn =
1863  [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
1864  ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
1866  SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
1867  OpBuilder::InsertionGuard g(innerRewriter);
1868  // 8. Set inner insertPoint right before tiled consumer op.
1869  innerRewriter.setInsertionPoint(tiledConsumerOp);
1870 
1871  SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
1872  SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
1873  SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
1874 
1875  // 9. Check all insert stride is 1.
1876  if (llvm::any_of(strides, [](OpFoldResult stride) {
1877  return !isConstantIntValue(stride, 1);
1878  })) {
1879  return rewriter.notifyMatchFailure(
1880  candidateSliceOp, "containingOp's result yield with stride");
1881  }
1882 
1883  // 10. Try to get iter domain position from input position.
1884  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
1885  if (failed(tiledConsumerOp.getIterationDomainTileFromOperandTile(
1886  rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
1887  iterDomainSizes))) {
1888  return rewriter.notifyMatchFailure(
1889  tiledConsumerOp,
1890  "can't get iter domain position from input position");
1891  }
1892 
1893  // 11. Try to fetch the offset and size for all results of the cloned
1894  // consumer. This would then be used to form the corresponding
1895  // tensor.insert_slice/parallel_insert_slice later.
1896  unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
1898  totalNumResultsOfConsumer);
1900  totalNumResultsOfConsumer);
1901  for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
1902  if (failed(tiledConsumerOp.getResultTilePosition(
1903  rewriter, idx, iterDomainOffsets, iterDomainSizes,
1904  resultOffsets[idx], resultSizes[idx]))) {
1905  return rewriter.notifyMatchFailure(
1906  tiledConsumerOp,
1907  "can't get result domain position from iter domain position");
1908  }
1909  }
1910 
1911  // 12. Create `extract_slice` for `iter_args` for DPS operation if
1912  // necessary.
1913  if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
1914  tiledConsumerOp.getOperation())) {
1915  rewriter.setInsertionPoint(tiledDestStyleOp);
1916  for (const auto &&[index, newRegionArg] :
1917  llvm::enumerate(newRegionIterArgs)) {
1918  auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
1919  loc, newRegionArg, resultOffsets[index], resultSizes[index],
1920  SmallVector<OpFoldResult>(resultOffsets[index].size(),
1921  rewriter.getIndexAttr(1)));
1922  // Make a copy of index to avoid a capturing structured binding, which
1923  // is a C++20 extension.
1924  auto dstNumber = index;
1925  rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
1926  tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
1927  });
1928  }
1929  }
1930 
1931  // 13. Prepare tiled offset and sizes for later `insert_slice` creation by
1932  // caller.
1933  Block *block = rewriter.getInsertionPoint()->getBlock();
1934  rewriter.setInsertionPoint(block->getTerminator());
1935  for (const auto &&[index, result] :
1936  llvm::enumerate(tiledConsumerOp->getResults())) {
1937  tiledResult.push_back(result);
1938  tiledOffset.emplace_back(resultOffsets[index]);
1939  tiledSizes.emplace_back(resultSizes[index]);
1940  }
1941  return success();
1942  };
1943  // 14. Add new inits to [nested] loops.
1944  if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits,
1945  newYieldValuesFn))) {
1946  return rewriter.notifyMatchFailure(tiledConsumerOp,
1947  "unable to add new inits to nest loop");
1948  }
1949 
1950  // 15. Replace the result of scf loop and consumer op with new loop's results.
1951 
1952  for (auto &&[oldResult, newResult] : llvm::zip(
1953  consumerOp->getResults(),
1954  nestedLoops.front()->getResults().take_back(newInits.size()))) {
1955  rewriter.replaceAllUsesWith(oldResult, newResult);
1956  }
1957 
1958  // 16. Need to erase the old scf loop and the cloned consumer op.
1959  rewriter.eraseOp(clonedConsumerOp);
1960 
1962  consumerOpOperand,
1963  &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
1964  tileAndFuseResult->tiledOps};
1965 }
1966 
1967 //===----------------------------------------------------------------------===//
1968 // lowerToLoopsUsingSCFForOp implementation.
1969 //===----------------------------------------------------------------------===//
1970 
1971 FailureOr<SmallVector<scf::ForOp>>
1973  TilingInterface op) {
1974  // TODO: Handle cases where the op has results if needed.
1975  if (op->getNumResults() > 0) {
1976  return rewriter.notifyMatchFailure(
1977  op, "unable to lower to loops operations with return values");
1978  }
1979 
1980  SmallVector<Range> domain = op.getIterationDomain(rewriter);
1981  SmallVector<Value> ivs;
1983  Location loc = op.getLoc();
1984  for (auto loopRange : domain) {
1985  Value offsetVal =
1986  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
1987  Value sizeVal =
1988  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
1989  Value strideVal =
1990  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
1991  auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
1992  strideVal, ValueRange{});
1993  loops.push_back(loop);
1994  ivs.push_back(loop.getInductionVar());
1995  rewriter.setInsertionPoint(loop.getBody()->getTerminator());
1996  }
1997  if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
1998  return failure();
1999  }
2000  return loops;
2001 }
static llvm::ManagedStatic< PassManagerOptions > options
static LogicalResult checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp)
This utility currently checks whether the loop either :-.
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static LogicalResult verifyTileSizeOptions(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options)
Verify the tile size options are set in a consistent manner.
static LogicalResult checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp)
A utility function that checks whether the only use of the result of a tensor.insert_slice op is in a...
std::function< LogicalResult(RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, SmallVector< Value > &tiledValues, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> YieldTiledValuesFn
A function that allows returning additional yielded values during yieldTiledValuesAndReplace.
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, OpFoldResult numThreads, OpFoldResult iterationSize)
Returns true if the maximum tile offset tileSize * numThreads-1 is less than iterationSize.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static FailureOr< OpOperand * > getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp)
Fetch the untiled consumer of a scf.for's result which is yielded by a tensor.insert_slice.
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, OpFoldResult offset, OpFoldResult tileSize)
Returns the bounded tile size given the current offset, loopRange and tileSize, i....
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Append the specified additional newInitOperands operands to the loops existing init operands (or simi...
static LogicalResult generateLoopNestUsingForOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ValueRange destinationTensors, YieldTiledValuesFn yieldTiledValuesFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.for operation.
static void checkSafeToTileToForall(TilingInterface op, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Checks if any of the tiled loops are not parallel.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, ArrayRef< Range > iterationDomain, const scf::SCFTilingOptions &options)
Method to instantiate the tile sizes and/or number of threads specified by the user.
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< LoopLikeOpInterface > loops)
Return the untiled producer whose slice is used in a tiled consumer.
static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using the loop construct specifed in options.
static SmallVector< scf::ForOp > getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop)
Find the perfectly nested loops outside of given loop(included) sorted from outer to inner.
static bool tileDividesIterationDomain(Range loopRange)
Check if stride evenly divides the trip count size - offset.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes)
Function to return the bounds of the loops to be generated.
static Operation * cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *op, ValueRange newDestArgs)
Clones the operation and updates the destination if the operation implements the DestinationStyleOpIn...
static LogicalResult addInitOperandsToLoopNest(RewriterBase &rewriter, MutableArrayRef< LoopLikeOpInterface > loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn)
Method to add new init values to a loop nest.
static LogicalResult generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ArrayRef< Attribute > mappingVector, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.forall operation.
static FailureOr< OpOperand * > getConsumerFromUses(Value val, Block *containingOpBlock)
Fetches the OpOperand of the only user (and use) of the value val which implements TilingInterface an...
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:964
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
Definition: Block.h:31
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:207
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:306
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteStrictness strictMode
Strict mode can restrict the ops that are added to the worklist during the rewrite.
RewriterBase::Listener * listener
An optional listener that should be notified about IR modifications.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class represents a saved insertion point.
Definition: Builders.h:335
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:453
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:444
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:420
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:466
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Definition: Operation.cpp:280
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
Definition: AffineOps.cpp:1305
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1298
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1192
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< scf::SCFReductionTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
FailureOr< scf::SCFFuseConsumerOfSliceResult > tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp)
Implementation of fusing consumer of a single slice by computing the slice of the consumer in-place f...
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< SmallVector< Operation * > > yieldReplacementForFusedProducer(RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef< LoopLikeOpInterface > loops, ArrayRef< unsigned > yieldResultNumber=ArrayRef< unsigned >{})
Reconstruct the fused producer from within the tiled-and-fused code.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
std::optional< SCFFuseProducerOfSliceResult > tileAndFuseProducerOfSlice(RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing producer of a single slice by computing the slice of the producer in-place.
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:55
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:74
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:109
FailureOr< TilingResult > replaceInsertSliceWithTiledConsumer(OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp, OpOperand &consumerOp)
Method to swap an tensor.insert_slice with its consumer when the consumer implements the TilingInterf...
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
LogicalResult applyOpPatternsAndFold(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
@ ExistingAndNewOps
Only pre-existing and newly created ops are processed.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
Container for result values of tiling.
Fuse the consumer of the source of candidateSliceOp by computing the required slice of the consumer i...
Fuse the producer of the source of candidateSliceOp by computing the required slice of the producer i...
Transformation information returned after reduction tiling.
SmallVector< Value > replacements
The replacements to use for the results of the tiled operation.
SmallVector< Value > initialValues
Initial values used for reduction.
SmallVector< Operation * > parallelTiledOps
The partial reduction tiled op generated.
SmallVector< LoopLikeOpInterface > loops
The loop operations that iterate over the tiles.
SmallVector< Operation * > mergeOps
The final reduction operation merging all the partial reductions.
Control function to check if a slice needs to be fused or not, The control function receives 1) the s...
Options used to control tile + fuse.
Transformation information returned after tile and fuse.
Options to use to control tiling.
SCFTileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes to use for each loop.
SCFTilingOptions & setNumThreads(ArrayRef< OpFoldResult > numThreads)
Convenience function to set the numThreadsComputationFunction to a function that computes num threads...
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.