MLIR  21.0.0git
TileUsingInterface.cpp
Go to the documentation of this file.
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
24 #include "mlir/IR/Dominance.h"
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
31 #include "llvm/ADT/ScopeExit.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/Debug.h"
34 #include <optional>
35 
36 #define DEBUG_TYPE "tile-using-interface"
37 
38 using namespace mlir;
39 
42  assert(!tileSizeComputationFunction && "tile sizes already set");
43  auto tileSizes = llvm::to_vector(ts);
44  tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
45  return tileSizes;
46  };
47  return *this;
48 }
49 
52  assert(!numThreadsComputationFunction && "num tiles already set");
53  auto numThreads = llvm::to_vector(nt);
54  numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) {
55  return numThreads;
56  };
57  return *this;
58 }
59 
60 /// Helper method to adjust the interchange vector to match the iteration
61 /// domain.
64  size_t iterationDomainSize) {
65  SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
66  if (filledVector.size() < iterationDomainSize) {
67  auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
68  filledVector.append(range.begin(), range.end());
69  }
70  if (filledVector.size() > iterationDomainSize)
71  filledVector.resize(iterationDomainSize);
72  return filledVector;
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // tileUsingSCF implementation.
77 //===----------------------------------------------------------------------===//
78 
79 /// Verify the tile size options are set in a consistent manner.
80 static LogicalResult
83  // Specifying number of threads is only supported on `scf.forall` op.
84  if (options.numThreadsComputationFunction &&
86  return rewriter.notifyMatchFailure(
87  loc, "number of threads can only by specified when loop type is "
88  "set to use `scf.forall`");
89  }
90 
91  // If specified, check that the interchange vector is a permutation.
92  if (!options.interchangeVector.empty()) {
93  if (!isPermutationVector(options.interchangeVector)) {
94  return rewriter.notifyMatchFailure(
95  loc, "invalid interchange vector, not a permutation of the entire "
96  "iteration space");
97  }
98  }
99  return success();
100 }
101 
102 /// Method to instantiate the tile sizes and/or number of threads specified
103 /// by the user.
104 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
105 getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
106  ArrayRef<Range> iterationDomain,
108  OpFoldResult zero = rewriter.getIndexAttr(0);
109  SmallVector<OpFoldResult> tileSizes, numThreads;
110  size_t numLoops = iterationDomain.size();
111 
112  // Check whether the number of tiles to use is specified.
113  if (options.numThreadsComputationFunction) {
114  numThreads = options.numThreadsComputationFunction(rewriter, op);
115  numThreads.resize(numLoops, zero);
116 
117  // If the number of tiles is also specified, use that.
118  if (options.tileSizeComputationFunction) {
119  tileSizes = options.tileSizeComputationFunction(rewriter, op);
120  tileSizes.resize(numLoops, zero);
121  return {tileSizes, numThreads};
122  }
123 
124  // Compute the tile sizes from the iteration domain and number
125  // of tiles as follows
126  // - niters = ceilDiv(ub - lb, step)
127  // - tileSize = ceilDiv(niters, numThreads)
128  AffineExpr s0, s1, s2;
129  bindSymbols(rewriter.getContext(), s0, s1, s2);
130  // TODO: The step here is assumed to be 1.
131  AffineExpr numItersExpr = (s1 - s0);
132  AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s2);
133  tileSizes.resize(numLoops, zero);
134  for (auto [index, range, nt] :
135  llvm::enumerate(iterationDomain, numThreads)) {
136  if (isZeroInteger(nt))
137  continue;
138 
139  tileSizes[index] = affine::makeComposedFoldedAffineApply(
140  rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
141  }
142  tileSizes.resize(numLoops, zero);
143  return {tileSizes, numThreads};
144  }
145 
146  // Enforce the convention that "tiling by zero"
147  // skips tiling a particular dimension. This convention is significantly
148  // simpler to handle instead of adjusting affine maps to account for missing
149  // dimensions.
150  assert(options.tileSizeComputationFunction &&
151  "expected tile sizes to be specified");
152  tileSizes = options.tileSizeComputationFunction(rewriter, op);
153  tileSizes.resize(numLoops, zero);
154 
155  return {tileSizes, numThreads};
156 }
157 
158 /// Checks if any of the tiled loops are not parallel.
159 static void checkSafeToTileToForall(TilingInterface op,
160  ArrayRef<OpFoldResult> tileSizes,
161  ArrayRef<OpFoldResult> numThreads) {
162  auto iterators = op.getLoopIteratorTypes();
163  assert(iterators.size() == tileSizes.size() &&
164  "expected as many tile size values as number of loops");
165  assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
166  "when specified, expected number of threads to use for each loop");
167 
168  for (auto [index, iterator, tileSize] :
169  llvm::enumerate(iterators, tileSizes)) {
170  // If num threads is specified, check that it is greater than one only for
171  // parallel dimensions.
172  if (!numThreads.empty()) {
173  if (std::optional<int64_t> constNumThreads =
174  getConstantIntValue(numThreads[index])) {
175  if (constNumThreads.value() > 1 &&
176  iterator != utils::IteratorType::parallel) {
177  op.emitWarning() << "tiling is not thread safe at axis #" << index;
178  }
179  }
180  continue;
181  }
182 
183  if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) {
184  if (constTileSize.value() > 0 &&
185  iterator != utils::IteratorType::parallel) {
186  op.emitWarning() << "tiling is not thread safe at axis #" << index;
187  }
188  }
189  }
190 }
191 
192 /// Check if `stride` evenly divides the trip count `size - offset`.
193 static bool tileDividesIterationDomain(Range loopRange) {
194  std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
195  if (!offsetAsInt)
196  return false;
197  std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
198  if (!sizeAsInt)
199  return false;
200  std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
201  if (!strideAsInt)
202  return false;
203  return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
204 }
205 
206 /// Returns the bounded tile size given the current `offset`, `loopRange` and
207 /// `tileSize`, i.e., `min(tileSize, range.end() - offset)`.
209  Range loopRange, OpFoldResult offset,
210  OpFoldResult tileSize) {
211  std::optional<int64_t> ts = getConstantIntValue(tileSize);
212  if (ts && ts.value() == 1)
213  return tileSize;
214 
216  Range{loopRange.offset, loopRange.size, tileSize}))
217  return tileSize;
218 
219  // The tile size to use (to avoid out of bounds access) is minimum of
220  // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
221  // loop.
222  AffineExpr s0, s1, d0;
223  bindDims(b.getContext(), d0);
224  bindSymbols(b.getContext(), s0, s1);
225  AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext());
226  Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
228  b, loc, minMap, SmallVector<OpFoldResult>{offset, size, tileSize});
229 }
230 
231 /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
232 /// than `iterationSize`.
234  OpFoldResult numThreads,
235  OpFoldResult iterationSize) {
236  std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
237  std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
238  std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
239  if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
240  return false;
241  return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
242 }
243 
244 /// Compute the `OpFoldResult`s that represents the multi-dimensional
245 /// `offset`s and `size`s of the tile of the iteration space that the
246 /// innermost loop body of the generated tiled loops corresponds to.
247 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
249  ArrayRef<Range> iterationDomain,
250  ArrayRef<OpFoldResult> tileSizes,
251  ArrayRef<OpFoldResult> numThreads) {
252  SmallVector<OpFoldResult> offsets, sizes;
253  int materializedLoopNum = 0;
254 
255  if (!numThreads.empty()) {
256  AffineExpr d0, d1, s0, s1;
257  AffineExpr offsetExpr, residualTileSizeExpr;
258  bindDims(rewriter.getContext(), d0, d1);
259  bindSymbols(rewriter.getContext(), s0, s1);
260  offsetExpr = d0 + d1 * s0;
261  residualTileSizeExpr = s1 - (d0 + d1 * s0);
262 
263  for (auto [nt, tileSize, loopRange] :
264  llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
265 
266  // Non-tiled cases, set the offset and size to the
267  // `loopRange.offset/size`.
268  if (isZeroInteger(nt)) {
269  offsets.push_back(loopRange.offset);
270  sizes.push_back(loopRange.size);
271  continue;
272  }
273 
274  Value iv = ivs[materializedLoopNum++];
276  rewriter, loc, offsetExpr,
277  ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize});
279  rewriter, loc, residualTileSizeExpr,
280  {loopRange.offset, nt, tileSize, loopRange.size});
281 
282  OpFoldResult size = tileSize;
283  if (!isZeroInteger(residualTileSize)) {
284  OpFoldResult sizeMinusOffsetPerThread =
285  affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
286  {offset, loopRange.size});
288  rewriter, loc,
290  {sizeMinusOffsetPerThread, tileSize});
291  }
292 
293  // Consider the case where the original loop was `[0, 100)`.
294  // If number of threads are `7`, the tile size would be computed as
295  // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6)
296  // - `offset = 0 + 6 * 15 = 105`
297  // - `tileSize = min(15, 100 - 105) = -5`
298  // To avoid negative tile sizes, we need to do a further
299  // `nonNegativeTileSize = affine.max(0, tileSize)`.
300  // This `max` can be avoided if
301  // `offset + tileSize * (numThreads - 1) < (ub - lb)`
302  if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) {
303  AffineMap maxMap =
306  rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});
307  }
308 
309  offsets.push_back(offset);
310  sizes.push_back(size);
311  }
312  return {offsets, sizes};
313  } else {
314  for (auto [tileSize, loopRange] :
315  llvm::zip_equal(tileSizes, iterationDomain)) {
316 
317  // Non-tiled cases, set the offset and size to the
318  // `loopRange.offset/size`.
319  if (isZeroInteger(tileSize)) {
320  offsets.push_back(loopRange.offset);
321  sizes.push_back(loopRange.size);
322  continue;
323  }
324 
325  Value iv = ivs[materializedLoopNum++];
326  OpFoldResult offset = getAsOpFoldResult(iv);
327  offsets.push_back(offset);
328  OpFoldResult size =
329  getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize);
330  sizes.push_back(size);
331  }
332  return {offsets, sizes};
333  }
334 }
335 
336 /// Function to return the bounds of the loops to be generated.
337 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
340  ArrayRef<OpFoldResult> tileSizes) {
341  SmallVector<OpFoldResult> lbs, ubs, steps;
342  for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
343  // No loop if the tile size is 0.
344  if (isZeroInteger(tileSize))
345  continue;
346  lbs.push_back(loopRange.offset);
347  ubs.push_back(loopRange.size);
348  steps.push_back(tileSize);
349  }
350  return {lbs, ubs, steps};
351 }
352 
353 /// A function that allows returning additional yielded values during
354 /// `yieldTiledValuesAndReplace`.
355 /// - `ivs` induction variable for the loop.
356 /// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
357 /// - `tiledValues` the tiled values to return. Must be of same size as
358 /// `newbbArgs`, each element of this array is inserted into the corresponding
359 /// element in `newbbArgs`.
360 /// - `resultOffsets` is of the same size as `tiledValues` and represents
361 /// the offsets to use when inserting corresponding element from `tiledValues`
362 /// into the element from `newBbArgs`.
363 /// - `resultSizes` is of the same size as `tiledValues` and represents
364 /// the size of the corresponding element from `tiledValues` inserted into
365 /// the element from `newBbArgs`.
366 /// In case the method needs to return `failure()` the method is expected
367 /// to clean up any inserted operations.
368 using YieldTiledValuesFn = std::function<LogicalResult(
369  RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
370  SmallVector<Value> &tiledValues,
371  SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
372  SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
373 
374 /// Clones the operation and updates the destination if the operation
375 /// implements the `DestinationStyleOpInterface`.
377  Operation *op,
378  ValueRange newDestArgs) {
379  Operation *clonedOp = rewriter.clone(*op);
380  if (newDestArgs.empty())
381  return clonedOp;
382  if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
383  destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
384  return clonedOp;
385 }
386 
387 /// Generate the tile-loop nest using `scf.for` operation.
388 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
389 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
390 /// - `destinationTensors` are the init values to use for the outer most loop.
391 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
392 /// most
393 /// loop.
394 /// - `loops` is an in-out parameter into which the generated loops are
395 /// populated.
396 static LogicalResult generateLoopNestUsingForOp(
397  RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
398  ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
399  YieldTiledValuesFn yieldTiledValuesFn,
401  assert(!loopRanges.empty() && "unexpected empty loop ranges");
402  assert(loopRanges.size() == tileSizes.size() &&
403  "expected as many tile sizes as loop ranges");
404  OpBuilder::InsertionGuard guard(rewriter);
405 
406  SmallVector<OpFoldResult> lbs, ubs, steps;
407  std::tie(lbs, ubs, steps) =
408  getLoopBounds(rewriter, loc, loopRanges, tileSizes);
409  SmallVector<Value> lbVals =
410  getValueOrCreateConstantIndexOp(rewriter, loc, lbs);
411  SmallVector<Value> ubVals =
412  getValueOrCreateConstantIndexOp(rewriter, loc, ubs);
413  SmallVector<Value> stepVals =
414  getValueOrCreateConstantIndexOp(rewriter, loc, steps);
415 
416  SmallVector<Value> ivs;
417  for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
418  auto loop =
419  rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
420  [](OpBuilder &bodyBuilder, Location bodyLoc,
421  Value iv, ValueRange /*iterArgs*/) {});
422  loops.push_back(loop);
423  ivs.push_back(loop.getInductionVar());
424  rewriter.setInsertionPointToEnd(loop.getBody());
425  destinationTensors = loop.getRegionIterArgs();
426  }
427 
428  SmallVector<Value> tiledResults;
429  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
430  if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
431  tiledResults, resultOffsets, resultSizes))) {
432  return rewriter.notifyMatchFailure(
433  loc, "failed to generate inner tile loop body");
434  }
435  if (loops.empty())
436  return success();
437 
438  assert(tiledResults.size() == destinationTensors.size() &&
439  "Number of results of body should be equal to number of iter args");
440 
441  // 6. Yield all the results of the tiled operation.
442  SmallVector<Value> yieldedValues;
443  for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
444  llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
445  resultSizes)) {
446  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
447  rewriter.getIndexAttr(1));
448  auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
449  loc, tiledValue, destinationTensor, resultOffset, resultSize,
450  resultStride);
451  yieldedValues.push_back(insertSlice);
452  }
453  rewriter.create<scf::YieldOp>(loc, yieldedValues);
454 
455  // Add the scf.yield operations for all the outer loops.
456  for (auto [outerLoop, innerLoop] :
457  llvm::zip_equal(MutableArrayRef(loops).drop_back(),
458  MutableArrayRef(loops).drop_front())) {
459  rewriter.setInsertionPointToEnd(
460  cast<scf::ForOp>(outerLoop.getOperation()).getBody());
461  rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
462  }
463  return success();
464 }
465 
466 /// Generate the tile-loop nest using `scf.forall` operation.
467 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
468 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
469 /// - `destinationTensors` are the init values to use for the outer most loop.
470 /// - `mappingVector` is the mapping attributes to use for loop construction.
471 /// Can be empty.
472 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
473 /// most
474 /// loop.
475 /// - `loops` is an in-out parameter into which the generated loops are
476 /// populated.
477 static LogicalResult generateLoopNestUsingForallOp(
478  RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
479  ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
480  ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
482  assert(!loopRanges.empty() && "unexpected empty loop ranges");
483  assert(loopRanges.size() == tileSizes.size() &&
484  "expected as many tile sizes as loop ranges");
485  OpBuilder::InsertionGuard guard(rewriter);
486 
487  std::optional<ArrayAttr> mappingAttr;
488  if (!mappingVector.empty())
489  mappingAttr = rewriter.getArrayAttr(mappingVector);
490 
491  scf::ForallOp forallOp;
492  bool useNumThreads = !numThreads.empty();
493 
494  if (useNumThreads) {
495  // Prune the zero numthreads.
496  SmallVector<OpFoldResult> nonZeroNumThreads;
497  for (auto nt : numThreads) {
498  if (isZeroInteger(nt))
499  continue;
500  nonZeroNumThreads.push_back(nt);
501  }
502  forallOp = rewriter.create<scf::ForallOp>(loc, nonZeroNumThreads,
503  destinationTensors, mappingAttr);
504  } else {
505  SmallVector<OpFoldResult> lbs, ubs, steps;
506  std::tie(lbs, ubs, steps) =
507  getLoopBounds(rewriter, loc, loopRanges, tileSizes);
508  forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps,
509  destinationTensors, mappingAttr);
510  }
511  loops.push_back(forallOp);
512 
513  rewriter.setInsertionPoint(forallOp.getTerminator());
514  destinationTensors = forallOp.getRegionOutArgs();
515 
516  SmallVector<Value> tiledResults;
517  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
518  if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
519  destinationTensors, tiledResults, resultOffsets,
520  resultSizes)))
521  return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
522 
523  rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
524  for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
525  llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
526  resultSizes)) {
527  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
528  rewriter.getIndexAttr(1));
529 
530  rewriter.create<tensor::ParallelInsertSliceOp>(
531  loc, tiledValue, destinationTensor, resultOffset, resultSize,
532  resultStride);
533  }
534  return success();
535 }
536 
537 /// Generate the tile-loop nest using the loop construct specifed in `options`.
538 /// - `options`: Tiling options specified.
539 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
540 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
541 /// - `destinationTensors` are the init values to use for the outer most loop.
542 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
543 /// most
544 /// loop.
545 /// - `loops` is an in-out parameter into which the generated loops are
546 /// populated.
547 static LogicalResult generateLoopNest(
548  RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options,
549  ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes,
550  ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors,
552  // If the tile sizes are all zero, no loops are generated. Just call the
553  // callback function to handle untiled case.
554  if (llvm::all_of(tileSizes, isZeroInteger)) {
555  SmallVector<Value> tiledResults;
556  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
557  return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
558  tiledResults, resultOffsets, resultSizes);
559  }
561  return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
562  destinationTensors, tiledBodyFn, loops);
563  }
566  rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector,
567  destinationTensors, tiledBodyFn, loops);
568  }
569  return rewriter.notifyMatchFailure(loc, "unhandled loop type");
570 }
571 
572 static FailureOr<SmallVector<Value>>
573 createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op,
574  ArrayRef<OpFoldResult> tileSizes,
576  SmallVector<Value> initTensors;
577  Location loc = op->getLoc();
578  switch (options.reductionStrategy) {
580  if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, initTensors)))
581  return failure();
582  return initTensors;
585  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
586  if (!redOp) {
587  return rewriter.notifyMatchFailure(
588  op, "PartialReductionOuterReduction tiling strategy is only supported"
589  "for operations implementing PartialReductionOpInterface");
590  }
591  // Get reduction dimensions.
592  // TODO: PartialReductionOpInterface should really query TilingInterface
593  // itself and find reduction dimensions.
594  SmallVector<int> reductionDims;
595  for (auto [idx, iteratorType] :
596  llvm::enumerate(op.getLoopIteratorTypes())) {
597  if (iteratorType == utils::IteratorType::reduction)
598  reductionDims.push_back(idx);
599  }
600  return redOp.generateInitialTensorForPartialReduction(
601  rewriter, loc, tileSizes, reductionDims);
602  }
603  default:
604  return rewriter.notifyMatchFailure(op,
605  "unhandled reduction tiling strategy");
606  }
607 }
608 
609 static FailureOr<TilingResult>
610 getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
611  ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
614  switch (options.reductionStrategy) {
616  return op.getTiledImplementation(rewriter, offsets, sizes);
619  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
620  if (!redOp) {
621  return rewriter.notifyMatchFailure(
622  op, "PartialReductionOuterReduction tiling strategy is only "
623  "supported for operations "
624  "implementing PartialReductionOpInterface");
625  }
626  // Get reduction dimensions.
627  // TODO: PartialReductionOpInterface should really query TilingInterface
628  // itself and find reduction dimensions.
629  SmallVector<int> reductionDims;
630  for (auto [idx, iteratorType] :
631  llvm::enumerate(op.getLoopIteratorTypes())) {
632  if (iteratorType == utils::IteratorType::reduction)
633  reductionDims.push_back(idx);
634  }
635  return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg,
636  offsets, sizes, reductionDims);
637  }
638  default:
639  return rewriter.notifyMatchFailure(op,
640  "unhandled reduction tiling strategy");
641  }
642 }
643 
644 static LogicalResult
645 getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
646  TilingInterface op, ArrayRef<OpFoldResult> offsets,
648  SmallVector<OpFoldResult> &resultOffset,
649  SmallVector<OpFoldResult> &resultSize,
651 
652  switch (options.reductionStrategy) {
654  return op.getResultTilePosition(rewriter, index, offsets, sizes,
655  resultOffset, resultSize);
658  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
659  if (!redOp) {
660  return rewriter.notifyMatchFailure(
661  op, "PartialReductionOuterReduction tiling strategy is only supported"
662  "for operations implementing PartialReductionOpInterface");
663  }
664  // Get reduction dimensions.
665  // TODO: PartialReductionOpInterface should really query TilingInterface
666  // itself and find reduction dimensions.
667  SmallVector<int> reductionDims;
668  for (auto [idx, iteratorType] :
669  llvm::enumerate(op.getLoopIteratorTypes())) {
670  if (iteratorType == utils::IteratorType::reduction)
671  reductionDims.push_back(idx);
672  }
673  return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
674  resultOffset, resultSize,
675  reductionDims);
676  }
677  default:
678  return rewriter.notifyMatchFailure(op,
679  "unhandled reduction tiling strategy");
680  }
681 }
682 
683 static FailureOr<MergeResult>
684 mergeTilingResults(RewriterBase &rewriter, TilingInterface op,
685  ValueRange partialResults,
687  switch (options.reductionStrategy) {
689  // No need to merge results for reduction tiling strategy.
690  return MergeResult{{}, partialResults};
693  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
694  if (!redOp) {
695  return rewriter.notifyMatchFailure(
696  op, "PartialReductionOuterReduction tiling strategy is only "
697  "supported for operations "
698  "implementing PartialReductionOpInterface");
699  }
700  // Get reduction dimensions.
701  // TODO: PartialReductionOpInterface should really query TilingInterface
702  // itself and find reduction dimensions.
703  SmallVector<int> reductionDims;
704  for (auto [idx, iteratorType] :
705  llvm::enumerate(op.getLoopIteratorTypes())) {
706  if (iteratorType == utils::IteratorType::reduction)
707  reductionDims.push_back(idx);
708  }
709  return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
710  reductionDims);
711  }
712  default:
713  return rewriter.notifyMatchFailure(op,
714  "unhandled reduction tiling strategy");
715  }
716 }
717 
718 /// Append the specified additional `newInitOperands` operands to the
719 /// loops existing `init` operands (or similar), and replace `loopOp` with
720 /// the new loop that has the additional init operands. The loop body of
721 /// this loop is moved over to the new loop. `yieldTiledValuesFn`
722 /// is called to get the new tiled values returned, and the offset
723 /// and sizes at which the tiled value is inserted into the
724 /// new region iter_args that correspond to the newly added init operands.
725 template <typename LoopType>
726 FailureOr<LoopLikeOpInterface>
727 yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter,
728  ValueRange newInitOperands,
729  YieldTiledValuesFn yieldTiledValuesFn) {
730  return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
731 }
732 
733 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`.
734 template <>
735 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
736  scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
737  YieldTiledValuesFn yieldTiledValuesFn) {
738  OpBuilder::InsertionGuard g(rewriter);
739  Location loc = loopOp.getLoc();
740  rewriter.setInsertionPoint(loopOp);
741 
742  auto inits = llvm::to_vector(loopOp.getInitArgs());
743  inits.append(newInitOperands.begin(), newInitOperands.end());
744  auto newLoop = rewriter.create<scf::ForOp>(
745  loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
746  inits, [](OpBuilder &, Location, Value, ValueRange) {});
747 
748  // Move the loop body to the new op.
749  Block *loopBody = loopOp.getBody();
750  Block *newLoopBody = newLoop.getBody();
751  rewriter.mergeBlocks(
752  loopBody, newLoopBody,
753  newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
754 
755  auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator());
756  rewriter.setInsertionPoint(yieldOp);
757 
758  SmallVector<Value> tiledValues;
759  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
760  ValueRange newRegionIterArgs =
761  newLoop.getRegionIterArgs().take_back(newInitOperands.size());
762  if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
763  newRegionIterArgs, tiledValues, resultOffsets,
764  resultSizes))) {
765  rewriter.eraseOp(newLoop);
766  return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");
767  }
768 
769  SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
770  for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
771  llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
772  resultSizes)) {
773  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
774  rewriter.getIndexAttr(1));
775  Value insert = rewriter.create<tensor::InsertSliceOp>(
776  yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
777  resultStride);
778  newYieldValues.push_back(insert);
779  }
780 
781  rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
782  rewriter.replaceOp(loopOp,
783  newLoop->getResults().take_front(loopOp.getNumResults()));
784  return cast<LoopLikeOpInterface>(newLoop.getOperation());
785 }
786 
787 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall`
788 template <>
789 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
790  scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
791  YieldTiledValuesFn yieldTiledValuesFn) {
792  OpBuilder::InsertionGuard g(rewriter);
793  Location loc = loopOp.getLoc();
794  rewriter.setInsertionPoint(loopOp);
795  auto inits = llvm::to_vector(loopOp.getOutputs());
796  inits.append(newInitOperands.begin(), newInitOperands.end());
797  auto newLoop = rewriter.create<scf::ForallOp>(
798  loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
799  loopOp.getMixedStep(), inits, loopOp.getMapping(),
800  [](OpBuilder &, Location, ValueRange) {});
801 
802  // Move the region of the current block to the newly created op.
803  Block *loopBody = loopOp.getBody();
804  Block *newLoopBody = newLoop.getBody();
805  rewriter.mergeBlocks(
806  loopBody, newLoopBody,
807  newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
808 
809  auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
810  rewriter.setInsertionPoint(terminator);
811  SmallVector<Value> tiledValues;
812  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
813  ValueRange regionIterArgs =
814  newLoop.getRegionIterArgs().take_back(newInitOperands.size());
815  if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
816  regionIterArgs, tiledValues, resultOffsets,
817  resultSizes))) {
818  rewriter.eraseOp(newLoop);
819  return rewriter.notifyMatchFailure(loopOp,
820  "failed to get yielded tiled values");
821  }
822 
823  // Update the terminator.
824  rewriter.setInsertionPointToEnd(terminator.getBody());
825 
826  for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
827  tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
828  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
829  rewriter.getIndexAttr(1));
830  rewriter.create<tensor::ParallelInsertSliceOp>(
831  terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
832  resultStride);
833  }
834 
835  rewriter.replaceOp(loopOp,
836  newLoop->getResults().take_front(loopOp.getNumResults()));
837  return cast<LoopLikeOpInterface>(newLoop.getOperation());
838 }
839 
840 /// Implementation of `yieldTiledValuesAndReplaceLoop` for
841 /// `LoopLikeOpInterface`, that just dispatches to the implementation for each
842 /// supported loop type.
843 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
844  LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
845  ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
847  loopLikeOp.getOperation())
848  .Case<scf::ForOp, scf::ForallOp>(
849  [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
851  loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
852  })
853  .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
854  return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
855  });
856 }
857 
858 /// Method to add new init values to a loop nest. Updates `loops` in-place
859 /// with new loops that use the `newInitValues`. The outer-loops are updated
860 /// to yield the new result values of the inner loop. For the innermost loop,
861 /// the call back `getNewYields` is invoked to get the additional values to
862 /// yield form the innermost loop.
863 static LogicalResult addInitOperandsToLoopNest(
865  ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
866  if (loops.empty())
867  return success();
868  OpBuilder::InsertionGuard g(rewriter);
869  rewriter.setInsertionPoint(loops.front());
870 
871  SmallVector<Value> ivs;
872  for (auto &loop : loops.drop_back()) {
873  rewriter.setInsertionPoint(loop);
874 
875  // if loops.size() > 1 we assume that scf.for is used for the loops.
876  auto forLoop = cast<scf::ForOp>(loop.getOperation());
877 
878  // Create a new loop with the new init values for this loop.
879  SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs());
880  newInits.append(newInitValues.begin(), newInitValues.end());
881  auto newLoop = rewriter.create<scf::ForOp>(
882  forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
883  forLoop.getStep(), newInits,
884  [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
885 
886  // Merge the body of the new loop with the body of the old loops.
887  SmallVector<Value> sourceBlockArgs;
888  sourceBlockArgs.push_back(newLoop.getInductionVar());
889  auto newRegionIterArgs = newLoop.getRegionIterArgs();
890  sourceBlockArgs.append(
891  newRegionIterArgs.begin(),
892  std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
893  rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
894  rewriter.replaceOp(
895  forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
896  loop = newLoop;
897  ivs.push_back(newLoop.getInductionVar());
898  newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
899  }
900 
901  // Update the loop body of the innermost loop to get new yield values.
902  LoopLikeOpInterface innerMostLoop = loops.back();
903  FailureOr<LoopLikeOpInterface> newInnerMostLoop =
904  yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues,
905  getNewTiledYieldsFn);
906 
907  if (failed(newInnerMostLoop))
908  return innerMostLoop.emitOpError("failed to return additional yields");
909  loops.back() = newInnerMostLoop.value();
910 
911  // Make all other loops except the innermost loops yield the values returned
912  // by the inner loop.
913  for (auto [outerLoop, innerLoop] :
914  llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
915  // Again assume that all the outer loops are scf.for operations.
916  auto outerForLoop = cast<scf::ForOp>(outerLoop);
917  auto outerLoopYield =
918  cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
919  SmallVector<Value> newYields =
920  llvm::to_vector(outerLoopYield.getOperands());
921  ValueRange additionalYields =
922  innerLoop->getResults().take_back(newInitValues.size());
923  newYields.append(additionalYields.begin(), additionalYields.end());
924  rewriter.setInsertionPoint(outerLoopYield);
925  rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
926  }
927  return success();
928 }
929 
930 /// Implementation of tiling transformation of `op` that implements the
931 /// `TilingInterface` using `scf.for` to iterate over the tiles.
932 FailureOr<scf::SCFTilingResult>
933 mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
935  if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) {
936  return failure();
937  }
938 
939  OpBuilder::InsertionGuard guard(rewriter);
940  rewriter.setInsertionPointAfter(op);
941 
942  // 1. Get the range of the loops that are represented by the operation.
943  SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
944 
945  // 2. Materialize the tile sizes and/or number of threads;
946  SmallVector<OpFoldResult> tileSizes, numThreads;
947  std::tie(tileSizes, numThreads) =
948  getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options);
949 
950  // Check if it is safe to tile. This is hold over from previous iterations
951  // of tile to for-all. Consider dropping it.
953  checkSafeToTileToForall(op, tileSizes, numThreads);
954  }
955 
956  // 3. If there is an interchange specified, permute the iteration domain and
957  // the tile sizes.
958  SmallVector<int64_t> interchangeVector;
959  if (!options.interchangeVector.empty()) {
960  interchangeVector = fillInterchangeVector(options.interchangeVector,
961  iterationDomain.size());
962  assert(isPermutationVector(interchangeVector) &&
963  "expected interchange vector to be a permutation");
964 
965  applyPermutationToVector(iterationDomain, interchangeVector);
966  applyPermutationToVector(tileSizes, interchangeVector);
967  if (!numThreads.empty())
968  applyPermutationToVector(numThreads, interchangeVector);
969  }
970 
971  FailureOr<TilingResult> tilingResult;
972  // 4. Define the lambda function used later to generate the body of the
973  // innermost tiled loop.
974  YieldTiledValuesFn innerYieldTiledValuesFn =
975  [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
976  ValueRange regionIterArgs, SmallVector<Value> &tiledResults,
979  -> LogicalResult {
980  // 4a. Compute the `offsets` and `sizes` to use for tiling.
981  SmallVector<OpFoldResult> offsets, sizes;
982  std::tie(offsets, sizes) = getTileOffsetAndSizes(
983  rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
984 
985  // 4b. If interchange was provided, apply inverse of the interchange
986  // to get back the offsets/sizes in the order to be specified.
987  if (!interchangeVector.empty()) {
988  auto inversePermutation = invertPermutationVector(interchangeVector);
991  }
992 
993  // 5. Generate the tiled implementation within the inner most loop.
994 
995  // 5a. Clone the operation within the loop body.
996  auto clonedOp = cast<TilingInterface>(
997  cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
998 
999  // 5b. Early return cloned op if tiling is not happening. We can not
1000  // return the original op because it could lead to `rewriter.replaceOp(op,
1001  // op->getResults())` and users would get crash.
1002  if (llvm::all_of(tileSizes, isZeroInteger)) {
1003  tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
1004  tilingResult =
1005  TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
1006  /*generatedSlices=*/{}};
1007  return success();
1008  }
1009 
1010  // 5c. Tile the cloned operation.
1011  tilingResult = getTiledImplementation(rewriter, clonedOp, regionIterArgs,
1012  offsets, sizes, options);
1013  if (failed(tilingResult)) {
1014  rewriter.eraseOp(clonedOp);
1015  return op.emitOpError("faild to tile operation");
1016  }
1017 
1018  // 5d. Delete the cloned operation.
1019  rewriter.eraseOp(clonedOp);
1020 
1021  // 5e. Compute the offsets at which the result values are to be inserted
1022  // back into its destinations.
1023  for (auto [index, tiledValue] :
1024  llvm::enumerate(tilingResult->tiledValues)) {
1025  tiledResults.push_back(tiledValue);
1026  SmallVector<OpFoldResult> resultOffset, resultSize;
1027  if (failed(getResultTilePosition(rewriter, index, tiledValue, op, offsets,
1028  sizes, resultOffset, resultSize,
1029  options))) {
1030  for (auto op : tilingResult->tiledOps) {
1031  rewriter.eraseOp(op);
1032  }
1033  return rewriter.notifyMatchFailure(
1034  op, "failed to get slice of result produced");
1035  }
1036  resultOffsets.emplace_back(std::move(resultOffset));
1037  resultSizes.emplace_back(std::move(resultSize));
1038  }
1039 
1040  return success();
1041  };
1042 
1043  // 6. Find the destination tensors to use for the operation.
1044  FailureOr<SmallVector<Value>> maybeInits =
1045  createInitialTensorsForTiling(rewriter, op, tileSizes, options);
1046  if (failed(maybeInits)) {
1047  return rewriter.notifyMatchFailure(
1048  op, "unable to create initial tensors for tiling");
1049  }
1050  SmallVector<Value> &initTensors = maybeInits.value();
1051 
1052  // 7. Generate the tiled loops nest using the callback defined above.
1054  if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
1055  tileSizes, numThreads, initTensors,
1056  innerYieldTiledValuesFn, loops)))
1057  return op.emitOpError("failed to generate tiling loops");
1058  assert(succeeded(tilingResult) &&
1059  "expected tiling result to be computed after loop generation");
1060 
1061  if (loops.empty()) {
1062  // If loops are empty, the tiled op is used as the replacement for the
1063  // untiled op.
1064  return scf::SCFTilingResult{tilingResult->tiledOps,
1065  initTensors,
1066  loops,
1067  tilingResult->tiledValues,
1068  tilingResult->generatedSlices,
1069  {}};
1070  }
1071 
1072  auto loopResults = llvm::map_to_vector(loops.front()->getResults(),
1073  [](OpResult r) -> Value { return r; });
1074 
1075  // For the full reduction case, there is nothing more to do.
1076  if (options.reductionStrategy ==
1078  return scf::SCFTilingResult{
1079  tilingResult->tiledOps, initTensors, loops, loopResults,
1080  tilingResult->generatedSlices, {}};
1081  }
1082 
1083  // The results of the loop needs to be merged.
1084  FailureOr<MergeResult> mergeResult =
1085  mergeTilingResults(rewriter, op, loopResults, options);
1086  if (failed(mergeResult)) {
1087  return rewriter.notifyMatchFailure(
1088  op, "Failed to merge partial results from tiling");
1089  }
1090  return scf::SCFTilingResult{tilingResult->tiledOps,
1091  initTensors,
1092  loops,
1093  mergeResult->replacements,
1094  tilingResult->generatedSlices,
1095  mergeResult->mergeOps};
1096 }
1097 
1098 FailureOr<scf::SCFTilingResult>
1100  PartialReductionOpInterface op,
1101  ArrayRef<OpFoldResult> tileSize) {
1104  options.setReductionTilingStrategy(
1106  PartialReductionOuterReduction);
1107  options.setTileSizes(tileSize);
1108  return tileUsingSCF(b, op, options);
1109 }
1110 
1111 //===----------------------------------------------------------------------===//
1112 // tileConsumerAndFuseProducersUsingSCF implementation.
1113 //===----------------------------------------------------------------------===//
1114 
1115 /// Return the untiled producer whose slice is used in a tiled consumer. The
1116 /// method traverses the tile loop nest (`loops`) if needed, and returns the
1117 /// `iter_args` of the outer most that is encountered. Traversing the
1118 /// iter_args indicates that this is a destination operand of the consumer. If
1119 /// there was no loop traversal needed, the second value of the returned tuple
1120 /// is empty.
1121 static std::tuple<OpResult, std::optional<OpOperand *>>
1124  std::optional<OpOperand *> destinationIterArg;
1125  assert(!loops.empty() && "expected non empty loops container");
1126  auto loopIt = loops.rbegin();
1127  while (loopIt != loops.rend() && isa<BlockArgument>(source->get())) {
1128  auto iterArg = cast<BlockArgument>(source->get());
1129  auto loop = *loopIt;
1130  if (iterArg.getOwner()->getParentOp() != loop)
1131  break;
1132  source = loop.getTiedLoopInit(iterArg);
1133  loopIt++;
1134  }
1135  if (loopIt == loops.rend())
1136  destinationIterArg = source;
1137  return {dyn_cast<OpResult>(source->get()), destinationIterArg};
1138 }
1139 
1140 /// Implementation of fusing producer of a single slice by computing the
1141 /// slice of the producer in-place.
1142 std::optional<scf::SCFFuseProducerOfSliceResult>
1144  RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1146  // 1. Get the producer of the source (potentially walking through
1147  // `iter_args` of nested `scf.for`)
1148  auto [fusableProducer, destinationInitArg] =
1149  getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
1150  loops);
1151  if (!fusableProducer)
1152  return std::nullopt;
1153  unsigned resultNumber = fusableProducer.getResultNumber();
1154 
1155  OpBuilder::InsertionGuard g(rewriter);
1156  rewriter.setInsertionPoint(candidateSliceOp);
1157 
1158  // 2. Clone the fused producer
1159  // 2a. Compute the destination operands to use for the cloned operation.
1160  SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
1161  Operation *fusableProducerOp = fusableProducer.getOwner();
1162  if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
1164  rewriter, fusableProducerOp->getLoc(), fusableProducerOp,
1165  origDestinationTensors)))
1166  return std::nullopt;
1167 
1168  clonedOpDestinationTensors = origDestinationTensors;
1169  if (destinationInitArg &&
1170  isa<DestinationStyleOpInterface>(fusableProducerOp)) {
1171  // 2b. If the producer is also destination style, then to maintain the
1172  // destination passing style, update the destination of the producer to be
1173  // the source of the slice.
1174  clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1175  }
1176  // 2c. Clone the fused producer.
1177  Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
1178  rewriter, fusableProducerOp, clonedOpDestinationTensors);
1179  // 2d. Update the source of the candidateSlice to be the cloned producer.
1180  // Easier to just clone the slice with different source since
1181  // replacements and DCE of cloned ops becomes easier
1182  SmallVector<Value> candidateSliceOpOperands =
1183  llvm::to_vector(candidateSliceOp->getOperands());
1184  candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
1185  tensor::ExtractSliceOp clonedCandidateSliceOp =
1186  mlir::clone(rewriter, candidateSliceOp,
1187  candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1188 
1189  // 3. Generate the tiled implementation of the producer of the source
1190  FailureOr<TilingResult> tileAndFuseResult =
1192  rewriter, clonedCandidateSliceOp,
1193  clonedProducerOp->getResult(resultNumber));
1194  if (failed(tileAndFuseResult))
1195  return std::nullopt;
1196  // Note: Do not delete the candidateSliceOp, since its passed in from the
1197  // caller.
1198  rewriter.replaceAllUsesWith(candidateSliceOp,
1199  tileAndFuseResult->tiledValues[0]);
1200  rewriter.eraseOp(clonedCandidateSliceOp);
1201  rewriter.eraseOp(clonedProducerOp);
1202 
1203  // 3. If the slice is for a destination operand, for example,
1204  //
1205  // ```mlir
1206  // %0 = linalg.init
1207  // %1 = linalg.fill .. outs(%0 : )
1208  // %2 = scf.for .. iter_args(%arg0 = %1) {
1209  // %3 = scf.for .. iter_args(%arg1 = %arg0) {
1210  // %4 = tensor.extract_slice %arg1 [..]
1211  // .. = linalg.matmul .. outs(%4 : )
1212  // }
1213  // }
1214  // ```
1215  //
1216  // the IR is currently
1217  //
1218  // ```
1219  // %0 = linalg.init
1220  // %1 = linalg.fill
1221  // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
1222  // %3 = scf.for .. iter_args(%arg1 = %arg0) {
1223  // %4 = tensor.extract_slice %arg1[..]
1224  // %5 = linalg.fill .. outs(%4 : )
1225  // .. = linalg.matmul .. outs(%5 : )
1226  // }
1227  // }
1228  // ```
1229  //
1230  // The untiled `linalg.fill` is still used as the `init_value` since it
1231  // was originally a destination operand of the untiled `linalg.matmul`.
1232  // When fusing an operand that is a destination operand, the iter_arg of
1233  // the outer most loop should be changed to use the destination of the
1234  // fused operation. With this the IR will be.
1235  //
1236  // ```
1237  // %0 = linalg.init
1238  // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
1239  // %2 = scf.for .. iter_args(%arg1 = %arg0) {
1240  // %3 = tensor.extract_slice %arg1[..]
1241  // %4 = linalg.fill .. outs(%3 : )
1242  // .. = linalg.matmul .. outs(%4 : )
1243  // }
1244  // }
1245  // ```
1246  if (destinationInitArg &&
1247  isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
1248  loops.front()
1249  ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1250  .set(origDestinationTensors[resultNumber]);
1251  }
1253  fusableProducer, tileAndFuseResult->tiledValues[0],
1254  tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices};
1255 }
1256 
1257 /// Reconstruct the fused producer from within the tiled-and-fused code.
1258 FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
1259  RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1260  scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
1262  ArrayRef<unsigned> yieldResultNumber) {
1263  if (loops.empty())
1264  return success();
1265 
1266  Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
1267  *tiledOwner = fusedProducerInfo.tiledOps[0];
1268 
1269  Location loc = originalOwner->getLoc();
1270  // a. collect all init Value to be appended
1271  SmallVector<unsigned> initNumberList =
1272  yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
1273  0, originalOwner->getNumResults()))
1274  : llvm::to_vector(yieldResultNumber);
1275  SmallVector<Value> initValueList;
1276  for (const auto &resultNumber : initNumberList) {
1277  FailureOr<Value> initValue = tensor::getOrCreateDestination(
1278  rewriter, loc, originalOwner->getResult(resultNumber));
1279  if (succeeded(initValue)) {
1280  initValueList.push_back(initValue.value());
1281  } else {
1282  return failure();
1283  }
1284  }
1285 
1286  SmallVector<Operation *> generatedSlices;
1287  YieldTiledValuesFn newYieldValuesFn =
1288  [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
1289  ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
1291  SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
1292  OpBuilder::InsertionGuard g(innerRewriter);
1293 
1294  // get sliceOp tile information
1295  SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
1296  sliceSizes = sliceOp.getMixedSizes();
1297 
1298  // expect all strides of sliceOp being 1
1299  if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
1300  return failure();
1301 
1302  unsigned sliceResultNumber =
1303  fusedProducerInfo.origProducer.getResultNumber();
1304 
1305  auto tilableOp = cast<TilingInterface>(originalOwner);
1306  // b. get iterDomain Offset and Sizes based on sliceOp tile
1307  SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
1308  // skip tensor.pack/unpack/pad, which expects single opResult
1309  if (tilableOp->getNumResults() > 1 &&
1310  failed(tilableOp.getIterationDomainTileFromResultTile(
1311  rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1312  iterDomainOffset, iterDomainSizes))) {
1313  // In theory, it is unnecessary to raise an error here. Actually
1314  // although it fails to reconstruct the result tensor, it should not
1315  // broke current fusion anyway. The reason why we must return failure
1316  // currently is that the callback function `newYieldValuesFn` will be
1317  // called after new init operand(s) has already been appended. It will
1318  // take more refactoring to make sure the init operands are added
1319  // consistently in the future. For more details, please refer to:
1320  // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
1321  return failure();
1322  }
1323 
1324  // c. calculate offsets and sizes info of all OpResults respectively based
1325  // on iteration Domain Tile
1326  SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
1327  for (const auto &resultNumber : initNumberList) {
1328  if (resultNumber == sliceResultNumber) {
1329  offsetList.push_back(sliceOffset);
1330  sizesList.push_back(sliceSizes);
1331  } else {
1332  assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1333  // infer result tile according to the iteration domain tile
1334  SmallVector<OpFoldResult> offset, sizes;
1335  if (failed(tilableOp.getResultTilePosition(
1336  rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1337  offset, sizes))) {
1338  return failure();
1339  }
1340  offsetList.push_back(offset);
1341  sizesList.push_back(sizes);
1342  }
1343  }
1344 
1345  // d. create `extract_slice` for `iter_args` for DPS operation if
1346  // necessary
1347  if (auto tiledDestStyleOp =
1348  dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1349  rewriter.setInsertionPoint(tiledDestStyleOp);
1350  for (const auto &&[index, newRegionArg] :
1351  llvm::enumerate(newRegionIterArgs)) {
1352  auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
1353  loc, newRegionArg, offsetList[index], sizesList[index],
1354  SmallVector<OpFoldResult>(offsetList[index].size(),
1355  rewriter.getIndexAttr(1)));
1356  generatedSlices.push_back(destSlice);
1357  unsigned resultNumber = initNumberList[index];
1358  rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
1359  tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1360  });
1361  }
1362  }
1363 
1364  // e. prepare tiled offset and sizes for later `insert_slice` creation by
1365  // caller
1366  Block *block = rewriter.getInsertionPoint()->getBlock();
1367  rewriter.setInsertionPoint(block->getTerminator());
1368  for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
1369  tiledResult.push_back(tiledOwner->getResult(resultNumber));
1370  tiledOffset.emplace_back(offsetList[index]);
1371  tiledSizes.emplace_back(sizesList[index]);
1372  }
1373  return success();
1374  };
1375 
1376  if (failed(addInitOperandsToLoopNest(rewriter, loops, initValueList,
1377  newYieldValuesFn))) {
1378  return failure();
1379  }
1380  return generatedSlices;
1381 }
1382 
1383 namespace {
1384 
1385 //===----------------------------------------------------------------------===//
1386 // SliceTrackingListener
1387 //===----------------------------------------------------------------------===//
1388 
1389 /// This class is a listener for tracking the insertion and removal of
1390 /// `tensor.extract_slice` ops in a worklist. This can be used in a greedy
1391 /// fusion algorithm to apply cleanup patterns in between fusion steps.
1392 class SliceTrackingListener : public RewriterBase::Listener {
1393 public:
1394  explicit SliceTrackingListener(
1395  std::optional<FrozenRewritePatternSet> patterns);
1396  SliceTrackingListener() = default;
1397 
1398  /// Adds the given list of operations to the worklist, and if present,
1399  /// applies the list of `patterns` to the newly added operations. This only
1400  /// processes the given operations and any newly inserted ones by the
1401  /// pattern set.
1402  LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
1403 
1404  /// Add to the new operation worklist if it is an extract_slice.
1405  void notifyOperationInserted(Operation *op,
1406  OpBuilder::InsertPoint previous) override;
1407 
1408  /// Shared helper for operation removal from the worklist.
1409  void removeOp(Operation *op);
1410 
1411  /// Remove the operation from the worklist.
1412  void notifyOperationErased(Operation *op) override;
1413 
1414  /// Remove the operation from the worklist.
1415  void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
1416 
1417  /// The worklist for this transformation keeps track of the slices to visit
1418  /// next for fusion.
1419  std::deque<tensor::ExtractSliceOp> worklist;
1420 
1421 private:
1422  /// Optional pattern set to apply when adding new operations to the
1423  /// worklist.
1424  std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
1425 };
1426 
1427 SliceTrackingListener::SliceTrackingListener(
1428  std::optional<FrozenRewritePatternSet> p) {
1429  patterns = std::move(p);
1430 }
1431 
1432 LogicalResult
1433 SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
1434  for (Operation *op : ops) {
1435  if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1436  worklist.push_back(slice);
1437  }
1438 
1439  if (!patterns)
1440  return success();
1441 
1442  return applyOpPatternsGreedily(
1443  ops, patterns.value(),
1444  GreedyRewriteConfig().setListener(this).setStrictness(
1446 }
1447 
1448 void SliceTrackingListener::notifyOperationInserted(
1449  Operation *op, OpBuilder::InsertPoint previous) {
1450  auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1451  if (!slice)
1452  return;
1453  worklist.push_back(slice);
1454 }
1455 
1456 // Scan the worklist for the given op and remove it if present. The
1457 // expectation is for the worklist to be small and for removal to be
1458 // relatively rare.
1459 void SliceTrackingListener::removeOp(Operation *op) {
1460  if (!isa<tensor::ExtractSliceOp>(op))
1461  return;
1462  auto iter = worklist.begin();
1463  while (iter != worklist.end()) {
1464  if (*iter == op)
1465  break;
1466  iter++;
1467  }
1468  if (iter == worklist.end())
1469  return;
1470 
1471  worklist.erase(iter);
1472 }
1473 
1474 void SliceTrackingListener::notifyOperationErased(Operation *op) {
1475  removeOp(op);
1476 }
1477 
1478 void SliceTrackingListener::notifyOperationReplaced(Operation *op,
1479  ValueRange replacement) {
1480  removeOp(op);
1481 }
1482 
1483 //===----------------------------------------------------------------------===//
1484 // ReplacementListener
1485 //===----------------------------------------------------------------------===//
1486 
1487 /// Listener that tracks updates replacements for values which can be mutated.
1488 /// This listener runs on top of the existing listener for the rewriter,
1489 /// to make sure external users can still run listeners.
1490 class ReplacementListener : public RewriterBase::ForwardingListener {
1491 public:
1492  ReplacementListener(DenseMap<Value, Value> &replacements,
1493  OpBuilder::Listener *listener)
1494  : ForwardingListener(listener), replacements(replacements) {}
1495 
1496  void updateReplacementValues(ValueRange origValues,
1497  ValueRange replaceValues) {
1498  // This can probably be written better, but just iterates over the map
1499  // and the new replacements for now.
1500  for (auto &[key, val] : replacements) {
1501  for (auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) {
1502  if (val == orig) {
1503  val = replace;
1504  }
1505  }
1506  }
1507  }
1508 
1509  void notifyOperationReplaced(Operation *op, Operation *newOp) override {
1510  ForwardingListener::notifyOperationReplaced(op, newOp);
1511  updateReplacementValues(op->getResults(), newOp->getResults());
1512  }
1513 
1514  void notifyOperationReplaced(Operation *op, ValueRange values) override {
1515  ForwardingListener::notifyOperationReplaced(op, values);
1516  updateReplacementValues(op->getResults(), values);
1517  }
1518 
1519 private:
1520  DenseMap<Value, Value> &replacements;
1521 };
1522 
1523 } // namespace
1524 
1525 /// Implementation of tile consumer and fuse producer greedily.
1526 FailureOr<scf::SCFTileAndFuseResult>
1528  RewriterBase &rewriter, TilingInterface consumer,
1530  // This transformation is only valid for ops that return values (i.e. not
1531  // valid to use with operations that have memref operands).
1532  if (!consumer->getNumResults()) {
1533  return rewriter.notifyMatchFailure(
1534  consumer, "invalid pattern for op with no results");
1535  }
1536 
1537  // 1. First tile the consumer.
1538  SetVector<Operation *> fusedProducers, tiledAndFusedOps;
1539 
1540  FailureOr<scf::SCFTilingResult> tilingResult =
1541  tileUsingSCF(rewriter, consumer, options.tilingOptions);
1542 
1543  if (failed(tilingResult))
1544  return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
1545  tiledAndFusedOps.insert_range(tilingResult->tiledOps);
1546 
1547  DenseMap<Value, Value> replacements;
1548  for (auto [origVal, replacement] :
1549  llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
1550  replacements[origVal] = replacement;
1551  }
1552 
1553  // If there are no loops generated, fusion is immaterial.
1554  auto &loops = tilingResult->loops;
1555  if (loops.empty()) {
1556  return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1557  replacements};
1558  }
1559 
1560  // Since the loop gets potentially replaced during fusion, we need to track
1561  // the mutation of replacement values. To do this, we attach a listener to
1562  // update the replacements as they happen.
1563  OpBuilder::Listener *previousListener = rewriter.getListener();
1564  auto resetListener =
1565  llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
1566  ReplacementListener replaceListener(replacements, previousListener);
1567  rewriter.setListener(&replaceListener);
1568 
1569  // 2. Typically, the operands of the tiled operation are slices of the
1570  // operands of the untiled operation. These are expressed in IR using
1571  // `tensor.extract_slice` operations with source being the operands of
1572  // the untiled operation. Create a worklist of these
1573  // `tensor.extract_slice` operations. If the producers of the source of
1574  // the `tensor.extract_slice` can be tiled such that the tiled value is
1575  // generated in-place, that effectively tiles + fuses the operations.
1576  struct WorklistItem {
1577  tensor::ExtractSliceOp candidateSlice;
1579  };
1580 
1581  SliceTrackingListener sliceTracker =
1582  SliceTrackingListener(options.cleanupPatterns);
1583 
1584  if (failed(
1585  sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1586  return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1587  }
1588  OpBuilder::InsertionGuard g(rewriter);
1589  while (!sliceTracker.worklist.empty()) {
1590  auto candidateSlice = sliceTracker.worklist.front();
1591  sliceTracker.worklist.pop_front();
1592 
1593  auto [fusableProducer, destinationInitArg] =
1594  getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(),
1595  loops);
1596  if (!fusableProducer)
1597  continue;
1598 
1599  std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1600  options.fusionControlFn(candidateSlice, fusableProducer,
1601  destinationInitArg.has_value());
1602  if (!controlFnResult)
1603  continue;
1604 
1605  WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
1606 
1607  // The operands of the fused producer might themselved be slices of
1608  // values produced by operations that implement the `TilingInterface`.
1609  // Add these operations to the worklist.
1610  std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1611  tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice,
1612  loops);
1613  if (!fusedResult)
1614  continue;
1615 
1616  SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
1617 
1618  if (worklistItem.controlFnResult.yieldProducerReplacement) {
1619  // Reconstruct and yield all opResult of fusableProducerOp by default.
1620  // The caller can specific which one to yield by designating optional
1621  // argument named `yieldResultNumber` of
1622  // `yieldReplacementForFusedProducer`.
1623  Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1624  FailureOr<SmallVector<Operation *>> newSlices =
1626  worklistItem.candidateSlice,
1627  fusedResult.value(), loops);
1628  if (failed(newSlices)) {
1629  return rewriter.notifyMatchFailure(
1630  fusableProducerOp, "failed to replacement value for this "
1631  "operation from within the tiled loop");
1632  }
1633  worklistCandidates.append(newSlices.value());
1634  for (auto [index, result] :
1635  llvm::enumerate(fusableProducerOp->getResults())) {
1636  replacements[result] = loops.front()->getResult(
1637  loops.front()->getNumResults() -
1638  fusableProducerOp->getNumResults() + index);
1639  }
1640  }
1641  if (Operation *tiledAndFusedOp =
1642  fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1643  fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1644  tiledAndFusedOps.insert(tiledAndFusedOp);
1645  }
1646 
1647  if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1648  return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1649  }
1650  }
1651 
1652  return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1653  replacements};
1654 }
1655 
1656 //===----------------------------------------------------------------------===//
1657 // tileAndFuseConsumerUsingSCF implementation.
1658 //===----------------------------------------------------------------------===//
1659 
1660 /// A utility function that checks whether the only use of the result of a
1661 /// tensor.insert_slice op is in a scf.yield op.
1662 static LogicalResult
1663 checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
1664  Value result = candidateSliceOp.getResult();
1665  Value::use_range uses = result.getUses();
1666  if (!llvm::hasSingleElement(uses)) {
1667  LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
1668  return failure();
1669  }
1670  OpOperand &operandUse = (*uses.begin());
1671  Operation *userOp = operandUse.getOwner();
1672  if (!isa<scf::YieldOp>(userOp)) {
1673  LLVM_DEBUG(llvm::dbgs()
1674  << "Expected scf.yield to be the only user, but got -> "
1675  << (*userOp));
1676  return failure();
1677  }
1678  if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
1679  LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
1680  "be in the same block\n");
1681  return failure();
1682  }
1683  return success();
1684 }
1685 
1686 /// An utility to get the first user of the given loopOp. If any of user stay
1687 /// in different block of loopOp, return failure.
1688 static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
1689  if (!isa<LoopLikeOpInterface>(loopOp))
1690  return failure();
1691  Operation *firstUserOfLoop = nullptr;
1692  for (Operation *userOp : loopOp->getUsers()) {
1693  // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
1694  // block with any other types of operation. Thus, just redirecting to its
1695  // parent `InParallelOp`. E.g.
1696  //
1697  // ```
1698  // %1 = scf.for {
1699  // ...
1700  // }
1701  // %2 = consumerOp ins(%1, ...)
1702  // scf.forall.in_parallel {
1703  // tensor.parallel_insert_slice %1
1704  // }
1705  // ```
1706  // where `InParallelOp` but not `ParallelInsertSlice` stays in the same
1707  // same block with `consumerOp`.
1708  if (isa<tensor::ParallelInsertSliceOp>(userOp))
1709  userOp = userOp->getParentOfType<scf::InParallelOp>();
1710 
1711  if (loopOp->getBlock() != userOp->getBlock())
1712  return failure();
1713 
1714  if (!firstUserOfLoop || userOp->isBeforeInBlock(firstUserOfLoop))
1715  firstUserOfLoop = userOp;
1716  }
1717  return firstUserOfLoop;
1718 }
1719 
1720 /// This utility currently checks whether the first userOp of loop is NOT
1721 /// before the last defineOp of consumer operand. Because that we need to move
1722 /// the whole loop structure right before the `firstUserOfLoop`. This utility
1723 /// thus helps ensuring that no invalid IR is formed, i.e. no backward slice
1724 /// of consumerOp is dominated by the `firstUserOfLoop`. Saying that:
1725 ///
1726 /// ```
1727 /// %0 = scf.for() {
1728 /// ...
1729 /// }
1730 /// ...
1731 /// %1 = firstUserOfLoop(%0)
1732 /// ...
1733 /// %2 = lastDefOfConsumerOperand
1734 /// ...
1735 /// %3 = consumerOp(%2)
1736 /// ```
1737 ///
1738 /// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it
1739 /// would be invalid to move the `loopOp` right before the `firstUserOfLoop`,
1740 /// a.k.a. use-def chain violation:
1741 ///
1742 /// ```
1743 /// %0:2 = scf.for() {
1744 /// // use before define error
1745 /// %3 = tiledConsumerOp(%2)
1746 /// }
1747 /// %1 = firstUserOfLoop(%0)
1748 /// ...
1749 /// %2 = lastDefOfConsumerOperand
1750 /// ```
1751 ///
1752 /// @param loopOp: loop operation
1753 /// @param consumerOp: consumer operation
1754 /// @param reorderOperations: the flag controls whether to reorder the
1755 /// backward slice w.r.t. the defineOp of `consumerOp` operands.
1756 /// @return: computed backward slice of consumerOp, but excluding those
1757 /// already dominates `firstUserOfLoop`.
1758 static FailureOr<llvm::SetVector<Operation *>>
1760  bool reorderOperations) {
1761  FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
1762  if (failed(firstUserOfLoop))
1763  return failure();
1764 
1766  DominanceInfo dominanceInfo;
1767  options.inclusive = true;
1768  options.omitBlockArguments = true;
1769  bool includeLoopOp = false;
1770  options.filter = [&](Operation *op) {
1771  if (op == loopOp) {
1772  includeLoopOp = true;
1773  return false;
1774  }
1775  // Cut off the slice to not include any operation that already dominates
1776  // firstUserOfLoop.
1777  return !dominanceInfo.properlyDominates(op, *firstUserOfLoop);
1778  };
1780  for (auto operand : consumerOp->getOperands()) {
1781  LogicalResult result = getBackwardSlice(operand, &slice, options);
1782  assert(result.succeeded() && "expected a backward slice");
1783  (void)result;
1784  }
1785 
1786  if (!slice.empty()) {
1787  // If consumerOp has one producer, which is also the user of loopOp.
1788  // E.g.
1789  // ```
1790  // %0 = %loopOp
1791  // %1 = consumerOp1 ins(%0)
1792  // %2 = consumerOp2 ins(%0, %1)
1793  // ```
1794  // We can not fuse consumerOp2 into loopOp due to UD chain, unless
1795  // consumerOp1 has already been fused into loopOp before.
1796  if (includeLoopOp || !reorderOperations)
1797  return failure();
1798  }
1799 
1800  return slice;
1801 }
1802 
1803 /// Fetches the OpOperand of the first valid user (and use) of the value `val`
1804 /// which implements `TilingInterface` and `DestinationStyleOpInterface`.
1805 /// Returns failure otherwise.
1806 static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
1807  Operation *loopOp,
1808  unsigned resultNumber) {
1809  if (!isa<LoopLikeOpInterface>(loopOp))
1810  return failure();
1811  Value val = loopOp->getResult(resultNumber);
1812  Block *loopBlock = loopOp->getBlock();
1813  for (OpOperand &opOperand : val.getUses()) {
1814  Operation *consumerOp = opOperand.getOwner();
1815  // Step 1. Check if the user is tilable.
1816  if (!isa<TilingInterface>(consumerOp) ||
1817  !isa<DestinationStyleOpInterface>(consumerOp)) {
1818  // TODO: We have to init result of consumer before scf.for, use
1819  // DestinationStyleOpInterface to get result shape from init for now.
1820  // Add support for other op such as op has InferTypeOpInterface.
1821  continue;
1822  }
1823  // Step 2. Check if user stay in the same block.
1824  if (loopBlock != consumerOp->getBlock())
1825  continue;
1826  // Step 3. Check if user has succeeding user. Otherwise, it usually
1827  // represents already tiled.
1828  if (consumerOp->use_empty())
1829  continue;
1830  // Step 4. Check assumption for loop with `reorderOperations` enabled.
1831  FailureOr<llvm::SetVector<Operation *>> slice =
1832  checkAssumptionForLoop(loopOp, consumerOp, true);
1833  if (failed(slice))
1834  continue;
1835  // Step 5. If backward sice is not empty, move them before
1836  // firstUserOfLoop.
1837  if (!slice->empty()) {
1838  mlir::topologicalSort(*slice);
1839  FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
1840  assert(succeeded(firstUserOfLoop) && "First user of loop is not found");
1841  for (auto op : *slice) {
1842  rewriter.moveOpBefore(op, *firstUserOfLoop);
1843  }
1844  }
1845  return &opOperand;
1846  }
1847  return failure();
1848 }
1849 
1850 /// Check that the loop is perfectly nested.
1851 /// The loops are expected to be ordered from outer most to inner most.
1852 /// For example:
1853 /// ```
1854 /// %0 = scf.for()
1855 /// %1 = scf.for()
1856 /// %2 = scf.for()
1857 /// %3 = ...
1858 /// yield %3
1859 /// yield %2
1860 /// yield %1
1861 /// ```
1862 /// Here loops should be [%0, %1].
1863 static bool
1865  assert(!loops.empty() && "unexpected empty loop nest");
1866  if (loops.size() == 1) {
1867  return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
1868  }
1869  for (auto [outerLoop, innerLoop] :
1870  llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
1871  auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
1872  auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
1873  if (!outerFor || !innerFor) {
1874  return false;
1875  }
1876  auto outerBBArgs = outerFor.getRegionIterArgs();
1877  auto innerIterArgs = innerFor.getInitArgs();
1878  if (outerBBArgs.size() != innerIterArgs.size()) {
1879  return false;
1880  }
1881 
1882  for (auto [outerBBArg, innerIterArg] :
1883  llvm::zip_equal(outerBBArgs, innerIterArgs)) {
1884  if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
1885  innerIterArg != outerBBArg) {
1886  return false;
1887  }
1888  }
1889 
1890  ValueRange outerYields =
1891  cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
1892  ValueRange innerResults = innerFor.getResults();
1893  if (outerYields.size() != innerResults.size()) {
1894  return false;
1895  }
1896  for (auto [outerYield, innerResult] :
1897  llvm::zip_equal(outerYields, innerResults)) {
1898  if (!llvm::hasSingleElement(innerResult.getUses()) ||
1899  outerYield != innerResult) {
1900  return false;
1901  }
1902  }
1903  }
1904  return true;
1905 }
1906 
1907 /// Fetch the untiled consumer of the outermost scf.for's result which is
1908 /// yielded by a tensor.insert_slice from the innermost scf.for. This function
1909 /// makes the following assumptions :
1910 /// 1. tensor.insert_slice has scf.yield as its only user.
1911 /// 2. scf.for's corresponding result has only one use.
1912 /// 3. The `loops` passed in are perfectly nested `scf.for` operations.
1913 static FailureOr<OpOperand *>
1915  tensor::InsertSliceOp candidateSliceOp,
1917  assert(!loops.empty() && "unexpected loops to be empty");
1918  // 1. Expect slice to be part of the body of the inner most loop.
1919  Operation *containingOp = candidateSliceOp->getParentOp();
1920  if (containingOp != loops.back()) {
1921  return rewriter.notifyMatchFailure(
1922  candidateSliceOp,
1923  "expected slice to be within body of inner-most loop");
1924  }
1925 
1926  // 2. Check that the loop is perfectly nested.
1927  if (!isPerfectlyNestedForLoops(loops)) {
1928  return rewriter.notifyMatchFailure(
1929  candidateSliceOp, "expected passed loops to be perfectly nested.");
1930  }
1931 
1932  if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
1933  return failure();
1934  Value sliceResult = candidateSliceOp.getResult();
1935 
1936  // 3. Fetch the corresponding output.
1937  OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
1938  unsigned resultNumber = yieldOpOperand.getOperandNumber();
1939 
1940  scf::ForOp topLevelForOp = cast<scf::ForOp>(loops.front().getOperation());
1941 
1942  return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber);
1943 }
1944 
1945 /// Fetch the first untiled consumer of a scf.forall's result which is yielded
1946 /// by a tensor.parallel_insert_slice.
1947 static FailureOr<OpOperand *>
1949  tensor::ParallelInsertSliceOp candidateSliceOp,
1951  assert(!loops.empty() && "unexpected loops to be empty");
1952  // 1. Check that the surrounding loop is a single scf.forall loop.
1953  if (loops.size() != 1) {
1954  return rewriter.notifyMatchFailure(
1955  candidateSliceOp, "expected single surrounding scf.forall");
1956  }
1957  auto forallOp = dyn_cast<scf::ForallOp>(loops.front().getOperation());
1958  if (!forallOp) {
1959  return rewriter.notifyMatchFailure(
1960  candidateSliceOp, "expected single surrounding scf.forall");
1961  }
1962 
1963  // 2. Fetch the corresponding output
1964  Value sliceDest = candidateSliceOp.getDest();
1965  auto iterArg = dyn_cast<BlockArgument>(sliceDest);
1966  if (!iterArg)
1967  return failure();
1968  if (iterArg.getOwner()->getParentOp() != forallOp)
1969  return failure();
1970 
1971  unsigned resultNumber =
1972  forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
1973  .getResultNumber();
1974 
1975  return getConsumerFromLoopUses(rewriter, forallOp, resultNumber);
1976 }
1977 
1978 /// A utility to fetch an untiled consumer of
1979 /// tensor.insert_slice/tensor.parallel_insert_slice.
1980 static FailureOr<OpOperand *>
1983  assert(!loops.empty() && "unexpected empty loops");
1984  if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1985  return getUntiledConsumerFromSlice(rewriter, insertSlice, loops);
1986  } else if (auto parallelInsertSlice =
1987  dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1988  return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice, loops);
1989  } else {
1990  return failure();
1991  }
1992 }
1993 
1994 /// Implementation of fusing consumer of a single slice by computing the
1995 /// slice of the consumer in-place for scf loop.
1996 FailureOr<scf::SCFFuseConsumerOfSliceResult>
1998  RewriterBase &rewriter, Operation *candidateSliceOp,
2000  // Return if `loops` is empty, return an error for now. Caller is expected
2001  // to handle this case.
2002  if (loops.empty()) {
2003  return candidateSliceOp->emitOpError(
2004  "cannot call tile and fuse consumer with an empty loop nest");
2005  }
2006  if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2007  candidateSliceOp))
2008  return failure();
2009 
2010  // 1. Get the consumer of scf.for for the result yielded by
2011  // tensor.insert_slice/parallel_insert_slice.
2012  FailureOr<OpOperand *> maybeConsumerOpOperand =
2013  getUntiledConsumerFromSlice(rewriter, candidateSliceOp, loops);
2014  if (failed(maybeConsumerOpOperand)) {
2015  return rewriter.notifyMatchFailure(candidateSliceOp,
2016  "could not fetch consumer to fuse");
2017  }
2018  OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
2019  Operation *consumerOp = consumerOpOperand->getOwner();
2020  unsigned operandNumber = consumerOpOperand->getOperandNumber();
2021  unsigned resultNumber = 0;
2022  if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
2023  resultNumber = producerResult.getResultNumber();
2024  } else {
2025  return rewriter.notifyMatchFailure(
2026  consumerOp, "consumer op's operand doesn't seem to be an OpResult");
2027  }
2028 
2029  LoopLikeOpInterface outerMostLoop = loops.front();
2030  LoopLikeOpInterface innerMostLoop = loops.back();
2031 
2032  // Check assumption for loop with `reorderOperations` disabled.
2033  if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
2034  return rewriter.notifyMatchFailure(
2035  outerMostLoop, "the first user of loop should not dominate any define "
2036  "of consumer operand(s)");
2037  }
2038 
2039  OpBuilder::InsertionGuard g(rewriter);
2040 
2041  // 2. Check consumer is not using scf loop's output as init.
2042  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
2043  if (!dstOp)
2044  return rewriter.notifyMatchFailure(consumerOp,
2045  "consumer op is not DPS operation");
2046  SmallVector<Value> dpsInits =
2047  llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
2048  if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
2049  return rewriter.notifyMatchFailure(
2050  consumerOp,
2051  "consumer op taking the result of scf.for as init is not supported");
2052  }
2053  SmallVector<Value> newInits = dpsInits;
2054 
2055  Location loc = outerMostLoop->getLoc();
2056 
2057  // 3. Move the whole loop structure right before firstUserOfLoop, the
2058  // dominance should be already ensured by `checkAssumptionForLoop`.
2059  FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(outerMostLoop);
2060  if (failed(firstUserOfLoop)) {
2061  return rewriter.notifyMatchFailure(
2062  outerMostLoop, "could not find the first user of outer most loop");
2063  }
2064  rewriter.moveOpBefore(outerMostLoop, *firstUserOfLoop);
2065 
2066  // 4. Set insertion point before terminator op of the loop and create a new
2067  // tensor.insert_slice. In the scf.for case this is a clone of the
2068  // candidateSliceOp whereas in the scf.forall case this is created from the
2069  // operands of tensor.parallel_insert_slice.
2070  tensor::InsertSliceOp clonedInsertSliceOp;
2071  if (auto sliceOp =
2072  dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
2073  auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
2074  rewriter.setInsertionPoint(newForallOp.getTerminator());
2075  clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
2076  loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
2077  sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
2078  } else {
2079  rewriter.setInsertionPoint(candidateSliceOp);
2080  clonedInsertSliceOp =
2081  cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
2082  }
2083 
2084  // 5.a. Clone consumer op.
2085  auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
2086 
2087  // 5.b. Replace all uses of the loop result with the result of the cloned
2088  // tensor.insert_slice.
2089  OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
2090  rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
2091  operandToReplace.set(clonedInsertSliceOp.getResult());
2092  });
2093 
2094  // 6. Perform tiling of the cloned consumer and replace the operand at
2095  // `operandNumber` with the source of the cloned tensor.insert_slice op.
2096  auto ossSliceOp =
2097  cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
2098  FailureOr<TilingResult> tileAndFuseResult =
2100  rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
2101  if (failed(tileAndFuseResult)) {
2102  return failure();
2103  }
2104  auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
2105  rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
2106  clonedInsertSliceOp.getSource());
2107 
2108  // 7. Reconstruct [nested] loop with new inits.
2109  YieldTiledValuesFn newYieldValuesFn =
2110  [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
2111  ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
2113  SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
2114  OpBuilder::InsertionGuard g(innerRewriter);
2115  // 8. Set inner insertPoint right before tiled consumer op.
2116  innerRewriter.setInsertionPoint(tiledConsumerOp);
2117 
2118  SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
2119  SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
2120  SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
2121 
2122  // 9. Check all insert stride is 1.
2123  if (!llvm::all_of(strides, isOneInteger)) {
2124  return rewriter.notifyMatchFailure(
2125  candidateSliceOp, "containingOp's result yield with stride");
2126  }
2127 
2128  // 10. Try to get iter domain position from input position. Use
2129  // clonedConsumerOp instead of tiledConsumerOp, because the iteration
2130  // domain may require index computation based on the result size. The
2131  // sizes and offsets should be the same either way, but using
2132  // tiledConsumerOp could lead to some chained unnecessary extra index
2133  // computation.
2134  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
2135  if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
2136  rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
2137  iterDomainSizes))) {
2138  return rewriter.notifyMatchFailure(
2139  clonedConsumerOp,
2140  "can't get iter domain position from input position");
2141  }
2142 
2143  // 11. Try to fetch the offset and size for all results of the cloned
2144  // consumer. This would then be used to form the corresponding
2145  // tensor.insert_slice/parallel_insert_slice later.
2146  unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
2148  totalNumResultsOfConsumer);
2150  totalNumResultsOfConsumer);
2151  for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
2152  if (failed(tiledConsumerOp.getResultTilePosition(
2153  rewriter, idx, iterDomainOffsets, iterDomainSizes,
2154  resultOffsets[idx], resultSizes[idx]))) {
2155  return rewriter.notifyMatchFailure(
2156  tiledConsumerOp,
2157  "can't get result domain position from iter domain position");
2158  }
2159  }
2160 
2161  // 12. Create `extract_slice` for `iter_args` for DPS operation if
2162  // necessary.
2163  if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
2164  tiledConsumerOp.getOperation())) {
2165  rewriter.setInsertionPoint(tiledDestStyleOp);
2166  for (const auto &&[index, newRegionArg] :
2167  llvm::enumerate(newRegionIterArgs)) {
2168  auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
2169  loc, newRegionArg, resultOffsets[index], resultSizes[index],
2170  SmallVector<OpFoldResult>(resultOffsets[index].size(),
2171  rewriter.getIndexAttr(1)));
2172  // Make a copy of index to avoid a capturing structured binding, which
2173  // is a C++20 extension.
2174  auto dstNumber = index;
2175  rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
2176  tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
2177  });
2178  }
2179  }
2180 
2181  // 13. Prepare tiled offset and sizes for later `insert_slice` creation by
2182  // caller.
2183  Block *block = rewriter.getInsertionPoint()->getBlock();
2184  rewriter.setInsertionPoint(block->getTerminator());
2185  for (const auto &&[index, result] :
2186  llvm::enumerate(tiledConsumerOp->getResults())) {
2187  tiledResult.push_back(result);
2188  tiledOffset.emplace_back(resultOffsets[index]);
2189  tiledSizes.emplace_back(resultSizes[index]);
2190  }
2191  return success();
2192  };
2193  // 14. Add new inits to [nested] loops.
2194  if (failed(addInitOperandsToLoopNest(rewriter, loops, newInits,
2195  newYieldValuesFn))) {
2196  return rewriter.notifyMatchFailure(tiledConsumerOp,
2197  "unable to add new inits to nest loop");
2198  }
2199 
2200  // 15. Replace the result of scf loop and consumer op with new loop's
2201  // results.
2202 
2203  for (auto &&[oldResult, newResult] :
2204  llvm::zip(consumerOp->getResults(),
2205  loops.front()->getResults().take_back(newInits.size()))) {
2206  rewriter.replaceAllUsesWith(oldResult, newResult);
2207  }
2208 
2209  // 16. Need to erase the old scf loop and the cloned consumer op.
2210  rewriter.eraseOp(clonedConsumerOp);
2211 
2213  consumerOpOperand,
2214  &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
2215  tileAndFuseResult->tiledOps};
2216 }
2217 
2218 //===----------------------------------------------------------------------===//
2219 // lowerToLoopsUsingSCFForOp implementation.
2220 //===----------------------------------------------------------------------===//
2221 
2222 FailureOr<SmallVector<scf::ForOp>>
2224  TilingInterface op) {
2225  // TODO: Handle cases where the op has results if needed.
2226  if (op->getNumResults() > 0) {
2227  return rewriter.notifyMatchFailure(
2228  op, "unable to lower to loops operations with return values");
2229  }
2230 
2231  SmallVector<Range> domain = op.getIterationDomain(rewriter);
2232  SmallVector<Value> ivs;
2234  Location loc = op.getLoc();
2235  for (auto loopRange : domain) {
2236  Value offsetVal =
2237  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
2238  Value sizeVal =
2239  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
2240  Value strideVal =
2241  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
2242  auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
2243  strideVal, ValueRange{});
2244  loops.push_back(loop);
2245  ivs.push_back(loop.getInductionVar());
2246  rewriter.setInsertionPoint(loop.getBody()->getTerminator());
2247  }
2248  if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
2249  return failure();
2250  }
2251  return loops;
2252 }
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static LogicalResult verifyTileSizeOptions(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options)
Verify the tile size options are set in a consistent manner.
static LogicalResult checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp)
A utility function that checks whether the only use of the result of a tensor.insert_slice op is in a...
static bool isPerfectlyNestedForLoops(MutableArrayRef< LoopLikeOpInterface > loops)
Check that the loop is perfectly nested.
std::function< LogicalResult(RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, SmallVector< Value > &tiledValues, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> YieldTiledValuesFn
A function that allows returning additional yielded values during yieldTiledValuesAndReplace.
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, OpFoldResult numThreads, OpFoldResult iterationSize)
Returns true if the maximum tile offset tileSize * numThreads-1 is less than iterationSize.
static LogicalResult getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize, const scf::SCFTilingOptions &options)
static FailureOr< llvm::SetVector< Operation * > > checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, bool reorderOperations)
This utility currently checks whether the first userOp of loop is NOT before the last defineOp of con...
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, OpFoldResult offset, OpFoldResult tileSize)
Returns the bounded tile size given the current offset, loopRange and tileSize, i....
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Append the specified additional newInitOperands operands to the loops existing init operands (or simi...
static LogicalResult generateLoopNestUsingForOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ValueRange destinationTensors, YieldTiledValuesFn yieldTiledValuesFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.for operation.
static FailureOr< OpOperand * > getConsumerFromLoopUses(RewriterBase &rewriter, Operation *loopOp, unsigned resultNumber)
Fetches the OpOperand of the first valid user (and use) of the value val which implements TilingInter...
static void checkSafeToTileToForall(TilingInterface op, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Checks if any of the tiled loops are not parallel.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, ArrayRef< Range > iterationDomain, const scf::SCFTilingOptions &options)
Method to instantiate the tile sizes and/or number of threads specified by the user.
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< LoopLikeOpInterface > loops)
Return the untiled producer whose slice is used in a tiled consumer.
static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using the loop construct specifed in options.
static bool tileDividesIterationDomain(Range loopRange)
Check if stride evenly divides the trip count size - offset.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes)
Function to return the bounds of the loops to be generated.
static Operation * cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *op, ValueRange newDestArgs)
Clones the operation and updates the destination if the operation implements the DestinationStyleOpIn...
static FailureOr< Operation * > getFirstUserOfLoop(Operation *loopOp)
An utility to get the first user of the given loopOp.
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, const scf::SCFTilingOptions &options)
static LogicalResult addInitOperandsToLoopNest(RewriterBase &rewriter, MutableArrayRef< LoopLikeOpInterface > loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn)
Method to add new init values to a loop nest.
static FailureOr< MergeResult > mergeTilingResults(RewriterBase &rewriter, TilingInterface op, ValueRange partialResults, const scf::SCFTilingOptions &options)
static FailureOr< OpOperand * > getUntiledConsumerFromSlice(RewriterBase &rewriter, tensor::InsertSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Fetch the untiled consumer of the outermost scf.for's result which is yielded by a tensor....
static FailureOr< SmallVector< Value > > createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, ArrayRef< OpFoldResult > tileSizes, const scf::SCFTilingOptions &options)
static LogicalResult generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ArrayRef< Attribute > mappingVector, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.forall operation.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:968
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgListType getArguments()
Definition: Block.h:87
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:106
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:264
A class for computing basic dominance information.
Definition: Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.cpp:324
This class allows control over how the GreedyPatternRewriteDriver works.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class represents a saved insertion point.
Definition: Builders.h:325
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:551
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:314
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:228
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:456
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:459
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:852
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:602
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:594
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
Definition: AffineOps.cpp:1336
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1329
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1225
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< scf::SCFFuseConsumerOfSliceResult > tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing consumer of a single slice by computing the slice of the consumer in-place f...
FailureOr< scf::SCFTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSizes)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< SmallVector< Operation * > > yieldReplacementForFusedProducer(RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef< LoopLikeOpInterface > loops, ArrayRef< unsigned > yieldResultNumber=ArrayRef< unsigned >{})
Reconstruct the fused producer from within the tiled-and-fused code.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
std::optional< SCFFuseProducerOfSliceResult > tileAndFuseProducerOfSlice(RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing producer of a single slice by computing the slice of the producer in-place.
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:82
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:117
FailureOr< TilingResult > replaceInsertSliceWithTiledConsumer(OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp, OpOperand &consumerOp)
Method to swap an tensor.insert_slice with its consumer when the consumer implements the TilingInterf...
Include the generated interface declarations.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:788
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
@ ExistingAndNewOps
Only pre-existing and newly created ops are processed.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Container for the result of merge operation of tiling.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:283
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:421
Container for result values of tiling.
Fuse the consumer of the source of candidateSliceOp by computing the required slice of the consumer i...
Fuse the producer of the source of candidateSliceOp by computing the required slice of the producer i...
Control function to check if a slice needs to be fused or not, The control function receives 1) the s...
Options used to control tile + fuse.
Transformation information returned after tile and fuse.
Options to use to control tiling.
SCFTileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes to use for each loop.
SCFTilingOptions & setNumThreads(ArrayRef< OpFoldResult > numThreads)
Convenience function to set the numThreadsComputationFunction to a function that computes num threads...
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
ReductionTilingStrategy
Specify how reduction dimensions should be tiled.
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.