MLIR  20.0.0git
TileUsingInterface.cpp
Go to the documentation of this file.
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
24 #include "mlir/IR/Dominance.h"
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
31 #include "llvm/ADT/TypeSwitch.h"
32 #include "llvm/Support/Debug.h"
33 #include <optional>
34 
35 #define DEBUG_TYPE "tile-using-interface"
36 
37 using namespace mlir;
38 
41  assert(!tileSizeComputationFunction && "tile sizes already set");
42  auto tileSizes = llvm::to_vector(ts);
43  tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
44  return tileSizes;
45  };
46  return *this;
47 }
48 
51  assert(!numThreadsComputationFunction && "num tiles already set");
52  auto numThreads = llvm::to_vector(nt);
53  numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) {
54  return numThreads;
55  };
56  return *this;
57 }
58 
59 /// Helper method to adjust the interchange vector to match the iteration
60 /// domain.
63  size_t iterationDomainSize) {
64  SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
65  if (filledVector.size() < iterationDomainSize) {
66  auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
67  filledVector.append(range.begin(), range.end());
68  }
69  if (filledVector.size() > iterationDomainSize)
70  filledVector.resize(iterationDomainSize);
71  return filledVector;
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // tileUsingSCF implementation.
76 //===----------------------------------------------------------------------===//
77 
78 /// Verify the tile size options are set in a consistent manner.
79 static LogicalResult
82  // Specifying number of threads is only supported on `scf.forall` op.
83  if (options.numThreadsComputationFunction &&
85  return rewriter.notifyMatchFailure(
86  loc, "number of threads can only by specified when loop type is "
87  "set to use `scf.forall`");
88  }
89 
90  // If specified, check that the interchange vector is a permutation.
91  if (!options.interchangeVector.empty()) {
92  if (!isPermutationVector(options.interchangeVector)) {
93  return rewriter.notifyMatchFailure(
94  loc, "invalid interchange vector, not a permutation of the entire "
95  "iteration space");
96  }
97  }
98  return success();
99 }
100 
101 /// Method to instantiate the tile sizes and/or number of threads specified
102 /// by the user.
103 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
104 getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
105  ArrayRef<Range> iterationDomain,
107  OpFoldResult zero = rewriter.getIndexAttr(0);
108  SmallVector<OpFoldResult> tileSizes, numThreads;
109  size_t numLoops = iterationDomain.size();
110 
111  // Check whether the number of tiles to use is specified.
112  if (options.numThreadsComputationFunction) {
113  numThreads = options.numThreadsComputationFunction(rewriter, op);
114  numThreads.resize(numLoops, zero);
115 
116  // If the number of tiles is also specified, use that.
117  if (options.tileSizeComputationFunction) {
118  tileSizes = options.tileSizeComputationFunction(rewriter, op);
119  tileSizes.resize(numLoops, zero);
120  return {tileSizes, numThreads};
121  }
122 
123  // Compute the tile sizes from the iteration domain and number
124  // of tiles as follows
125  // - niters = ceilDiv(ub - lb, step)
126  // - tileSize = ceilDiv(niters, numThreads)
127  AffineExpr s0, s1, s2;
128  bindSymbols(rewriter.getContext(), s0, s1, s2);
129  // TODO: The step here is assumed to be 1.
130  AffineExpr numItersExpr = (s1 - s0);
131  AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s2);
132  tileSizes.resize(numLoops, zero);
133  for (auto [index, range, nt] :
134  llvm::enumerate(iterationDomain, numThreads)) {
135  if (isConstantIntValue(nt, 0))
136  continue;
137 
138  tileSizes[index] = affine::makeComposedFoldedAffineApply(
139  rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
140  }
141  tileSizes.resize(numLoops, zero);
142  return {tileSizes, numThreads};
143  }
144 
145  // Enforce the convention that "tiling by zero"
146  // skips tiling a particular dimension. This convention is significantly
147  // simpler to handle instead of adjusting affine maps to account for missing
148  // dimensions.
149  assert(options.tileSizeComputationFunction &&
150  "expected tile sizes to be specified");
151  tileSizes = options.tileSizeComputationFunction(rewriter, op);
152  tileSizes.resize(numLoops, zero);
153 
154  return {tileSizes, numThreads};
155 }
156 
157 /// Checks if any of the tiled loops are not parallel.
158 static void checkSafeToTileToForall(TilingInterface op,
159  ArrayRef<OpFoldResult> tileSizes,
160  ArrayRef<OpFoldResult> numThreads) {
161  auto iterators = op.getLoopIteratorTypes();
162  assert(iterators.size() == tileSizes.size() &&
163  "expected as many tile size values as number of loops");
164  assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
165  "when specified, expected number of threads to use for each loop");
166 
167  for (auto [index, iterator, tileSize] :
168  llvm::enumerate(iterators, tileSizes)) {
169  // If num threads is specified, check that it is greater than one only for
170  // parallel dimensions.
171  if (!numThreads.empty()) {
172  if (std::optional<int64_t> constNumThreads =
173  getConstantIntValue(numThreads[index])) {
174  if (constNumThreads.value() > 1 &&
175  iterator != utils::IteratorType::parallel) {
176  op.emitWarning() << "tiling is not thread safe at axis #" << index;
177  }
178  }
179  continue;
180  }
181 
182  if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) {
183  if (constTileSize.value() > 0 &&
184  iterator != utils::IteratorType::parallel) {
185  op.emitWarning() << "tiling is not thread safe at axis #" << index;
186  }
187  }
188  }
189 }
190 
191 /// Check if `stride` evenly divides the trip count `size - offset`.
192 static bool tileDividesIterationDomain(Range loopRange) {
193  std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
194  if (!offsetAsInt)
195  return false;
196  std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
197  if (!sizeAsInt)
198  return false;
199  std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
200  if (!strideAsInt)
201  return false;
202  return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
203 }
204 
205 /// Returns the bounded tile size given the current `offset`, `loopRange` and
206 /// `tileSize`, i.e., `min(tileSize, range.end() - offset)`.
208  Range loopRange, OpFoldResult offset,
209  OpFoldResult tileSize) {
210  std::optional<int64_t> ts = getConstantIntValue(tileSize);
211  if (ts && ts.value() == 1)
212  return tileSize;
213 
215  Range{loopRange.offset, loopRange.size, tileSize}))
216  return tileSize;
217 
218  // The tile size to use (to avoid out of bounds access) is minimum of
219  // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
220  // loop.
221  AffineExpr s0, s1, d0;
222  bindDims(b.getContext(), d0);
223  bindSymbols(b.getContext(), s0, s1);
224  AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext());
225  Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
227  b, loc, minMap, SmallVector<OpFoldResult>{offset, size, tileSize});
228 }
229 
230 /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
231 /// than `iterationSize`.
233  OpFoldResult numThreads,
234  OpFoldResult iterationSize) {
235  std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
236  std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
237  std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
238  if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
239  return false;
240  return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
241 }
242 
243 /// Compute the `OpFoldResult`s that represents the multi-dimensional
244 /// `offset`s and `size`s of the tile of the iteration space that the
245 /// innermost loop body of the generated tiled loops corresponds to.
246 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
248  ArrayRef<Range> iterationDomain,
249  ArrayRef<OpFoldResult> tileSizes,
250  ArrayRef<OpFoldResult> numThreads) {
251  SmallVector<OpFoldResult> offsets, sizes;
252  int materializedLoopNum = 0;
253 
254  if (!numThreads.empty()) {
255  AffineExpr d0, d1, s0, s1;
256  AffineExpr offsetExpr, residualTileSizeExpr;
257  bindDims(rewriter.getContext(), d0, d1);
258  bindSymbols(rewriter.getContext(), s0, s1);
259  offsetExpr = d0 + d1 * s0;
260  residualTileSizeExpr = s1 - (d0 + d1 * s0);
261 
262  for (auto [nt, tileSize, loopRange] :
263  llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
264 
265  // Non-tiled cases, set the offset and size to the
266  // `loopRange.offset/size`.
267  if (isConstantIntValue(nt, 0)) {
268  offsets.push_back(loopRange.offset);
269  sizes.push_back(loopRange.size);
270  continue;
271  }
272 
273  Value iv = ivs[materializedLoopNum++];
275  rewriter, loc, offsetExpr,
276  ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize});
278  rewriter, loc, residualTileSizeExpr,
279  {loopRange.offset, nt, tileSize, loopRange.size});
280 
281  OpFoldResult size = tileSize;
282  if (!isConstantIntValue(residualTileSize, 0)) {
283  OpFoldResult sizeMinusOffsetPerThread =
284  affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
285  {offset, loopRange.size});
287  rewriter, loc,
289  {sizeMinusOffsetPerThread, tileSize});
290  }
291 
292  // Consider the case where the original loop was `[0, 100)`.
293  // If number of threads are `7`, the tile size would be computed as
294  // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6)
295  // - `offset = 0 + 6 * 15 = 105`
296  // - `tileSize = min(15, 100 - 105) = -5`
297  // To avoid negative tile sizes, we need to do a further
298  // `nonNegativeTileSize = affine.max(0, tileSize)`.
299  // This `max` can be avoided if
300  // `offset + tileSize * (numThreads - 1) < (ub - lb)`
301  if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) {
302  AffineMap maxMap =
305  rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});
306  }
307 
308  offsets.push_back(offset);
309  sizes.push_back(size);
310  }
311  return {offsets, sizes};
312  } else {
313  for (auto [tileSize, loopRange] :
314  llvm::zip_equal(tileSizes, iterationDomain)) {
315 
316  // Non-tiled cases, set the offset and size to the
317  // `loopRange.offset/size`.
318  if (isConstantIntValue(tileSize, 0)) {
319  offsets.push_back(loopRange.offset);
320  sizes.push_back(loopRange.size);
321  continue;
322  }
323 
324  Value iv = ivs[materializedLoopNum++];
325  OpFoldResult offset = getAsOpFoldResult(iv);
326  offsets.push_back(offset);
327  OpFoldResult size =
328  getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize);
329  sizes.push_back(size);
330  }
331  return {offsets, sizes};
332  }
333 }
334 
335 /// Function to return the bounds of the loops to be generated.
336 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
339  ArrayRef<OpFoldResult> tileSizes) {
340  SmallVector<OpFoldResult> lbs, ubs, steps;
341  for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
342  // No loop if the tile size is 0.
343  if (isConstantIntValue(tileSize, 0))
344  continue;
345  lbs.push_back(loopRange.offset);
346  ubs.push_back(loopRange.size);
347  steps.push_back(tileSize);
348  }
349  return {lbs, ubs, steps};
350 }
351 
352 /// A function that allows returning additional yielded values during
353 /// `yieldTiledValuesAndReplace`.
354 /// - `ivs` induction variable for the loop.
355 /// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
356 /// - `tiledValues` the tiled values to return. Must be of same size as
357 /// `newbbArgs`, each element of this array is inserted into the corresponding
358 /// element in `newbbArgs`.
359 /// - `resultOffsets` is of the same size as `tiledValues` and represents
360 /// the offsets to use when inserting corresponding element from `tiledValues`
361 /// into the element from `newBbArgs`.
362 /// - `resultSizes` is of the same size as `tiledValues` and represents
363 /// the size of the corresponding element from `tiledValues` inserted into
364 /// the element from `newBbArgs`.
365 /// In case the method needs to return `failure()` the method is expected
366 /// to clean up any inserted operations.
367 using YieldTiledValuesFn = std::function<LogicalResult(
368  RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
369  SmallVector<Value> &tiledValues,
370  SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
371  SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
372 
373 /// Clones the operation and updates the destination if the operation
374 /// implements the `DestinationStyleOpInterface`.
376  Operation *op,
377  ValueRange newDestArgs) {
378  Operation *clonedOp = rewriter.clone(*op);
379  if (newDestArgs.empty())
380  return clonedOp;
381  if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
382  destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
383  return clonedOp;
384 }
385 
386 /// Generate the tile-loop nest using `scf.for` operation.
387 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
388 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
389 /// - `destinationTensors` are the init values to use for the outer most loop.
390 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
391 /// most
392 /// loop.
393 /// - `loops` is an in-out parameter into which the generated loops are
394 /// populated.
395 static LogicalResult generateLoopNestUsingForOp(
396  RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
397  ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
398  YieldTiledValuesFn yieldTiledValuesFn,
400  assert(!loopRanges.empty() && "unexpected empty loop ranges");
401  assert(loopRanges.size() == tileSizes.size() &&
402  "expected as many tile sizes as loop ranges");
403  OpBuilder::InsertionGuard guard(rewriter);
404 
405  SmallVector<OpFoldResult> lbs, ubs, steps;
406  std::tie(lbs, ubs, steps) =
407  getLoopBounds(rewriter, loc, loopRanges, tileSizes);
408  SmallVector<Value> lbVals =
409  getValueOrCreateConstantIndexOp(rewriter, loc, lbs);
410  SmallVector<Value> ubVals =
411  getValueOrCreateConstantIndexOp(rewriter, loc, ubs);
412  SmallVector<Value> stepVals =
413  getValueOrCreateConstantIndexOp(rewriter, loc, steps);
414 
415  SmallVector<Value> ivs;
416  for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
417  auto loop =
418  rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
419  [](OpBuilder &bodyBuilder, Location bodyLoc,
420  Value iv, ValueRange /*iterArgs*/) {});
421  loops.push_back(loop);
422  ivs.push_back(loop.getInductionVar());
423  rewriter.setInsertionPointToEnd(loop.getBody());
424  destinationTensors = loop.getRegionIterArgs();
425  }
426 
427  SmallVector<Value> tiledResults;
428  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
429  if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
430  tiledResults, resultOffsets, resultSizes))) {
431  return rewriter.notifyMatchFailure(
432  loc, "failed to generate inner tile loop body");
433  }
434  if (loops.empty())
435  return success();
436 
437  assert(tiledResults.size() == destinationTensors.size() &&
438  "Number of results of body should be equal to number of iter args");
439 
440  // 6. Yield all the results of the tiled operation.
441  SmallVector<Value> yieldedValues;
442  for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
443  llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
444  resultSizes)) {
445  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
446  rewriter.getIndexAttr(1));
447  auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
448  loc, tiledValue, destinationTensor, resultOffset, resultSize,
449  resultStride);
450  yieldedValues.push_back(insertSlice);
451  }
452  rewriter.create<scf::YieldOp>(loc, yieldedValues);
453 
454  // Add the scf.yield operations for all the outer loops.
455  for (auto [outerLoop, innerLoop] :
456  llvm::zip_equal(MutableArrayRef(loops).drop_back(),
457  MutableArrayRef(loops).drop_front())) {
458  rewriter.setInsertionPointToEnd(
459  cast<scf::ForOp>(outerLoop.getOperation()).getBody());
460  rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
461  }
462  return success();
463 }
464 
465 /// Generate the tile-loop nest using `scf.forall` operation.
466 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
467 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
468 /// - `destinationTensors` are the init values to use for the outer most loop.
469 /// - `mappingVector` is the mapping attributes to use for loop construction.
470 /// Can be empty.
471 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
472 /// most
473 /// loop.
474 /// - `loops` is an in-out parameter into which the generated loops are
475 /// populated.
476 static LogicalResult generateLoopNestUsingForallOp(
477  RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
478  ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
479  ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
481  assert(!loopRanges.empty() && "unexpected empty loop ranges");
482  assert(loopRanges.size() == tileSizes.size() &&
483  "expected as many tile sizes as loop ranges");
484  OpBuilder::InsertionGuard guard(rewriter);
485  SmallVector<OpFoldResult> offsets(loopRanges.size()),
486  sizes(loopRanges.size());
487 
488  std::optional<ArrayAttr> mappingAttr;
489  if (!mappingVector.empty())
490  mappingAttr = rewriter.getArrayAttr(mappingVector);
491 
492  scf::ForallOp forallOp;
493  bool useNumThreads = !numThreads.empty();
494 
495  if (useNumThreads) {
496  // Prune the zero numthreads.
497  SmallVector<OpFoldResult> nonZeroNumThreads;
498  for (auto nt : numThreads) {
499  if (isConstantIntValue(nt, 0))
500  continue;
501  nonZeroNumThreads.push_back(nt);
502  }
503  forallOp = rewriter.create<scf::ForallOp>(loc, nonZeroNumThreads,
504  destinationTensors, mappingAttr);
505  } else {
506  SmallVector<OpFoldResult> lbs, ubs, steps;
507  std::tie(lbs, ubs, steps) =
508  getLoopBounds(rewriter, loc, loopRanges, tileSizes);
509  forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps,
510  destinationTensors, mappingAttr);
511  }
512  loops.push_back(forallOp);
513 
514  rewriter.setInsertionPoint(forallOp.getTerminator());
515  destinationTensors = forallOp.getRegionOutArgs();
516 
517  SmallVector<Value> tiledResults;
518  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
519  if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
520  destinationTensors, tiledResults, resultOffsets,
521  resultSizes)))
522  return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
523 
524  rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
525  for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
526  llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
527  resultSizes)) {
528  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
529  rewriter.getIndexAttr(1));
530 
531  rewriter.create<tensor::ParallelInsertSliceOp>(
532  loc, tiledValue, destinationTensor, resultOffset, resultSize,
533  resultStride);
534  }
535  return success();
536 }
537 
538 /// Generate the tile-loop nest using the loop construct specifed in `options`.
539 /// - `options`: Tiling options specified.
540 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
541 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
542 /// - `destinationTensors` are the init values to use for the outer most loop.
543 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
544 /// most
545 /// loop.
546 /// - `loops` is an in-out parameter into which the generated loops are
547 /// populated.
548 static LogicalResult generateLoopNest(
549  RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options,
550  ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes,
551  ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors,
553  // If the tile sizes are all zero, no loops are generated. Just call the
554  // callback function to handle untiled case.
555  if (llvm::all_of(tileSizes, isZeroIndex)) {
556  SmallVector<Value> tiledResults;
557  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
558  return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
559  tiledResults, resultOffsets, resultSizes);
560  }
562  return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
563  destinationTensors, tiledBodyFn, loops);
564  }
567  rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector,
568  destinationTensors, tiledBodyFn, loops);
569  }
570  return rewriter.notifyMatchFailure(loc, "unhandled loop type");
571 }
572 
573 static FailureOr<SmallVector<Value>>
574 createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op,
575  ArrayRef<OpFoldResult> tileSizes,
577  SmallVector<Value> initTensors;
578  Location loc = op->getLoc();
579  switch (options.reductionStrategy) {
581  if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, initTensors)))
582  return failure();
583  return initTensors;
586  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
587  if (!redOp) {
588  return rewriter.notifyMatchFailure(
589  op, "PartialReductionOuterReduction tiling strategy is only supported"
590  "for operations implementing PartialReductionOpInterface");
591  }
592  // Get reduction dimensions.
593  // TODO: PartialReductionOpInterface should really query TilingInterface
594  // itself and find reduction dimensions.
595  SmallVector<int> reductionDims;
596  for (auto [idx, iteratorType] :
597  llvm::enumerate(op.getLoopIteratorTypes())) {
598  if (iteratorType == utils::IteratorType::reduction)
599  reductionDims.push_back(idx);
600  }
601  return redOp.generateInitialTensorForPartialReduction(
602  rewriter, loc, tileSizes, reductionDims);
603  }
604  default:
605  return rewriter.notifyMatchFailure(op,
606  "unhandled reduction tiling strategy");
607  }
608 }
609 
610 static FailureOr<TilingResult>
611 getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
612  ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
615  switch (options.reductionStrategy) {
617  return op.getTiledImplementation(rewriter, offsets, sizes);
620  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
621  if (!redOp) {
622  return rewriter.notifyMatchFailure(
623  op, "PartialReductionOuterReduction tiling strategy is only "
624  "supported for operations "
625  "implementing PartialReductionOpInterface");
626  }
627  // Get reduction dimensions.
628  // TODO: PartialReductionOpInterface should really query TilingInterface
629  // itself and find reduction dimensions.
630  SmallVector<int> reductionDims;
631  for (auto [idx, iteratorType] :
632  llvm::enumerate(op.getLoopIteratorTypes())) {
633  if (iteratorType == utils::IteratorType::reduction)
634  reductionDims.push_back(idx);
635  }
636  return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg,
637  offsets, sizes, reductionDims);
638  }
639  default:
640  return rewriter.notifyMatchFailure(op,
641  "unhandled reduction tiling strategy");
642  }
643 }
644 
645 static LogicalResult
646 getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
647  TilingInterface op, ArrayRef<OpFoldResult> offsets,
649  SmallVector<OpFoldResult> &resultOffset,
650  SmallVector<OpFoldResult> &resultSize,
652 
653  switch (options.reductionStrategy) {
655  return op.getResultTilePosition(rewriter, index, offsets, sizes,
656  resultOffset, resultSize);
659  // TODO: This does not work for non identity accesses to the result tile.
660  // The proper fix is to add a getPartialResultTilePosition method to
661  // PartialReductionOpInterface.
662  resultOffset =
663  SmallVector<OpFoldResult>(offsets.size(), rewriter.getIndexAttr(0));
664  for (size_t i = 0; i < offsets.size(); i++) {
665  resultSize.push_back(
666  tensor::getMixedSize(rewriter, op.getLoc(), tiledResult, i));
667  }
668  return success();
669  default:
670  return rewriter.notifyMatchFailure(op,
671  "unhandled reduction tiling strategy");
672  }
673  }
674 }
675 
676 static FailureOr<MergeResult>
677 mergeTilingResults(RewriterBase &rewriter, TilingInterface op,
678  ValueRange partialResults,
680  switch (options.reductionStrategy) {
682  // No need to merge results for reduction tiling strategy.
683  return MergeResult{{}, partialResults};
686  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
687  if (!redOp) {
688  return rewriter.notifyMatchFailure(
689  op, "PartialReductionOuterReduction tiling strategy is only "
690  "supported for operations "
691  "implementing PartialReductionOpInterface");
692  }
693  // Get reduction dimensions.
694  // TODO: PartialReductionOpInterface should really query TilingInterface
695  // itself and find reduction dimensions.
696  SmallVector<int> reductionDims;
697  for (auto [idx, iteratorType] :
698  llvm::enumerate(op.getLoopIteratorTypes())) {
699  if (iteratorType == utils::IteratorType::reduction)
700  reductionDims.push_back(idx);
701  }
702  return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
703  reductionDims);
704  }
705  default:
706  return rewriter.notifyMatchFailure(op,
707  "unhandled reduction tiling strategy");
708  }
709 }
710 
711 /// Append the specified additional `newInitOperands` operands to the
712 /// loops existing `init` operands (or similar), and replace `loopOp` with
713 /// the new loop that has the additional init operands. The loop body of
714 /// this loop is moved over to the new loop. `yieldTiledValuesFn`
715 /// is called to get the new tiled values returned, and the offset
716 /// and sizes at which the tiled value is inserted into the
717 /// new region iter_args that correspond to the newly added init operands.
718 template <typename LoopType>
719 FailureOr<LoopLikeOpInterface>
720 yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter,
721  ValueRange newInitOperands,
722  YieldTiledValuesFn yieldTiledValuesFn) {
723  return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
724 }
725 
726 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`.
727 template <>
728 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
729  scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
730  YieldTiledValuesFn yieldTiledValuesFn) {
731  OpBuilder::InsertionGuard g(rewriter);
732  Location loc = loopOp.getLoc();
733  rewriter.setInsertionPoint(loopOp);
734 
735  auto inits = llvm::to_vector(loopOp.getInitArgs());
736  inits.append(newInitOperands.begin(), newInitOperands.end());
737  auto newLoop = rewriter.create<scf::ForOp>(
738  loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
739  inits, [](OpBuilder &, Location, Value, ValueRange) {});
740 
741  // Move the loop body to the new op.
742  Block *loopBody = loopOp.getBody();
743  Block *newLoopBody = newLoop.getBody();
744  rewriter.mergeBlocks(
745  loopBody, newLoopBody,
746  newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
747 
748  auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator());
749  rewriter.setInsertionPoint(yieldOp);
750 
751  SmallVector<Value> tiledValues;
752  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
753  ValueRange newRegionIterArgs =
754  newLoop.getRegionIterArgs().take_back(newInitOperands.size());
755  if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
756  newRegionIterArgs, tiledValues, resultOffsets,
757  resultSizes))) {
758  rewriter.eraseOp(newLoop);
759  return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");
760  }
761 
762  SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
763  for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
764  llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
765  resultSizes)) {
766  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
767  rewriter.getIndexAttr(1));
768  Value insert = rewriter.create<tensor::InsertSliceOp>(
769  yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
770  resultStride);
771  newYieldValues.push_back(insert);
772  }
773 
774  rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
775  rewriter.replaceOp(loopOp,
776  newLoop->getResults().take_front(loopOp.getNumResults()));
777  return cast<LoopLikeOpInterface>(newLoop.getOperation());
778 }
779 
780 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall`
781 template <>
782 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
783  scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
784  YieldTiledValuesFn yieldTiledValuesFn) {
785  OpBuilder::InsertionGuard g(rewriter);
786  Location loc = loopOp.getLoc();
787  rewriter.setInsertionPoint(loopOp);
788  auto inits = llvm::to_vector(loopOp.getOutputs());
789  inits.append(newInitOperands.begin(), newInitOperands.end());
790  auto newLoop = rewriter.create<scf::ForallOp>(
791  loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
792  loopOp.getMixedStep(), inits, loopOp.getMapping(),
793  [](OpBuilder &, Location, ValueRange) {});
794 
795  // Move the region of the current block to the newly created op.
796  Block *loopBody = loopOp.getBody();
797  Block *newLoopBody = newLoop.getBody();
798  rewriter.mergeBlocks(
799  loopBody, newLoopBody,
800  newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
801 
802  auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
803  rewriter.setInsertionPoint(terminator);
804  SmallVector<Value> tiledValues;
805  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
806  ValueRange regionIterArgs =
807  newLoop.getRegionIterArgs().take_back(newInitOperands.size());
808  if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
809  regionIterArgs, tiledValues, resultOffsets,
810  resultSizes))) {
811  rewriter.eraseOp(newLoop);
812  return rewriter.notifyMatchFailure(loopOp,
813  "failed to get yielded tiled values");
814  }
815 
816  // Update the terminator.
817  rewriter.setInsertionPointToEnd(terminator.getBody());
818 
819  for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
820  tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
821  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
822  rewriter.getIndexAttr(1));
823  rewriter.create<tensor::ParallelInsertSliceOp>(
824  terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
825  resultStride);
826  }
827 
828  rewriter.replaceOp(loopOp,
829  newLoop->getResults().take_front(loopOp.getNumResults()));
830  return cast<LoopLikeOpInterface>(newLoop.getOperation());
831 }
832 
833 /// Implementation of `yieldTiledValuesAndReplaceLoop` for
834 /// `LoopLikeOpInterface`, that just dispatches to the implementation for each
835 /// supported loop type.
836 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
837  LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
838  ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
840  loopLikeOp.getOperation())
841  .Case<scf::ForOp, scf::ForallOp>(
842  [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
844  loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
845  })
846  .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
847  return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
848  });
849 }
850 
851 /// Method to add new init values to a loop nest. Updates `loops` in-place
852 /// with new loops that use the `newInitValues`. The outer-loops are updated
853 /// to yield the new result values of the inner loop. For the innermost loop,
854 /// the call back `getNewYields` is invoked to get the additional values to
855 /// yield form the innermost loop.
856 static LogicalResult addInitOperandsToLoopNest(
858  ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
859  SmallVector<scf::ForOp> newLoops;
860  if (loops.empty())
861  return success();
862  OpBuilder::InsertionGuard g(rewriter);
863  rewriter.setInsertionPoint(loops.front());
864 
865  SmallVector<Value> ivs;
866  for (auto &loop : loops.drop_back()) {
867  rewriter.setInsertionPoint(loop);
868 
869  // if loops.size() > 1 we assume that scf.for is used for the loops.
870  auto forLoop = cast<scf::ForOp>(loop.getOperation());
871 
872  // Create a new loop with the new init values for this loop.
873  SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs());
874  newInits.append(newInitValues.begin(), newInitValues.end());
875  auto newLoop = rewriter.create<scf::ForOp>(
876  forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
877  forLoop.getStep(), newInits,
878  [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
879 
880  // Merge the body of the new loop with the body of the old loops.
881  SmallVector<Value> sourceBlockArgs;
882  sourceBlockArgs.push_back(newLoop.getInductionVar());
883  auto newRegionIterArgs = newLoop.getRegionIterArgs();
884  sourceBlockArgs.append(
885  newRegionIterArgs.begin(),
886  std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
887  rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
888  rewriter.replaceOp(
889  forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
890  loop = newLoop;
891  ivs.push_back(newLoop.getInductionVar());
892  newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
893  }
894 
895  // Update the loop body of the innermost loop to get new yield values.
896  LoopLikeOpInterface innerMostLoop = loops.back();
897  FailureOr<LoopLikeOpInterface> newInnerMostLoop =
898  yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues,
899  getNewTiledYieldsFn);
900 
901  if (failed(newInnerMostLoop))
902  return innerMostLoop.emitOpError("failed to return additional yields");
903  loops.back() = newInnerMostLoop.value();
904 
905  // Make all other loops except the innermost loops yield the values returned
906  // by the inner loop.
907  for (auto [outerLoop, innerLoop] :
908  llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
909  // Again assume that all the outer loops are scf.for operations.
910  auto outerForLoop = cast<scf::ForOp>(outerLoop);
911  auto outerLoopYield =
912  cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
913  SmallVector<Value> newYields =
914  llvm::to_vector(outerLoopYield.getOperands());
915  ValueRange additionalYields =
916  innerLoop->getResults().take_back(newInitValues.size());
917  newYields.append(additionalYields.begin(), additionalYields.end());
918  rewriter.setInsertionPoint(outerLoopYield);
919  rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
920  }
921  return success();
922 }
923 
924 /// Implementation of tiling transformation of `op` that implements the
925 /// `TilingInterface` using `scf.for` to iterate over the tiles.
926 FailureOr<scf::SCFTilingResult>
927 mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
929  if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) {
930  return failure();
931  }
932 
933  OpBuilder::InsertionGuard guard(rewriter);
934  rewriter.setInsertionPointAfter(op);
935 
936  // 1. Get the range of the loops that are represented by the operation.
937  SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
938 
939  // 2. Materialize the tile sizes and/or number of threads;
940  SmallVector<OpFoldResult> tileSizes, numThreads;
941  std::tie(tileSizes, numThreads) =
942  getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options);
943 
944  // Check if it is safe to tile. This is hold over from previous iterations
945  // of tile to for-all. Consider dropping it.
947  checkSafeToTileToForall(op, tileSizes, numThreads);
948  }
949 
950  // 3. If there is an interchange specified, permute the iteration domain and
951  // the tile sizes.
952  SmallVector<int64_t> interchangeVector;
953  if (!options.interchangeVector.empty()) {
954  interchangeVector = fillInterchangeVector(options.interchangeVector,
955  iterationDomain.size());
956  assert(isPermutationVector(interchangeVector) &&
957  "expected interchange vector to be a permutation");
958 
959  applyPermutationToVector(iterationDomain, interchangeVector);
960  applyPermutationToVector(tileSizes, interchangeVector);
961  if (!numThreads.empty())
962  applyPermutationToVector(numThreads, interchangeVector);
963  }
964 
965  FailureOr<TilingResult> tilingResult;
966  // 4. Define the lambda function used later to generate the body of the
967  // innermost tiled loop.
968  YieldTiledValuesFn innerYieldTiledValuesFn =
969  [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
970  ValueRange regionIterArgs, SmallVector<Value> &tiledResults,
973  -> LogicalResult {
974  // 4a. Compute the `offsets` and `sizes` to use for tiling.
975  SmallVector<OpFoldResult> offsets, sizes;
976  std::tie(offsets, sizes) = getTileOffsetAndSizes(
977  rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
978 
979  // 4b. If interchange was provided, apply inverse of the interchange
980  // to get back the offsets/sizes in the order to be specified.
981  if (!interchangeVector.empty()) {
982  auto inversePermutation = invertPermutationVector(interchangeVector);
985  }
986 
987  // 5. Generate the tiled implementation within the inner most loop.
988 
989  // 5a. Clone the operation within the loop body.
990  auto clonedOp = cast<TilingInterface>(
991  cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
992 
993  // 5b. Early return cloned op if tiling is not happening. We can not
994  // return the original op because it could lead to `rewriter.replaceOp(op,
995  // op->getResults())` and users would get crash.
996  if (llvm::all_of(tileSizes, isZeroIndex)) {
997  tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
998  tilingResult =
999  TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
1000  /*generatedSlices=*/{}};
1001  return success();
1002  }
1003 
1004  // 5c. Tile the cloned operation.
1005  tilingResult = getTiledImplementation(rewriter, clonedOp, regionIterArgs,
1006  offsets, sizes, options);
1007  if (failed(tilingResult)) {
1008  rewriter.eraseOp(clonedOp);
1009  return op.emitOpError("faild to tile operation");
1010  }
1011 
1012  // 5d. Delete the cloned operation.
1013  rewriter.eraseOp(clonedOp);
1014 
1015  // 5e. Compute the offsets at which the result values are to be inserted
1016  // back into its destinations.
1017  for (auto [index, tiledValue] :
1018  llvm::enumerate(tilingResult->tiledValues)) {
1019  tiledResults.push_back(tiledValue);
1020  SmallVector<OpFoldResult> resultOffset, resultSize;
1021  if (failed(getResultTilePosition(rewriter, index, tiledValue, op, offsets,
1022  sizes, resultOffset, resultSize,
1023  options))) {
1024  for (auto op : tilingResult->tiledOps) {
1025  rewriter.eraseOp(op);
1026  }
1027  return rewriter.notifyMatchFailure(
1028  op, "failed to get slice of result produced");
1029  }
1030  resultOffsets.emplace_back(std::move(resultOffset));
1031  resultSizes.emplace_back(std::move(resultSize));
1032  }
1033 
1034  return success();
1035  };
1036 
1037  // 6. Find the destination tensors to use for the operation.
1038  FailureOr<SmallVector<Value>> maybeInits =
1039  createInitialTensorsForTiling(rewriter, op, tileSizes, options);
1040  if (failed(maybeInits)) {
1041  return rewriter.notifyMatchFailure(
1042  op, "unable to create initial tensors for tiling");
1043  }
1044  SmallVector<Value> &initTensors = maybeInits.value();
1045 
1046  // 7. Generate the tiled loops nest using the callback defined above.
1048  if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
1049  tileSizes, numThreads, initTensors,
1050  innerYieldTiledValuesFn, loops)))
1051  return op.emitOpError("failed to generate tiling loops");
1052  assert(succeeded(tilingResult) &&
1053  "expected tiling result to be computed after loop generation");
1054 
1055  SmallVector<Value> partialResults;
1056  if (loops.empty()) {
1057  // If loops are empty, the tiled op is used as the replacement for the
1058  // untiled op.
1059  partialResults = tilingResult->tiledValues;
1060  } else {
1061  partialResults = llvm::map_to_vector(loops.front()->getResults(),
1062  [](OpResult r) -> Value { return r; });
1063  }
1064 
1065  FailureOr<MergeResult> mergeResult =
1066  mergeTilingResults(rewriter, op, partialResults, options);
1067  if (failed(mergeResult)) {
1068  return rewriter.notifyMatchFailure(
1069  op, "Failed to merge partial results from tiling");
1070  }
1071 
1072  return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops,
1073  mergeResult.value(),
1074  tilingResult->generatedSlices};
1075 }
1076 
1077 FailureOr<scf::SCFTilingResult>
1079  PartialReductionOpInterface op,
1080  ArrayRef<OpFoldResult> tileSizes) {
1083  options.setReductionTilingStrategy(SCFTilingOptions::ReductionTilingStrategy::
1084  PartialReductionOuterReduction);
1085  options.setTileSizes(tileSizes);
1086 
1087  TilingInterface tilingInterfaceOp =
1088  dyn_cast<TilingInterface>(op.getOperation());
1089  if (!tilingInterfaceOp) {
1090  return b.notifyMatchFailure(
1091  op,
1092  "Operation implementing PartialReductionOpInterface should implement "
1093  "TilingInterface");
1094  }
1095 
1096  return tileUsingSCF(b, tilingInterfaceOp, options);
1097 }
1098 
1099 //===----------------------------------------------------------------------===//
1100 // tileConsumerAndFuseProducersUsingSCF implementation.
1101 //===----------------------------------------------------------------------===//
1102 
1103 /// Return the untiled producer whose slice is used in a tiled consumer. The
1104 /// method traverses the tile loop nest (`loops`) if needed, and returns the
1105 /// `iter_args` of the outer most that is encountered. Traversing the
1106 /// iter_args indicates that this is a destination operand of the consumer. If
1107 /// there was no loop traversal needed, the second value of the returned tuple
1108 /// is empty.
1109 static std::tuple<OpResult, std::optional<OpOperand *>>
1112  std::optional<OpOperand *> destinationIterArg;
1113  auto loopIt = loops.rbegin();
1114  while (auto iterArg = dyn_cast<BlockArgument>(source->get())) {
1115  auto loop = *loopIt;
1116  if (iterArg.getOwner()->getParentOp() != loop)
1117  break;
1118  source = loop.getTiedLoopInit(iterArg);
1119  loopIt++;
1120  }
1121  if (loopIt == loops.rend())
1122  destinationIterArg = source;
1123  return {dyn_cast<OpResult>(source->get()), destinationIterArg};
1124 }
1125 
1126 /// Implementation of fusing producer of a single slice by computing the
1127 /// slice of the producer in-place.
1128 std::optional<scf::SCFFuseProducerOfSliceResult>
1130  RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1132  // 1. Get the producer of the source (potentially walking through
1133  // `iter_args` of nested `scf.for`)
1134  auto [fusableProducer, destinationInitArg] =
1135  getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
1136  loops);
1137  if (!fusableProducer)
1138  return std::nullopt;
1139  unsigned resultNumber = fusableProducer.getResultNumber();
1140 
1141  OpBuilder::InsertionGuard g(rewriter);
1142  rewriter.setInsertionPoint(candidateSliceOp);
1143 
1144  // 2. Clone the fused producer
1145  // 2a. Compute the destination operands to use for the cloned operation.
1146  SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
1147  Operation *fusableProducerOp = fusableProducer.getOwner();
1148  if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
1150  rewriter, fusableProducerOp->getLoc(), fusableProducerOp,
1151  origDestinationTensors)))
1152  return std::nullopt;
1153 
1154  clonedOpDestinationTensors = origDestinationTensors;
1155  if (destinationInitArg &&
1156  isa<DestinationStyleOpInterface>(fusableProducerOp)) {
1157  // 2b. If the producer is also destination style, then to maintain the
1158  // destination passing style, update the destination of the producer to be
1159  // the source of the slice.
1160  clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1161  }
1162  // 2c. Clone the fused producer.
1163  Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
1164  rewriter, fusableProducerOp, clonedOpDestinationTensors);
1165  // 2d. Update the source of the candidateSlice to be the cloned producer.
1166  // Easier to just clone the slice with different source since
1167  // replacements and DCE of cloned ops becomes easier
1168  SmallVector<Value> candidateSliceOpOperands =
1169  llvm::to_vector(candidateSliceOp->getOperands());
1170  candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
1171  tensor::ExtractSliceOp clonedCandidateSliceOp =
1172  mlir::clone(rewriter, candidateSliceOp,
1173  candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1174 
1175  // 3. Generate the tiled implementation of the producer of the source
1176  FailureOr<TilingResult> tileAndFuseResult =
1178  rewriter, clonedCandidateSliceOp,
1179  clonedProducerOp->getResult(resultNumber));
1180  if (failed(tileAndFuseResult))
1181  return std::nullopt;
1182  // Note: Do not delete the candidateSliceOp, since its passed in from the
1183  // caller.
1184  rewriter.replaceAllUsesWith(candidateSliceOp,
1185  tileAndFuseResult->tiledValues[0]);
1186  rewriter.eraseOp(clonedCandidateSliceOp);
1187  rewriter.eraseOp(clonedProducerOp);
1188 
1189  // 3. If the slice is for a destination operand, for example,
1190  //
1191  // ```mlir
1192  // %0 = linalg.init
1193  // %1 = linalg.fill .. outs(%0 : )
1194  // %2 = scf.for .. iter_args(%arg0 = %1) {
1195  // %3 = scf.for .. iter_args(%arg1 = %arg0) {
1196  // %4 = tensor.extract_slice %arg1 [..]
1197  // .. = linalg.matmul .. outs(%4 : )
1198  // }
1199  // }
1200  // ```
1201  //
1202  // the IR is currently
1203  //
1204  // ```
1205  // %0 = linalg.init
1206  // %1 = linalg.fill
1207  // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
1208  // %3 = scf.for .. iter_args(%arg1 = %arg0) {
1209  // %4 = tensor.extract_slice %arg1[..]
1210  // %5 = linalg.fill .. outs(%4 : )
1211  // .. = linalg.matmul .. outs(%5 : )
1212  // }
1213  // }
1214  // ```
1215  //
1216  // The untiled `linalg.fill` is still used as the `init_value` since it
1217  // was originally a destination operand of the untiled `linalg.matmul`.
1218  // When fusing an operand that is a destination operand, the iter_arg of
1219  // the outer most loop should be changed to use the destination of the
1220  // fused operation. With this the IR will be.
1221  //
1222  // ```
1223  // %0 = linalg.init
1224  // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
1225  // %2 = scf.for .. iter_args(%arg1 = %arg0) {
1226  // %3 = tensor.extract_slice %arg1[..]
1227  // %4 = linalg.fill .. outs(%3 : )
1228  // .. = linalg.matmul .. outs(%4 : )
1229  // }
1230  // }
1231  // ```
1232  if (destinationInitArg &&
1233  isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
1234  loops.front()
1235  ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1236  .set(origDestinationTensors[resultNumber]);
1237  }
1239  fusableProducer, tileAndFuseResult->tiledValues[0],
1240  tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices};
1241 }
1242 
1243 /// Reconstruct the fused producer from within the tiled-and-fused code.
1244 FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
1245  RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1246  scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
1248  ArrayRef<unsigned> yieldResultNumber) {
1249  if (loops.empty())
1250  return success();
1251 
1252  Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
1253  *tiledOwner = fusedProducerInfo.tiledOps[0];
1254 
1255  Location loc = originalOwner->getLoc();
1256  // a. collect all init Value to be appended
1257  SmallVector<unsigned> initNumberList =
1258  yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
1259  0, originalOwner->getNumResults()))
1260  : llvm::to_vector(yieldResultNumber);
1261  SmallVector<Value> initValueList;
1262  for (const auto &resultNumber : initNumberList) {
1263  FailureOr<Value> initValue = tensor::getOrCreateDestination(
1264  rewriter, loc, originalOwner->getResult(resultNumber));
1265  if (succeeded(initValue)) {
1266  initValueList.push_back(initValue.value());
1267  } else {
1268  return failure();
1269  }
1270  }
1271 
1272  SmallVector<Operation *> generatedSlices;
1273  YieldTiledValuesFn newYieldValuesFn =
1274  [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
1275  ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
1277  SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
1278  OpBuilder::InsertionGuard g(innerRewriter);
1279 
1280  // get sliceOp tile information
1281  SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
1282  sliceSizes = sliceOp.getMixedSizes();
1283 
1284  // expect all strides of sliceOp being 1
1285  if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
1286  return !isConstantIntValue(ofr, 1);
1287  }))
1288  return failure();
1289 
1290  unsigned sliceResultNumber =
1291  fusedProducerInfo.origProducer.getResultNumber();
1292 
1293  auto tilableOp = cast<TilingInterface>(originalOwner);
1294  // b. get iterDomain Offset and Sizes based on sliceOp tile
1295  SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
1296  // skip tensor.pack/unpack/pad, which expects single opResult
1297  if (tilableOp->getNumResults() > 1 &&
1298  failed(tilableOp.getIterationDomainTileFromResultTile(
1299  rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1300  iterDomainOffset, iterDomainSizes))) {
1301  // In theory, it is unnecessary to raise an error here. Actually
1302  // although it fails to reconstruct the result tensor, it should not
1303  // broke current fusion anyway. The reason why we must return failure
1304  // currently is that the callback function `newYieldValuesFn` will be
1305  // called after new init operand(s) has already been appended. It will
1306  // take more refactoring to make sure the init operands are added
1307  // consistently in the future. For more details, please refer to:
1308  // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
1309  return failure();
1310  }
1311 
1312  // c. calculate offsets and sizes info of all OpResults respectively based
1313  // on iteration Domain Tile
1314  SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
1315  for (const auto &resultNumber : initNumberList) {
1316  if (resultNumber == sliceResultNumber) {
1317  offsetList.push_back(sliceOffset);
1318  sizesList.push_back(sliceSizes);
1319  } else {
1320  assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1321  // infer result tile according to the iteration domain tile
1322  SmallVector<OpFoldResult> offset, sizes;
1323  if (failed(tilableOp.getResultTilePosition(
1324  rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1325  offset, sizes))) {
1326  return failure();
1327  }
1328  offsetList.push_back(offset);
1329  sizesList.push_back(sizes);
1330  }
1331  }
1332 
1333  // d. create `extract_slice` for `iter_args` for DPS operation if
1334  // necessary
1335  if (auto tiledDestStyleOp =
1336  dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1337  rewriter.setInsertionPoint(tiledDestStyleOp);
1338  for (const auto &&[index, newRegionArg] :
1339  llvm::enumerate(newRegionIterArgs)) {
1340  auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
1341  loc, newRegionArg, offsetList[index], sizesList[index],
1342  SmallVector<OpFoldResult>(offsetList[index].size(),
1343  rewriter.getIndexAttr(1)));
1344  generatedSlices.push_back(destSlice);
1345  unsigned resultNumber = initNumberList[index];
1346  rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
1347  tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1348  });
1349  }
1350  }
1351 
1352  // e. prepare tiled offset and sizes for later `insert_slice` creation by
1353  // caller
1354  Block *block = rewriter.getInsertionPoint()->getBlock();
1355  rewriter.setInsertionPoint(block->getTerminator());
1356  for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
1357  tiledResult.push_back(tiledOwner->getResult(resultNumber));
1358  tiledOffset.emplace_back(offsetList[index]);
1359  tiledSizes.emplace_back(sizesList[index]);
1360  }
1361  return success();
1362  };
1363 
1364  if (failed(addInitOperandsToLoopNest(rewriter, loops, initValueList,
1365  newYieldValuesFn))) {
1366  return failure();
1367  }
1368  return generatedSlices;
1369 }
1370 
1371 namespace {
1372 
1373 //===----------------------------------------------------------------------===//
1374 // SliceTrackingListener
1375 //===----------------------------------------------------------------------===//
1376 
1377 /// This class is a listener for tracking the insertion and removal of
1378 /// `tensor.extract_slice` ops in a worklist. This can be used in a greedy
1379 /// fusion algorithm to apply cleanup patterns in between fusion steps.
1380 class SliceTrackingListener : public RewriterBase::Listener {
1381 public:
1382  explicit SliceTrackingListener(
1383  std::optional<FrozenRewritePatternSet> patterns);
1384  SliceTrackingListener() = default;
1385 
1386  /// Adds the given list of operations to the worklist, and if present,
1387  /// applies the list of `patterns` to the newly added operations. This only
1388  /// processes the given operations and any newly inserted ones by the
1389  /// pattern set.
1390  LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
1391 
1392  /// Add to the new operation worklist if it is an extract_slice.
1393  void notifyOperationInserted(Operation *op,
1394  OpBuilder::InsertPoint previous) override;
1395 
1396  /// Shared helper for operation removal from the worklist.
1397  void removeOp(Operation *op);
1398 
1399  /// Remove the operation from the worklist.
1400  void notifyOperationErased(Operation *op) override;
1401 
1402  /// Remove the operation from the worklist.
1403  void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
1404 
1405  /// The worklist for this transformation keeps track of the slices to visit
1406  /// next for fusion.
1407  std::deque<tensor::ExtractSliceOp> worklist;
1408 
1409 private:
1410  /// Optional pattern set to apply when adding new operations to the
1411  /// worklist.
1412  std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
1413 };
1414 
1415 SliceTrackingListener::SliceTrackingListener(
1416  std::optional<FrozenRewritePatternSet> p) {
1417  patterns = std::move(p);
1418 }
1419 
1420 LogicalResult
1421 SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
1422  for (Operation *op : ops) {
1423  if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1424  worklist.push_back(slice);
1425  }
1426 
1427  if (!patterns)
1428  return success();
1429 
1431  config.listener = this;
1433  return applyOpPatternsGreedily(ops, patterns.value(), config);
1434 }
1435 
1436 void SliceTrackingListener::notifyOperationInserted(
1437  Operation *op, OpBuilder::InsertPoint previous) {
1438  auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1439  if (!slice)
1440  return;
1441  worklist.push_back(slice);
1442 }
1443 
1444 // Scan the worklist for the given op and remove it if present. The
1445 // expectation is for the worklist to be small and for removal to be
1446 // relatively rare.
1447 void SliceTrackingListener::removeOp(Operation *op) {
1448  if (!isa<tensor::ExtractSliceOp>(op))
1449  return;
1450  auto iter = worklist.begin();
1451  while (iter != worklist.end()) {
1452  if (*iter == op)
1453  break;
1454  iter++;
1455  }
1456  if (iter == worklist.end())
1457  return;
1458 
1459  worklist.erase(iter);
1460 }
1461 
1462 void SliceTrackingListener::notifyOperationErased(Operation *op) {
1463  removeOp(op);
1464 }
1465 
1466 void SliceTrackingListener::notifyOperationReplaced(Operation *op,
1467  ValueRange replacement) {
1468  removeOp(op);
1469 }
1470 } // namespace
1471 
1472 /// Implementation of tile consumer and fuse producer greedily.
1473 FailureOr<scf::SCFTileAndFuseResult>
1475  RewriterBase &rewriter, TilingInterface consumer,
1477  // This transformation is only valid for ops that return values (i.e. not
1478  // valid to use with operations that have memref operands).
1479  if (!consumer->getNumResults()) {
1480  return rewriter.notifyMatchFailure(
1481  consumer, "invalid pattern for op with no results");
1482  }
1483 
1484  // 1. First tile the consumer.
1485  SetVector<Operation *> fusedProducers, tiledAndFusedOps;
1486  llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
1487 
1488  FailureOr<scf::SCFTilingResult> tilingResult =
1489  tileUsingSCF(rewriter, consumer, options.tilingOptions);
1490 
1491  if (failed(tilingResult))
1492  return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
1493  for (auto *tiledOp : tilingResult->tiledOps)
1494  tiledAndFusedOps.insert(tiledOp);
1495 
1496  // If there are no loops generated, fusion is immaterial.
1497  auto &loops = tilingResult->loops;
1498  if (loops.empty()) {
1499  DenseMap<Value, Value> replacements;
1500  for (auto [origVal, replacement] : llvm::zip_equal(
1501  consumer->getResults(), tilingResult->mergeResult.replacements)) {
1502  replacements[origVal] = replacement;
1503  }
1504  return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1505  replacements};
1506  }
1507 
1508  // To keep track of replacements for now just record the map from the
1509  // original untiled value to the result number of the for loop. Since the
1510  // loop gets potentially replaced during fusion, keeping the value directly
1511  // wont work.
1512  DenseMap<Value, size_t> origValToResultNumber;
1513  for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
1514  origValToResultNumber[result] = index;
1515  }
1516 
1517  // 2. Typically, the operands of the tiled operation are slices of the
1518  // operands of the untiled operation. These are expressed in IR using
1519  // `tensor.extract_slice` operations with source being the operands of
1520  // the untiled operation. Create a worklist of these
1521  // `tensor.extract_slice` operations. If the producers of the source of
1522  // the `tensor.extract_slice` can be tiled such that the tiled value is
1523  // generated in-place, that effectively tiles + fuses the operations.
1524  struct WorklistItem {
1525  tensor::ExtractSliceOp candidateSlice;
1527  };
1528 
1529  SliceTrackingListener sliceTracker =
1530  SliceTrackingListener(options.cleanupPatterns);
1531 
1532  if (failed(
1533  sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1534  return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1535  }
1536  OpBuilder::InsertionGuard g(rewriter);
1537  while (!sliceTracker.worklist.empty()) {
1538  auto candidateSlice = sliceTracker.worklist.front();
1539  sliceTracker.worklist.pop_front();
1540 
1541  auto [fusableProducer, destinationInitArg] =
1542  getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(),
1543  loops);
1544  if (!fusableProducer)
1545  continue;
1546 
1547  std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1548  options.fusionControlFn(candidateSlice, fusableProducer,
1549  destinationInitArg.has_value());
1550  if (!controlFnResult)
1551  continue;
1552 
1553  WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
1554 
1555  // The operands of the fused producer might themselved be slices of
1556  // values produced by operations that implement the `TilingInterface`.
1557  // Add these operations to the worklist.
1558  std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1559  tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice,
1560  loops);
1561  if (!fusedResult)
1562  continue;
1563 
1564  SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
1565 
1566  if (worklistItem.controlFnResult.yieldProducerReplacement) {
1567  // Reconstruct and yield all opResult of fusableProducerOp by default.
1568  // The caller can specific which one to yield by designating optional
1569  // argument named `yieldResultNumber` of
1570  // `yieldReplacementForFusedProducer`.
1571  Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1572  FailureOr<SmallVector<Operation *>> newSlices =
1574  worklistItem.candidateSlice,
1575  fusedResult.value(), loops);
1576  if (failed(newSlices)) {
1577  return rewriter.notifyMatchFailure(
1578  fusableProducerOp, "failed to replacement value for this "
1579  "operation from within the tiled loop");
1580  }
1581  worklistCandidates.append(newSlices.value());
1582  for (auto [index, result] :
1583  llvm::enumerate(fusableProducerOp->getResults())) {
1584  origValToResultNumber[result] = loops.front()->getNumResults() -
1585  fusableProducerOp->getNumResults() +
1586  index;
1587  }
1588  }
1589  if (Operation *tiledAndFusedOp =
1590  fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1591  fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1592  tiledAndFusedOps.insert(tiledAndFusedOp);
1593  }
1594 
1595  if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1596  return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1597  }
1598  }
1599 
1600  DenseMap<Value, Value> replacements;
1601  for (auto [origVal, resultNumber] : origValToResultNumber) {
1602  replacements[origVal] = loops.front()->getResult(resultNumber);
1603  }
1604 
1605  return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1606  replacements};
1607 }
1608 
1609 //===----------------------------------------------------------------------===//
1610 // tileAndFuseConsumerUsingSCF implementation.
1611 //===----------------------------------------------------------------------===//
1612 
1613 /// A utility function that checks whether the only use of the result of a
1614 /// tensor.insert_slice op is in a scf.yield op.
1615 static LogicalResult
1616 checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
1617  Value result = candidateSliceOp.getResult();
1618  Value::use_range uses = result.getUses();
1619  if (!llvm::hasSingleElement(uses)) {
1620  LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
1621  return failure();
1622  }
1623  OpOperand &operandUse = (*uses.begin());
1624  Operation *userOp = operandUse.getOwner();
1625  if (!isa<scf::YieldOp>(userOp)) {
1626  LLVM_DEBUG(llvm::dbgs()
1627  << "Expected scf.yield to be the only user, but got -> "
1628  << (*userOp));
1629  return failure();
1630  }
1631  if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
1632  LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
1633  "be in the same block\n");
1634  return failure();
1635  }
1636  return success();
1637 }
1638 
1639 /// An utility to get the first user of the given loopOp. If any of user stay
1640 /// in different block of loopOp, return failure.
1641 static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
1642  if (!isa<LoopLikeOpInterface>(loopOp))
1643  return failure();
1644  Operation *firstUserOfLoop = nullptr;
1645  for (Operation *userOp : loopOp->getUsers()) {
1646  // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
1647  // block with any other types of operation. Thus, just redirecting to its
1648  // parent `InParallelOp`. E.g.
1649  //
1650  // ```
1651  // %1 = scf.for {
1652  // ...
1653  // }
1654  // %2 = consumerOp ins(%1, ...)
1655  // scf.forall.in_parallel {
1656  // tensor.parallel_insert_slice %1
1657  // }
1658  // ```
1659  // where `InParallelOp` but not `ParallelInsertSlice` stays in the same
1660  // same block with `consumerOp`.
1661  if (isa<tensor::ParallelInsertSliceOp>(userOp))
1662  userOp = userOp->getParentOfType<scf::InParallelOp>();
1663 
1664  if (loopOp->getBlock() != userOp->getBlock())
1665  return failure();
1666 
1667  if (!firstUserOfLoop || userOp->isBeforeInBlock(firstUserOfLoop))
1668  firstUserOfLoop = userOp;
1669  }
1670  return firstUserOfLoop;
1671 }
1672 
1673 /// This utility currently checks whether the first userOp of loop is NOT
1674 /// before the last defineOp of consumer operand. Because that we need to move
1675 /// the whole loop structure right before the `firstUserOfLoop`. This utility
1676 /// thus helps ensuring that no invalid IR is formed, i.e. no backward slice
1677 /// of consumerOp is dominated by the `firstUserOfLoop`. Saying that:
1678 ///
1679 /// ```
1680 /// %0 = scf.for() {
1681 /// ...
1682 /// }
1683 /// ...
1684 /// %1 = firstUserOfLoop(%0)
1685 /// ...
1686 /// %2 = lastDefOfConsumerOperand
1687 /// ...
1688 /// %3 = consumerOp(%2)
1689 /// ```
1690 ///
1691 /// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it
1692 /// would be invalid to move the `loopOp` right before the `firstUserOfLoop`,
1693 /// a.k.a. use-def chain violation:
1694 ///
1695 /// ```
1696 /// %0:2 = scf.for() {
1697 /// // use before define error
1698 /// %3 = tiledConsumerOp(%2)
1699 /// }
1700 /// %1 = firstUserOfLoop(%0)
1701 /// ...
1702 /// %2 = lastDefOfConsumerOperand
1703 /// ```
1704 ///
1705 /// @param loopOp: loop operation
1706 /// @param consumerOp: consumer operation
1707 /// @param reorderOperations: the flag controls whether to reorder the
1708 /// backward slice w.r.t. the defineOp of `consumerOp` operands.
1709 /// @return: computed backward slice of consumerOp, but excluding those
1710 /// already dominates `firstUserOfLoop`.
1711 static FailureOr<llvm::SetVector<Operation *>>
1713  bool reorderOperations) {
1714  FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
1715  if (failed(firstUserOfLoop))
1716  return failure();
1717 
1719  DominanceInfo dominanceInfo;
1720  options.inclusive = true;
1721  options.omitBlockArguments = true;
1722  bool includeLoopOp = false;
1723  options.filter = [&](Operation *op) {
1724  if (op == loopOp) {
1725  includeLoopOp = true;
1726  return false;
1727  }
1728  // Cut off the slice to not include any operation that already dominates
1729  // firstUserOfLoop.
1730  return !dominanceInfo.properlyDominates(op, *firstUserOfLoop);
1731  };
1733  for (auto operand : consumerOp->getOperands()) {
1734  getBackwardSlice(operand, &slice, options);
1735  }
1736 
1737  if (!slice.empty()) {
1738  // If consumerOp has one producer, which is also the user of loopOp.
1739  // E.g.
1740  // ```
1741  // %0 = %loopOp
1742  // %1 = consumerOp1 ins(%0)
1743  // %2 = consumerOp2 ins(%0, %1)
1744  // ```
1745  // We can not fuse consumerOp2 into loopOp due to UD chain, unless
1746  // consumerOp1 has already been fused into loopOp before.
1747  if (includeLoopOp || !reorderOperations)
1748  return failure();
1749  }
1750 
1751  return slice;
1752 }
1753 
1754 /// Fetches the OpOperand of the first valid user (and use) of the value `val`
1755 /// which implements `TilingInterface` and `DestinationStyleOpInterface`.
1756 /// Returns failure otherwise.
1757 static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
1758  Operation *loopOp,
1759  unsigned resultNumber) {
1760  if (!isa<LoopLikeOpInterface>(loopOp))
1761  return failure();
1762  Value val = loopOp->getResult(resultNumber);
1763  Block *loopBlock = loopOp->getBlock();
1764  for (OpOperand &opOperand : val.getUses()) {
1765  Operation *consumerOp = opOperand.getOwner();
1766  // Step 1. Check if the user is tilable.
1767  if (!isa<TilingInterface>(consumerOp) ||
1768  !isa<DestinationStyleOpInterface>(consumerOp)) {
1769  // TODO: We have to init result of consumer before scf.for, use
1770  // DestinationStyleOpInterface to get result shape from init for now.
1771  // Add support for other op such as op has InferTypeOpInterface.
1772  continue;
1773  }
1774  // Step 2. Check if user stay in the same block.
1775  if (loopBlock != consumerOp->getBlock())
1776  continue;
1777  // Step 3. Check if user has succeeding user. Otherwise, it usually
1778  // represents already tiled.
1779  if (consumerOp->use_empty())
1780  continue;
1781  // Step 4. Check assumption for loop with `reorderOperations` enabled.
1782  FailureOr<llvm::SetVector<Operation *>> slice =
1783  checkAssumptionForLoop(loopOp, consumerOp, true);
1784  if (failed(slice))
1785  continue;
1786  // Step 5. If backward sice is not empty, move them before
1787  // firstUserOfLoop.
1788  if (!slice->empty()) {
1789  mlir::topologicalSort(*slice);
1790  FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
1791  assert(succeeded(firstUserOfLoop) && "First user of loop is not found");
1792  for (auto op : *slice) {
1793  rewriter.moveOpBefore(op, *firstUserOfLoop);
1794  }
1795  }
1796  return &opOperand;
1797  }
1798  return failure();
1799 }
1800 
1801 /// Find the perfectly nested loops outside of given loop(included) sorted
1802 /// from outer to inner.
1803 ///
1804 /// E.g.
1805 ///
1806 /// ```
1807 /// %0 = scf.for()
1808 /// %1 = scf.for()
1809 /// %2 = scf.for()
1810 /// %3 = ...
1811 /// yield %3
1812 /// yield %2
1813 /// yield %1
1814 /// ```
1815 ///
1816 /// This function will return three perfectly nested loops: %0 + %1 + %2, when
1817 /// target inner loop is %2.
1820  SmallVector<scf::ForOp> nestLoops = {loop};
1821  auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp());
1822 
1823  // Check if it is the ForOp that yield the result of inner loop.
1824  auto isForOpYieldResultOfInnerLoop =
1825  [](scf::ForOp outerLoop) -> LogicalResult {
1826  Block *body = outerLoop.getBody();
1827  if (!llvm::hasSingleElement(body->without_terminator()))
1828  return failure();
1829  auto yieldOp = cast<scf::YieldOp>(body->getTerminator());
1830  auto innerForOp = dyn_cast<scf::ForOp>(body->front());
1831  if (!innerForOp)
1832  return failure();
1833  // All of innerForOp results should be yielded.
1834  return success(innerForOp->getNumResults() == yieldOp->getNumOperands());
1835  };
1836 
1837  while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) {
1838  nestLoops.push_back(outerLoop);
1839  outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp());
1840  }
1841  // sorted from outer to inner
1842  return {nestLoops.rbegin(), nestLoops.rend()};
1843 }
1844 
1845 /// Fetch the untiled consumer of a scf.for's result which is yielded by a
1846 /// tensor.insert_slice. This function makes the following assumptions :
1847 /// 1. tensor.insert_slice has scf.yield as its only user.
1848 /// 2. scf.for's corresponding result has only one use.
1849 static FailureOr<OpOperand *>
1851  tensor::InsertSliceOp candidateSliceOp) {
1852  if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
1853  return failure();
1854  Value sliceResult = candidateSliceOp.getResult();
1855  // Step 1. Fetch the corresponding output.
1856  OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
1857  unsigned resultNumber = yieldOpOperand.getOperandNumber();
1858  // Step 2. Check containing op is scf.for.
1859  Operation *containingOp = candidateSliceOp->getParentOp();
1860  auto forOp = dyn_cast<scf::ForOp>(containingOp);
1861  if (!forOp)
1862  return failure();
1863  scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front();
1864 
1865  return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber);
1866 }
1867 
1868 /// Fetch the first untiled consumer of a scf.forall's result which is yielded
1869 /// by a tensor.parallel_insert_slice.
1870 static FailureOr<OpOperand *>
1872  tensor::ParallelInsertSliceOp candidateSliceOp) {
1873  // Step 1. Fetch the corresponding output
1874  Value sliceDest = candidateSliceOp.getDest();
1875  auto iterArg = dyn_cast<BlockArgument>(sliceDest);
1876  if (!iterArg)
1877  return failure();
1878  Operation *containingOp = iterArg.getOwner()->getParentOp();
1879  if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
1880  return failure();
1881  // Step 2. Check that the containing op is scf.forall.
1882  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
1883  if (!forallOp)
1884  return failure();
1885  unsigned resultNumber =
1886  forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
1887  .getResultNumber();
1888 
1889  return getConsumerFromLoopUses(rewriter, containingOp, resultNumber);
1890 }
1891 
1892 /// A utility to fetch an untiled consumer of
1893 /// tensor.insert_slice/tensor.parallel_insert_slice.
1894 static FailureOr<OpOperand *>
1896  if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1897  return getUntiledConsumerFromSlice(rewriter, insertSlice);
1898  } else if (auto parallelInsertSlice =
1899  dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1900  return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice);
1901  } else {
1902  return failure();
1903  }
1904 }
1905 
1906 /// Implementation of fusing consumer of a single slice by computing the
1907 /// slice of the consumer in-place for scf loop.
1908 FailureOr<scf::SCFFuseConsumerOfSliceResult>
1910  Operation *candidateSliceOp) {
1911  if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
1912  candidateSliceOp))
1913  return failure();
1914 
1915  bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
1916 
1917  // 1. Get the consumer of scf.for for the result yielded by
1918  // tensor.insert_slice/parallel_insert_slice.
1919  FailureOr<OpOperand *> maybeConsumerOpOperand =
1920  getUntiledConsumerFromSlice(rewriter, candidateSliceOp);
1921  if (failed(maybeConsumerOpOperand)) {
1922  return rewriter.notifyMatchFailure(candidateSliceOp,
1923  "could not fetch consumer to fuse");
1924  }
1925  OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
1926  Operation *consumerOp = consumerOpOperand->getOwner();
1927  unsigned operandNumber = consumerOpOperand->getOperandNumber();
1928  unsigned resultNumber = 0;
1929  if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
1930  resultNumber = producerResult.getResultNumber();
1931  } else {
1932  return rewriter.notifyMatchFailure(
1933  consumerOp, "consumer op's operand doesn't seem to be an OpResult");
1934  }
1935 
1936  // There are two possible cases regarding `oldLoopOp` here:
1937  // 1. single `scf.forall` or `scf.for`.
1938  // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
1939  // top-level loop is the outer-most one of these nested loops.
1940  LoopLikeOpInterface innerMostLoop =
1941  candidateSliceOp->getParentOfType<LoopLikeOpInterface>();
1943  if (isInsertSliceOp) {
1944  nestedLoops = llvm::map_to_vector(
1946  cast<scf::ForOp>(innerMostLoop.getOperation())),
1947  [](scf::ForOp forOp) {
1948  return cast<LoopLikeOpInterface>(forOp.getOperation());
1949  });
1950  } else {
1951  nestedLoops = {innerMostLoop};
1952  }
1953 
1954  LoopLikeOpInterface outerMostLoop = nestedLoops.front();
1955 
1956  // Check assumption for loop with `reorderOperations` disabled.
1957  if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
1958  return rewriter.notifyMatchFailure(
1959  outerMostLoop, "the first user of loop should not dominate any define "
1960  "of consumer operand(s)");
1961  }
1962 
1963  OpBuilder::InsertionGuard g(rewriter);
1964 
1965  // 2. Check consumer is not using scf loop's output as init.
1966  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
1967  if (!dstOp)
1968  return rewriter.notifyMatchFailure(consumerOp,
1969  "consumer op is not DPS operation");
1970  SmallVector<Value> dpsInits =
1971  llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
1972  if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
1973  return rewriter.notifyMatchFailure(
1974  consumerOp,
1975  "consumer op taking the result of scf.for as init is not supported");
1976  }
1977  SmallVector<Value> newInits = dpsInits;
1978 
1979  Location loc = outerMostLoop->getLoc();
1980 
1981  // 3. Move the whole loop structure right before firstUserOfLoop, the
1982  // dominance should be already ensured by `checkAssumptionForLoop`.
1983  FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(outerMostLoop);
1984  if (failed(firstUserOfLoop)) {
1985  return rewriter.notifyMatchFailure(
1986  outerMostLoop, "could not find the first user of outer most loop");
1987  }
1988  rewriter.moveOpBefore(outerMostLoop, *firstUserOfLoop);
1989 
1990  // 4. Set insertion point before terminator op of the loop and create a new
1991  // tensor.insert_slice. In the scf.for case this is a clone of the
1992  // candidateSliceOp whereas in the scf.forall case this is created from the
1993  // operands of tensor.parallel_insert_slice.
1994  tensor::InsertSliceOp clonedInsertSliceOp;
1995  if (auto sliceOp =
1996  dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
1997  auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
1998  rewriter.setInsertionPoint(newForallOp.getTerminator());
1999  clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
2000  loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
2001  sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
2002  } else {
2003  rewriter.setInsertionPoint(candidateSliceOp);
2004  clonedInsertSliceOp =
2005  cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
2006  }
2007 
2008  // 5.a. Clone consumer op.
2009  auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
2010 
2011  // 5.b. Replace all uses of the loop result with the result of the cloned
2012  // tensor.insert_slice.
2013  OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
2014  rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
2015  operandToReplace.set(clonedInsertSliceOp.getResult());
2016  });
2017 
2018  // 6. Perform tiling of the cloned consumer and replace the operand at
2019  // `operandNumber` with the source of the cloned tensor.insert_slice op.
2020  auto ossSliceOp =
2021  cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
2022  FailureOr<TilingResult> tileAndFuseResult =
2024  rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
2025  if (failed(tileAndFuseResult)) {
2026  return failure();
2027  }
2028  auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
2029  rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
2030  clonedInsertSliceOp.getSource());
2031 
2032  // 7. Reconstruct [nested] loop with new inits.
2033  YieldTiledValuesFn newYieldValuesFn =
2034  [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
2035  ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
2037  SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
2038  OpBuilder::InsertionGuard g(innerRewriter);
2039  // 8. Set inner insertPoint right before tiled consumer op.
2040  innerRewriter.setInsertionPoint(tiledConsumerOp);
2041 
2042  SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
2043  SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
2044  SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
2045 
2046  // 9. Check all insert stride is 1.
2047  if (llvm::any_of(strides, [](OpFoldResult stride) {
2048  return !isConstantIntValue(stride, 1);
2049  })) {
2050  return rewriter.notifyMatchFailure(
2051  candidateSliceOp, "containingOp's result yield with stride");
2052  }
2053 
2054  // 10. Try to get iter domain position from input position. Use
2055  // clonedConsumerOp instead of tiledConsumerOp, because the iteration
2056  // domain may require index computation based on the result size. The
2057  // sizes and offsets should be the same either way, but using
2058  // tiledConsumerOp could lead to some chained unnecessary extra index
2059  // computation.
2060  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
2061  if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
2062  rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
2063  iterDomainSizes))) {
2064  return rewriter.notifyMatchFailure(
2065  clonedConsumerOp,
2066  "can't get iter domain position from input position");
2067  }
2068 
2069  // 11. Try to fetch the offset and size for all results of the cloned
2070  // consumer. This would then be used to form the corresponding
2071  // tensor.insert_slice/parallel_insert_slice later.
2072  unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
2074  totalNumResultsOfConsumer);
2076  totalNumResultsOfConsumer);
2077  for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
2078  if (failed(tiledConsumerOp.getResultTilePosition(
2079  rewriter, idx, iterDomainOffsets, iterDomainSizes,
2080  resultOffsets[idx], resultSizes[idx]))) {
2081  return rewriter.notifyMatchFailure(
2082  tiledConsumerOp,
2083  "can't get result domain position from iter domain position");
2084  }
2085  }
2086 
2087  // 12. Create `extract_slice` for `iter_args` for DPS operation if
2088  // necessary.
2089  if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
2090  tiledConsumerOp.getOperation())) {
2091  rewriter.setInsertionPoint(tiledDestStyleOp);
2092  for (const auto &&[index, newRegionArg] :
2093  llvm::enumerate(newRegionIterArgs)) {
2094  auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
2095  loc, newRegionArg, resultOffsets[index], resultSizes[index],
2096  SmallVector<OpFoldResult>(resultOffsets[index].size(),
2097  rewriter.getIndexAttr(1)));
2098  // Make a copy of index to avoid a capturing structured binding, which
2099  // is a C++20 extension.
2100  auto dstNumber = index;
2101  rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
2102  tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
2103  });
2104  }
2105  }
2106 
2107  // 13. Prepare tiled offset and sizes for later `insert_slice` creation by
2108  // caller.
2109  Block *block = rewriter.getInsertionPoint()->getBlock();
2110  rewriter.setInsertionPoint(block->getTerminator());
2111  for (const auto &&[index, result] :
2112  llvm::enumerate(tiledConsumerOp->getResults())) {
2113  tiledResult.push_back(result);
2114  tiledOffset.emplace_back(resultOffsets[index]);
2115  tiledSizes.emplace_back(resultSizes[index]);
2116  }
2117  return success();
2118  };
2119  // 14. Add new inits to [nested] loops.
2120  if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits,
2121  newYieldValuesFn))) {
2122  return rewriter.notifyMatchFailure(tiledConsumerOp,
2123  "unable to add new inits to nest loop");
2124  }
2125 
2126  // 15. Replace the result of scf loop and consumer op with new loop's
2127  // results.
2128 
2129  for (auto &&[oldResult, newResult] : llvm::zip(
2130  consumerOp->getResults(),
2131  nestedLoops.front()->getResults().take_back(newInits.size()))) {
2132  rewriter.replaceAllUsesWith(oldResult, newResult);
2133  }
2134 
2135  // 16. Need to erase the old scf loop and the cloned consumer op.
2136  rewriter.eraseOp(clonedConsumerOp);
2137 
2139  consumerOpOperand,
2140  &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
2141  tileAndFuseResult->tiledOps};
2142 }
2143 
2144 //===----------------------------------------------------------------------===//
2145 // lowerToLoopsUsingSCFForOp implementation.
2146 //===----------------------------------------------------------------------===//
2147 
2148 FailureOr<SmallVector<scf::ForOp>>
2150  TilingInterface op) {
2151  // TODO: Handle cases where the op has results if needed.
2152  if (op->getNumResults() > 0) {
2153  return rewriter.notifyMatchFailure(
2154  op, "unable to lower to loops operations with return values");
2155  }
2156 
2157  SmallVector<Range> domain = op.getIterationDomain(rewriter);
2158  SmallVector<Value> ivs;
2160  Location loc = op.getLoc();
2161  for (auto loopRange : domain) {
2162  Value offsetVal =
2163  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
2164  Value sizeVal =
2165  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
2166  Value strideVal =
2167  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
2168  auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
2169  strideVal, ValueRange{});
2170  loops.push_back(loop);
2171  ivs.push_back(loop.getInductionVar());
2172  rewriter.setInsertionPoint(loop.getBody()->getTerminator());
2173  }
2174  if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
2175  return failure();
2176  }
2177  return loops;
2178 }
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static LogicalResult verifyTileSizeOptions(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options)
Verify the tile size options are set in a consistent manner.
static LogicalResult checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp)
A utility function that checks whether the only use of the result of a tensor.insert_slice op is in a...
std::function< LogicalResult(RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, SmallVector< Value > &tiledValues, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> YieldTiledValuesFn
A function that allows returning additional yielded values during yieldTiledValuesAndReplace.
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, OpFoldResult numThreads, OpFoldResult iterationSize)
Returns true if the maximum tile offset tileSize * numThreads-1 is less than iterationSize.
static LogicalResult getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize, const scf::SCFTilingOptions &options)
static FailureOr< llvm::SetVector< Operation * > > checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, bool reorderOperations)
This utility currently checks whether the first userOp of loop is NOT before the last defineOp of con...
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, OpFoldResult offset, OpFoldResult tileSize)
Returns the bounded tile size given the current offset, loopRange and tileSize, i....
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Append the specified additional newInitOperands operands to the loops existing init operands (or simi...
static LogicalResult generateLoopNestUsingForOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ValueRange destinationTensors, YieldTiledValuesFn yieldTiledValuesFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.for operation.
static FailureOr< OpOperand * > getConsumerFromLoopUses(RewriterBase &rewriter, Operation *loopOp, unsigned resultNumber)
Fetches the OpOperand of the first valid user (and use) of the value val which implements TilingInter...
static FailureOr< OpOperand * > getUntiledConsumerFromSlice(RewriterBase &rewriter, tensor::InsertSliceOp candidateSliceOp)
Fetch the untiled consumer of a scf.for's result which is yielded by a tensor.insert_slice.
static void checkSafeToTileToForall(TilingInterface op, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Checks if any of the tiled loops are not parallel.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, ArrayRef< Range > iterationDomain, const scf::SCFTilingOptions &options)
Method to instantiate the tile sizes and/or number of threads specified by the user.
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< LoopLikeOpInterface > loops)
Return the untiled producer whose slice is used in a tiled consumer.
static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using the loop construct specifed in options.
static SmallVector< scf::ForOp > getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop)
Find the perfectly nested loops outside of given loop(included) sorted from outer to inner.
static bool tileDividesIterationDomain(Range loopRange)
Check if stride evenly divides the trip count size - offset.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes)
Function to return the bounds of the loops to be generated.
static Operation * cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *op, ValueRange newDestArgs)
Clones the operation and updates the destination if the operation implements the DestinationStyleOpIn...
static FailureOr< Operation * > getFirstUserOfLoop(Operation *loopOp)
An utility to get the first user of the given loopOp.
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, const scf::SCFTilingOptions &options)
static LogicalResult addInitOperandsToLoopNest(RewriterBase &rewriter, MutableArrayRef< LoopLikeOpInterface > loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn)
Method to add new init values to a loop nest.
static FailureOr< MergeResult > mergeTilingResults(RewriterBase &rewriter, TilingInterface op, ValueRange partialResults, const scf::SCFTilingOptions &options)
static FailureOr< SmallVector< Value > > createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, ArrayRef< OpFoldResult > tileSizes, const scf::SCFTilingOptions &options)
static LogicalResult generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ArrayRef< Attribute > mappingVector, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.forall operation.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:964
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:306
A class for computing basic dominance information.
Definition: Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.h:153
This class allows control over how the GreedyPatternRewriteDriver works.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class represents a saved insertion point.
Definition: Builders.h:336
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:454
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:445
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:421
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:466
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:853
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:874
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:644
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
Definition: AffineOps.cpp:1307
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1300
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1194
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
FailureOr< scf::SCFFuseConsumerOfSliceResult > tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp)
Implementation of fusing consumer of a single slice by computing the slice of the consumer in-place f...
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< SmallVector< Operation * > > yieldReplacementForFusedProducer(RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef< LoopLikeOpInterface > loops, ArrayRef< unsigned > yieldResultNumber=ArrayRef< unsigned >{})
Reconstruct the fused producer from within the tiled-and-fused code.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
FailureOr< scf::SCFTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
std::optional< SCFFuseProducerOfSliceResult > tileAndFuseProducerOfSlice(RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing producer of a single slice by computing the slice of the producer in-place.
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:56
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:75
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:110
FailureOr< TilingResult > replaceInsertSliceWithTiledConsumer(OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp, OpOperand &consumerOp)
Method to swap an tensor.insert_slice with its consumer when the consumer implements the TilingInterf...
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
const FrozenRewritePatternSet GreedyRewriteConfig config
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
@ ExistingAndNewOps
Only pre-existing and newly created ops are processed.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Container for the result of merge operation of tiling.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
Container for result values of tiling.
Fuse the consumer of the source of candidateSliceOp by computing the required slice of the consumer i...
Fuse the producer of the source of candidateSliceOp by computing the required slice of the producer i...
Control function to check if a slice needs to be fused or not, The control function receives 1) the s...
Options used to control tile + fuse.
Transformation information returned after tile and fuse.
Options to use to control tiling.
SCFTileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes to use for each loop.
SCFTilingOptions & setNumThreads(ArrayRef< OpFoldResult > numThreads)
Convenience function to set the numThreadsComputationFunction to a function that computes num threads...
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
ReductionTilingStrategy
Specify how reduction dimensions should be tiled.
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.