MLIR  21.0.0git
TileUsingInterface.cpp
Go to the documentation of this file.
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
24 #include "mlir/IR/Dominance.h"
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
31 #include "llvm/ADT/ScopeExit.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/Debug.h"
34 #include <optional>
35 
36 #define DEBUG_TYPE "tile-using-interface"
37 
38 using namespace mlir;
39 
42  assert(!tileSizeComputationFunction && "tile sizes already set");
43  auto tileSizes = llvm::to_vector(ts);
44  tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
45  return tileSizes;
46  };
47  return *this;
48 }
49 
52  assert(!numThreadsComputationFunction && "num tiles already set");
53  auto numThreads = llvm::to_vector(nt);
54  numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) {
55  return numThreads;
56  };
57  return *this;
58 }
59 
60 /// Helper method to adjust the interchange vector to match the iteration
61 /// domain.
64  size_t iterationDomainSize) {
65  SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
66  if (filledVector.size() < iterationDomainSize) {
67  auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
68  filledVector.append(range.begin(), range.end());
69  }
70  if (filledVector.size() > iterationDomainSize)
71  filledVector.resize(iterationDomainSize);
72  return filledVector;
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // tileUsingSCF implementation.
77 //===----------------------------------------------------------------------===//
78 
79 /// Verify the tile size options are set in a consistent manner.
80 static LogicalResult
83  // Specifying number of threads is only supported on `scf.forall` op.
84  if (options.numThreadsComputationFunction &&
86  return rewriter.notifyMatchFailure(
87  loc, "number of threads can only by specified when loop type is "
88  "set to use `scf.forall`");
89  }
90 
91  // If specified, check that the interchange vector is a permutation.
92  if (!options.interchangeVector.empty()) {
93  if (!isPermutationVector(options.interchangeVector)) {
94  return rewriter.notifyMatchFailure(
95  loc, "invalid interchange vector, not a permutation of the entire "
96  "iteration space");
97  }
98  }
99  return success();
100 }
101 
102 /// Method to instantiate the tile sizes and/or number of threads specified
103 /// by the user.
104 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
105 getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
106  ArrayRef<Range> iterationDomain,
108  OpFoldResult zero = rewriter.getIndexAttr(0);
109  SmallVector<OpFoldResult> tileSizes, numThreads;
110  size_t numLoops = iterationDomain.size();
111 
112  // Check whether the number of tiles to use is specified.
113  if (options.numThreadsComputationFunction) {
114  numThreads = options.numThreadsComputationFunction(rewriter, op);
115  numThreads.resize(numLoops, zero);
116 
117  // If the number of tiles is also specified, use that.
118  if (options.tileSizeComputationFunction) {
119  tileSizes = options.tileSizeComputationFunction(rewriter, op);
120  tileSizes.resize(numLoops, zero);
121  return {tileSizes, numThreads};
122  }
123 
124  // Compute the tile sizes from the iteration domain and number
125  // of tiles as follows
126  // - niters = ceilDiv(ub - lb, step)
127  // - tileSize = ceilDiv(niters, numThreads)
128  AffineExpr s0, s1, s2;
129  bindSymbols(rewriter.getContext(), s0, s1, s2);
130  // TODO: The step here is assumed to be 1.
131  AffineExpr numItersExpr = (s1 - s0);
132  AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s2);
133  tileSizes.resize(numLoops, zero);
134  for (auto [index, range, nt] :
135  llvm::enumerate(iterationDomain, numThreads)) {
136  if (isConstantIntValue(nt, 0))
137  continue;
138 
139  tileSizes[index] = affine::makeComposedFoldedAffineApply(
140  rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
141  }
142  tileSizes.resize(numLoops, zero);
143  return {tileSizes, numThreads};
144  }
145 
146  // Enforce the convention that "tiling by zero"
147  // skips tiling a particular dimension. This convention is significantly
148  // simpler to handle instead of adjusting affine maps to account for missing
149  // dimensions.
150  assert(options.tileSizeComputationFunction &&
151  "expected tile sizes to be specified");
152  tileSizes = options.tileSizeComputationFunction(rewriter, op);
153  tileSizes.resize(numLoops, zero);
154 
155  return {tileSizes, numThreads};
156 }
157 
158 /// Checks if any of the tiled loops are not parallel.
159 static void checkSafeToTileToForall(TilingInterface op,
160  ArrayRef<OpFoldResult> tileSizes,
161  ArrayRef<OpFoldResult> numThreads) {
162  auto iterators = op.getLoopIteratorTypes();
163  assert(iterators.size() == tileSizes.size() &&
164  "expected as many tile size values as number of loops");
165  assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
166  "when specified, expected number of threads to use for each loop");
167 
168  for (auto [index, iterator, tileSize] :
169  llvm::enumerate(iterators, tileSizes)) {
170  // If num threads is specified, check that it is greater than one only for
171  // parallel dimensions.
172  if (!numThreads.empty()) {
173  if (std::optional<int64_t> constNumThreads =
174  getConstantIntValue(numThreads[index])) {
175  if (constNumThreads.value() > 1 &&
176  iterator != utils::IteratorType::parallel) {
177  op.emitWarning() << "tiling is not thread safe at axis #" << index;
178  }
179  }
180  continue;
181  }
182 
183  if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) {
184  if (constTileSize.value() > 0 &&
185  iterator != utils::IteratorType::parallel) {
186  op.emitWarning() << "tiling is not thread safe at axis #" << index;
187  }
188  }
189  }
190 }
191 
192 /// Check if `stride` evenly divides the trip count `size - offset`.
193 static bool tileDividesIterationDomain(Range loopRange) {
194  std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
195  if (!offsetAsInt)
196  return false;
197  std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
198  if (!sizeAsInt)
199  return false;
200  std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
201  if (!strideAsInt)
202  return false;
203  return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
204 }
205 
206 /// Returns the bounded tile size given the current `offset`, `loopRange` and
207 /// `tileSize`, i.e., `min(tileSize, range.end() - offset)`.
209  Range loopRange, OpFoldResult offset,
210  OpFoldResult tileSize) {
211  std::optional<int64_t> ts = getConstantIntValue(tileSize);
212  if (ts && ts.value() == 1)
213  return tileSize;
214 
216  Range{loopRange.offset, loopRange.size, tileSize}))
217  return tileSize;
218 
219  // The tile size to use (to avoid out of bounds access) is minimum of
220  // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
221  // loop.
222  AffineExpr s0, s1, d0;
223  bindDims(b.getContext(), d0);
224  bindSymbols(b.getContext(), s0, s1);
225  AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext());
226  Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
228  b, loc, minMap, SmallVector<OpFoldResult>{offset, size, tileSize});
229 }
230 
231 /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
232 /// than `iterationSize`.
234  OpFoldResult numThreads,
235  OpFoldResult iterationSize) {
236  std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
237  std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
238  std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
239  if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
240  return false;
241  return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
242 }
243 
244 /// Compute the `OpFoldResult`s that represents the multi-dimensional
245 /// `offset`s and `size`s of the tile of the iteration space that the
246 /// innermost loop body of the generated tiled loops corresponds to.
247 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
249  ArrayRef<Range> iterationDomain,
250  ArrayRef<OpFoldResult> tileSizes,
251  ArrayRef<OpFoldResult> numThreads) {
252  SmallVector<OpFoldResult> offsets, sizes;
253  int materializedLoopNum = 0;
254 
255  if (!numThreads.empty()) {
256  AffineExpr d0, d1, s0, s1;
257  AffineExpr offsetExpr, residualTileSizeExpr;
258  bindDims(rewriter.getContext(), d0, d1);
259  bindSymbols(rewriter.getContext(), s0, s1);
260  offsetExpr = d0 + d1 * s0;
261  residualTileSizeExpr = s1 - (d0 + d1 * s0);
262 
263  for (auto [nt, tileSize, loopRange] :
264  llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
265 
266  // Non-tiled cases, set the offset and size to the
267  // `loopRange.offset/size`.
268  if (isConstantIntValue(nt, 0)) {
269  offsets.push_back(loopRange.offset);
270  sizes.push_back(loopRange.size);
271  continue;
272  }
273 
274  Value iv = ivs[materializedLoopNum++];
276  rewriter, loc, offsetExpr,
277  ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize});
279  rewriter, loc, residualTileSizeExpr,
280  {loopRange.offset, nt, tileSize, loopRange.size});
281 
282  OpFoldResult size = tileSize;
283  if (!isConstantIntValue(residualTileSize, 0)) {
284  OpFoldResult sizeMinusOffsetPerThread =
285  affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
286  {offset, loopRange.size});
288  rewriter, loc,
290  {sizeMinusOffsetPerThread, tileSize});
291  }
292 
293  // Consider the case where the original loop was `[0, 100)`.
294  // If number of threads are `7`, the tile size would be computed as
295  // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6)
296  // - `offset = 0 + 6 * 15 = 105`
297  // - `tileSize = min(15, 100 - 105) = -5`
298  // To avoid negative tile sizes, we need to do a further
299  // `nonNegativeTileSize = affine.max(0, tileSize)`.
300  // This `max` can be avoided if
301  // `offset + tileSize * (numThreads - 1) < (ub - lb)`
302  if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) {
303  AffineMap maxMap =
306  rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});
307  }
308 
309  offsets.push_back(offset);
310  sizes.push_back(size);
311  }
312  return {offsets, sizes};
313  } else {
314  for (auto [tileSize, loopRange] :
315  llvm::zip_equal(tileSizes, iterationDomain)) {
316 
317  // Non-tiled cases, set the offset and size to the
318  // `loopRange.offset/size`.
319  if (isConstantIntValue(tileSize, 0)) {
320  offsets.push_back(loopRange.offset);
321  sizes.push_back(loopRange.size);
322  continue;
323  }
324 
325  Value iv = ivs[materializedLoopNum++];
326  OpFoldResult offset = getAsOpFoldResult(iv);
327  offsets.push_back(offset);
328  OpFoldResult size =
329  getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize);
330  sizes.push_back(size);
331  }
332  return {offsets, sizes};
333  }
334 }
335 
336 /// Function to return the bounds of the loops to be generated.
337 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
340  ArrayRef<OpFoldResult> tileSizes) {
341  SmallVector<OpFoldResult> lbs, ubs, steps;
342  for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
343  // No loop if the tile size is 0.
344  if (isConstantIntValue(tileSize, 0))
345  continue;
346  lbs.push_back(loopRange.offset);
347  ubs.push_back(loopRange.size);
348  steps.push_back(tileSize);
349  }
350  return {lbs, ubs, steps};
351 }
352 
353 /// A function that allows returning additional yielded values during
354 /// `yieldTiledValuesAndReplace`.
355 /// - `ivs` induction variable for the loop.
356 /// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
357 /// - `tiledValues` the tiled values to return. Must be of same size as
358 /// `newbbArgs`, each element of this array is inserted into the corresponding
359 /// element in `newbbArgs`.
360 /// - `resultOffsets` is of the same size as `tiledValues` and represents
361 /// the offsets to use when inserting corresponding element from `tiledValues`
362 /// into the element from `newBbArgs`.
363 /// - `resultSizes` is of the same size as `tiledValues` and represents
364 /// the size of the corresponding element from `tiledValues` inserted into
365 /// the element from `newBbArgs`.
366 /// In case the method needs to return `failure()` the method is expected
367 /// to clean up any inserted operations.
368 using YieldTiledValuesFn = std::function<LogicalResult(
369  RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
370  SmallVector<Value> &tiledValues,
371  SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
372  SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
373 
374 /// Clones the operation and updates the destination if the operation
375 /// implements the `DestinationStyleOpInterface`.
377  Operation *op,
378  ValueRange newDestArgs) {
379  Operation *clonedOp = rewriter.clone(*op);
380  if (newDestArgs.empty())
381  return clonedOp;
382  if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
383  destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
384  return clonedOp;
385 }
386 
387 /// Generate the tile-loop nest using `scf.for` operation.
388 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
389 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
390 /// - `destinationTensors` are the init values to use for the outer most loop.
391 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
392 /// most
393 /// loop.
394 /// - `loops` is an in-out parameter into which the generated loops are
395 /// populated.
396 static LogicalResult generateLoopNestUsingForOp(
397  RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
398  ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
399  YieldTiledValuesFn yieldTiledValuesFn,
401  assert(!loopRanges.empty() && "unexpected empty loop ranges");
402  assert(loopRanges.size() == tileSizes.size() &&
403  "expected as many tile sizes as loop ranges");
404  OpBuilder::InsertionGuard guard(rewriter);
405 
406  SmallVector<OpFoldResult> lbs, ubs, steps;
407  std::tie(lbs, ubs, steps) =
408  getLoopBounds(rewriter, loc, loopRanges, tileSizes);
409  SmallVector<Value> lbVals =
410  getValueOrCreateConstantIndexOp(rewriter, loc, lbs);
411  SmallVector<Value> ubVals =
412  getValueOrCreateConstantIndexOp(rewriter, loc, ubs);
413  SmallVector<Value> stepVals =
414  getValueOrCreateConstantIndexOp(rewriter, loc, steps);
415 
416  SmallVector<Value> ivs;
417  for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
418  auto loop =
419  rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
420  [](OpBuilder &bodyBuilder, Location bodyLoc,
421  Value iv, ValueRange /*iterArgs*/) {});
422  loops.push_back(loop);
423  ivs.push_back(loop.getInductionVar());
424  rewriter.setInsertionPointToEnd(loop.getBody());
425  destinationTensors = loop.getRegionIterArgs();
426  }
427 
428  SmallVector<Value> tiledResults;
429  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
430  if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
431  tiledResults, resultOffsets, resultSizes))) {
432  return rewriter.notifyMatchFailure(
433  loc, "failed to generate inner tile loop body");
434  }
435  if (loops.empty())
436  return success();
437 
438  assert(tiledResults.size() == destinationTensors.size() &&
439  "Number of results of body should be equal to number of iter args");
440 
441  // 6. Yield all the results of the tiled operation.
442  SmallVector<Value> yieldedValues;
443  for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
444  llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
445  resultSizes)) {
446  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
447  rewriter.getIndexAttr(1));
448  auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
449  loc, tiledValue, destinationTensor, resultOffset, resultSize,
450  resultStride);
451  yieldedValues.push_back(insertSlice);
452  }
453  rewriter.create<scf::YieldOp>(loc, yieldedValues);
454 
455  // Add the scf.yield operations for all the outer loops.
456  for (auto [outerLoop, innerLoop] :
457  llvm::zip_equal(MutableArrayRef(loops).drop_back(),
458  MutableArrayRef(loops).drop_front())) {
459  rewriter.setInsertionPointToEnd(
460  cast<scf::ForOp>(outerLoop.getOperation()).getBody());
461  rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
462  }
463  return success();
464 }
465 
466 /// Generate the tile-loop nest using `scf.forall` operation.
467 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
468 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
469 /// - `destinationTensors` are the init values to use for the outer most loop.
470 /// - `mappingVector` is the mapping attributes to use for loop construction.
471 /// Can be empty.
472 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
473 /// most
474 /// loop.
475 /// - `loops` is an in-out parameter into which the generated loops are
476 /// populated.
477 static LogicalResult generateLoopNestUsingForallOp(
478  RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
479  ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
480  ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
482  assert(!loopRanges.empty() && "unexpected empty loop ranges");
483  assert(loopRanges.size() == tileSizes.size() &&
484  "expected as many tile sizes as loop ranges");
485  OpBuilder::InsertionGuard guard(rewriter);
486 
487  std::optional<ArrayAttr> mappingAttr;
488  if (!mappingVector.empty())
489  mappingAttr = rewriter.getArrayAttr(mappingVector);
490 
491  scf::ForallOp forallOp;
492  bool useNumThreads = !numThreads.empty();
493 
494  if (useNumThreads) {
495  // Prune the zero numthreads.
496  SmallVector<OpFoldResult> nonZeroNumThreads;
497  for (auto nt : numThreads) {
498  if (isConstantIntValue(nt, 0))
499  continue;
500  nonZeroNumThreads.push_back(nt);
501  }
502  forallOp = rewriter.create<scf::ForallOp>(loc, nonZeroNumThreads,
503  destinationTensors, mappingAttr);
504  } else {
505  SmallVector<OpFoldResult> lbs, ubs, steps;
506  std::tie(lbs, ubs, steps) =
507  getLoopBounds(rewriter, loc, loopRanges, tileSizes);
508  forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps,
509  destinationTensors, mappingAttr);
510  }
511  loops.push_back(forallOp);
512 
513  rewriter.setInsertionPoint(forallOp.getTerminator());
514  destinationTensors = forallOp.getRegionOutArgs();
515 
516  SmallVector<Value> tiledResults;
517  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
518  if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
519  destinationTensors, tiledResults, resultOffsets,
520  resultSizes)))
521  return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
522 
523  rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
524  for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
525  llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
526  resultSizes)) {
527  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
528  rewriter.getIndexAttr(1));
529 
530  rewriter.create<tensor::ParallelInsertSliceOp>(
531  loc, tiledValue, destinationTensor, resultOffset, resultSize,
532  resultStride);
533  }
534  return success();
535 }
536 
537 /// Generate the tile-loop nest using the loop construct specifed in `options`.
538 /// - `options`: Tiling options specified.
539 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
540 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
541 /// - `destinationTensors` are the init values to use for the outer most loop.
542 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
543 /// most
544 /// loop.
545 /// - `loops` is an in-out parameter into which the generated loops are
546 /// populated.
547 static LogicalResult generateLoopNest(
548  RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options,
549  ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes,
550  ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors,
552  // If the tile sizes are all zero, no loops are generated. Just call the
553  // callback function to handle untiled case.
554  if (llvm::all_of(tileSizes, isZeroIndex)) {
555  SmallVector<Value> tiledResults;
556  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
557  return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
558  tiledResults, resultOffsets, resultSizes);
559  }
561  return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
562  destinationTensors, tiledBodyFn, loops);
563  }
566  rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector,
567  destinationTensors, tiledBodyFn, loops);
568  }
569  return rewriter.notifyMatchFailure(loc, "unhandled loop type");
570 }
571 
572 static FailureOr<SmallVector<Value>>
573 createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op,
574  ArrayRef<OpFoldResult> tileSizes,
576  SmallVector<Value> initTensors;
577  Location loc = op->getLoc();
578  switch (options.reductionStrategy) {
580  if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, initTensors)))
581  return failure();
582  return initTensors;
585  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
586  if (!redOp) {
587  return rewriter.notifyMatchFailure(
588  op, "PartialReductionOuterReduction tiling strategy is only supported"
589  "for operations implementing PartialReductionOpInterface");
590  }
591  // Get reduction dimensions.
592  // TODO: PartialReductionOpInterface should really query TilingInterface
593  // itself and find reduction dimensions.
594  SmallVector<int> reductionDims;
595  for (auto [idx, iteratorType] :
596  llvm::enumerate(op.getLoopIteratorTypes())) {
597  if (iteratorType == utils::IteratorType::reduction)
598  reductionDims.push_back(idx);
599  }
600  return redOp.generateInitialTensorForPartialReduction(
601  rewriter, loc, tileSizes, reductionDims);
602  }
603  default:
604  return rewriter.notifyMatchFailure(op,
605  "unhandled reduction tiling strategy");
606  }
607 }
608 
609 static FailureOr<TilingResult>
610 getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
611  ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
614  switch (options.reductionStrategy) {
616  return op.getTiledImplementation(rewriter, offsets, sizes);
619  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
620  if (!redOp) {
621  return rewriter.notifyMatchFailure(
622  op, "PartialReductionOuterReduction tiling strategy is only "
623  "supported for operations "
624  "implementing PartialReductionOpInterface");
625  }
626  // Get reduction dimensions.
627  // TODO: PartialReductionOpInterface should really query TilingInterface
628  // itself and find reduction dimensions.
629  SmallVector<int> reductionDims;
630  for (auto [idx, iteratorType] :
631  llvm::enumerate(op.getLoopIteratorTypes())) {
632  if (iteratorType == utils::IteratorType::reduction)
633  reductionDims.push_back(idx);
634  }
635  return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg,
636  offsets, sizes, reductionDims);
637  }
638  default:
639  return rewriter.notifyMatchFailure(op,
640  "unhandled reduction tiling strategy");
641  }
642 }
643 
644 static LogicalResult
645 getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
646  TilingInterface op, ArrayRef<OpFoldResult> offsets,
648  SmallVector<OpFoldResult> &resultOffset,
649  SmallVector<OpFoldResult> &resultSize,
651 
652  switch (options.reductionStrategy) {
654  return op.getResultTilePosition(rewriter, index, offsets, sizes,
655  resultOffset, resultSize);
658  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
659  if (!redOp) {
660  return rewriter.notifyMatchFailure(
661  op, "PartialReductionOuterReduction tiling strategy is only supported"
662  "for operations implementing PartialReductionOpInterface");
663  }
664  // Get reduction dimensions.
665  // TODO: PartialReductionOpInterface should really query TilingInterface
666  // itself and find reduction dimensions.
667  SmallVector<int> reductionDims;
668  for (auto [idx, iteratorType] :
669  llvm::enumerate(op.getLoopIteratorTypes())) {
670  if (iteratorType == utils::IteratorType::reduction)
671  reductionDims.push_back(idx);
672  }
673  return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
674  resultOffset, resultSize,
675  reductionDims);
676  }
677  default:
678  return rewriter.notifyMatchFailure(op,
679  "unhandled reduction tiling strategy");
680  }
681 }
682 
683 static FailureOr<MergeResult>
684 mergeTilingResults(RewriterBase &rewriter, TilingInterface op,
685  ValueRange partialResults,
687  switch (options.reductionStrategy) {
689  // No need to merge results for reduction tiling strategy.
690  return MergeResult{{}, partialResults};
693  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
694  if (!redOp) {
695  return rewriter.notifyMatchFailure(
696  op, "PartialReductionOuterReduction tiling strategy is only "
697  "supported for operations "
698  "implementing PartialReductionOpInterface");
699  }
700  // Get reduction dimensions.
701  // TODO: PartialReductionOpInterface should really query TilingInterface
702  // itself and find reduction dimensions.
703  SmallVector<int> reductionDims;
704  for (auto [idx, iteratorType] :
705  llvm::enumerate(op.getLoopIteratorTypes())) {
706  if (iteratorType == utils::IteratorType::reduction)
707  reductionDims.push_back(idx);
708  }
709  return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
710  reductionDims);
711  }
712  default:
713  return rewriter.notifyMatchFailure(op,
714  "unhandled reduction tiling strategy");
715  }
716 }
717 
718 /// Append the specified additional `newInitOperands` operands to the
719 /// loops existing `init` operands (or similar), and replace `loopOp` with
720 /// the new loop that has the additional init operands. The loop body of
721 /// this loop is moved over to the new loop. `yieldTiledValuesFn`
722 /// is called to get the new tiled values returned, and the offset
723 /// and sizes at which the tiled value is inserted into the
724 /// new region iter_args that correspond to the newly added init operands.
725 template <typename LoopType>
726 FailureOr<LoopLikeOpInterface>
727 yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter,
728  ValueRange newInitOperands,
729  YieldTiledValuesFn yieldTiledValuesFn) {
730  return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
731 }
732 
733 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`.
734 template <>
735 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
736  scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
737  YieldTiledValuesFn yieldTiledValuesFn) {
738  OpBuilder::InsertionGuard g(rewriter);
739  Location loc = loopOp.getLoc();
740  rewriter.setInsertionPoint(loopOp);
741 
742  auto inits = llvm::to_vector(loopOp.getInitArgs());
743  inits.append(newInitOperands.begin(), newInitOperands.end());
744  auto newLoop = rewriter.create<scf::ForOp>(
745  loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
746  inits, [](OpBuilder &, Location, Value, ValueRange) {});
747 
748  // Move the loop body to the new op.
749  Block *loopBody = loopOp.getBody();
750  Block *newLoopBody = newLoop.getBody();
751  rewriter.mergeBlocks(
752  loopBody, newLoopBody,
753  newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
754 
755  auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator());
756  rewriter.setInsertionPoint(yieldOp);
757 
758  SmallVector<Value> tiledValues;
759  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
760  ValueRange newRegionIterArgs =
761  newLoop.getRegionIterArgs().take_back(newInitOperands.size());
762  if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
763  newRegionIterArgs, tiledValues, resultOffsets,
764  resultSizes))) {
765  rewriter.eraseOp(newLoop);
766  return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");
767  }
768 
769  SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
770  for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
771  llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
772  resultSizes)) {
773  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
774  rewriter.getIndexAttr(1));
775  Value insert = rewriter.create<tensor::InsertSliceOp>(
776  yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
777  resultStride);
778  newYieldValues.push_back(insert);
779  }
780 
781  rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
782  rewriter.replaceOp(loopOp,
783  newLoop->getResults().take_front(loopOp.getNumResults()));
784  return cast<LoopLikeOpInterface>(newLoop.getOperation());
785 }
786 
787 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall`
788 template <>
789 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
790  scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
791  YieldTiledValuesFn yieldTiledValuesFn) {
792  OpBuilder::InsertionGuard g(rewriter);
793  Location loc = loopOp.getLoc();
794  rewriter.setInsertionPoint(loopOp);
795  auto inits = llvm::to_vector(loopOp.getOutputs());
796  inits.append(newInitOperands.begin(), newInitOperands.end());
797  auto newLoop = rewriter.create<scf::ForallOp>(
798  loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
799  loopOp.getMixedStep(), inits, loopOp.getMapping(),
800  [](OpBuilder &, Location, ValueRange) {});
801 
802  // Move the region of the current block to the newly created op.
803  Block *loopBody = loopOp.getBody();
804  Block *newLoopBody = newLoop.getBody();
805  rewriter.mergeBlocks(
806  loopBody, newLoopBody,
807  newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
808 
809  auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
810  rewriter.setInsertionPoint(terminator);
811  SmallVector<Value> tiledValues;
812  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
813  ValueRange regionIterArgs =
814  newLoop.getRegionIterArgs().take_back(newInitOperands.size());
815  if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
816  regionIterArgs, tiledValues, resultOffsets,
817  resultSizes))) {
818  rewriter.eraseOp(newLoop);
819  return rewriter.notifyMatchFailure(loopOp,
820  "failed to get yielded tiled values");
821  }
822 
823  // Update the terminator.
824  rewriter.setInsertionPointToEnd(terminator.getBody());
825 
826  for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
827  tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
828  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
829  rewriter.getIndexAttr(1));
830  rewriter.create<tensor::ParallelInsertSliceOp>(
831  terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
832  resultStride);
833  }
834 
835  rewriter.replaceOp(loopOp,
836  newLoop->getResults().take_front(loopOp.getNumResults()));
837  return cast<LoopLikeOpInterface>(newLoop.getOperation());
838 }
839 
840 /// Implementation of `yieldTiledValuesAndReplaceLoop` for
841 /// `LoopLikeOpInterface`, that just dispatches to the implementation for each
842 /// supported loop type.
843 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
844  LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
845  ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
847  loopLikeOp.getOperation())
848  .Case<scf::ForOp, scf::ForallOp>(
849  [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
851  loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
852  })
853  .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
854  return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
855  });
856 }
857 
858 /// Method to add new init values to a loop nest. Updates `loops` in-place
859 /// with new loops that use the `newInitValues`. The outer-loops are updated
860 /// to yield the new result values of the inner loop. For the innermost loop,
861 /// the call back `getNewYields` is invoked to get the additional values to
862 /// yield form the innermost loop.
863 static LogicalResult addInitOperandsToLoopNest(
865  ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
866  if (loops.empty())
867  return success();
868  OpBuilder::InsertionGuard g(rewriter);
869  rewriter.setInsertionPoint(loops.front());
870 
871  SmallVector<Value> ivs;
872  for (auto &loop : loops.drop_back()) {
873  rewriter.setInsertionPoint(loop);
874 
875  // if loops.size() > 1 we assume that scf.for is used for the loops.
876  auto forLoop = cast<scf::ForOp>(loop.getOperation());
877 
878  // Create a new loop with the new init values for this loop.
879  SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs());
880  newInits.append(newInitValues.begin(), newInitValues.end());
881  auto newLoop = rewriter.create<scf::ForOp>(
882  forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
883  forLoop.getStep(), newInits,
884  [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
885 
886  // Merge the body of the new loop with the body of the old loops.
887  SmallVector<Value> sourceBlockArgs;
888  sourceBlockArgs.push_back(newLoop.getInductionVar());
889  auto newRegionIterArgs = newLoop.getRegionIterArgs();
890  sourceBlockArgs.append(
891  newRegionIterArgs.begin(),
892  std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
893  rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
894  rewriter.replaceOp(
895  forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
896  loop = newLoop;
897  ivs.push_back(newLoop.getInductionVar());
898  newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
899  }
900 
901  // Update the loop body of the innermost loop to get new yield values.
902  LoopLikeOpInterface innerMostLoop = loops.back();
903  FailureOr<LoopLikeOpInterface> newInnerMostLoop =
904  yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues,
905  getNewTiledYieldsFn);
906 
907  if (failed(newInnerMostLoop))
908  return innerMostLoop.emitOpError("failed to return additional yields");
909  loops.back() = newInnerMostLoop.value();
910 
911  // Make all other loops except the innermost loops yield the values returned
912  // by the inner loop.
913  for (auto [outerLoop, innerLoop] :
914  llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
915  // Again assume that all the outer loops are scf.for operations.
916  auto outerForLoop = cast<scf::ForOp>(outerLoop);
917  auto outerLoopYield =
918  cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
919  SmallVector<Value> newYields =
920  llvm::to_vector(outerLoopYield.getOperands());
921  ValueRange additionalYields =
922  innerLoop->getResults().take_back(newInitValues.size());
923  newYields.append(additionalYields.begin(), additionalYields.end());
924  rewriter.setInsertionPoint(outerLoopYield);
925  rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
926  }
927  return success();
928 }
929 
930 /// Implementation of tiling transformation of `op` that implements the
931 /// `TilingInterface` using `scf.for` to iterate over the tiles.
932 FailureOr<scf::SCFTilingResult>
933 mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
935  if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) {
936  return failure();
937  }
938 
939  OpBuilder::InsertionGuard guard(rewriter);
940  rewriter.setInsertionPointAfter(op);
941 
942  // 1. Get the range of the loops that are represented by the operation.
943  SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
944 
945  // 2. Materialize the tile sizes and/or number of threads;
946  SmallVector<OpFoldResult> tileSizes, numThreads;
947  std::tie(tileSizes, numThreads) =
948  getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options);
949 
950  // Check if it is safe to tile. This is hold over from previous iterations
951  // of tile to for-all. Consider dropping it.
953  checkSafeToTileToForall(op, tileSizes, numThreads);
954  }
955 
956  // 3. If there is an interchange specified, permute the iteration domain and
957  // the tile sizes.
958  SmallVector<int64_t> interchangeVector;
959  if (!options.interchangeVector.empty()) {
960  interchangeVector = fillInterchangeVector(options.interchangeVector,
961  iterationDomain.size());
962  assert(isPermutationVector(interchangeVector) &&
963  "expected interchange vector to be a permutation");
964 
965  applyPermutationToVector(iterationDomain, interchangeVector);
966  applyPermutationToVector(tileSizes, interchangeVector);
967  if (!numThreads.empty())
968  applyPermutationToVector(numThreads, interchangeVector);
969  }
970 
971  FailureOr<TilingResult> tilingResult;
972  // 4. Define the lambda function used later to generate the body of the
973  // innermost tiled loop.
974  YieldTiledValuesFn innerYieldTiledValuesFn =
975  [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
976  ValueRange regionIterArgs, SmallVector<Value> &tiledResults,
979  -> LogicalResult {
980  // 4a. Compute the `offsets` and `sizes` to use for tiling.
981  SmallVector<OpFoldResult> offsets, sizes;
982  std::tie(offsets, sizes) = getTileOffsetAndSizes(
983  rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
984 
985  // 4b. If interchange was provided, apply inverse of the interchange
986  // to get back the offsets/sizes in the order to be specified.
987  if (!interchangeVector.empty()) {
988  auto inversePermutation = invertPermutationVector(interchangeVector);
991  }
992 
993  // 5. Generate the tiled implementation within the inner most loop.
994 
995  // 5a. Clone the operation within the loop body.
996  auto clonedOp = cast<TilingInterface>(
997  cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
998 
999  // 5b. Early return cloned op if tiling is not happening. We can not
1000  // return the original op because it could lead to `rewriter.replaceOp(op,
1001  // op->getResults())` and users would get crash.
1002  if (llvm::all_of(tileSizes, isZeroIndex)) {
1003  tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
1004  tilingResult =
1005  TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
1006  /*generatedSlices=*/{}};
1007  return success();
1008  }
1009 
1010  // 5c. Tile the cloned operation.
1011  tilingResult = getTiledImplementation(rewriter, clonedOp, regionIterArgs,
1012  offsets, sizes, options);
1013  if (failed(tilingResult)) {
1014  rewriter.eraseOp(clonedOp);
1015  return op.emitOpError("faild to tile operation");
1016  }
1017 
1018  // 5d. Delete the cloned operation.
1019  rewriter.eraseOp(clonedOp);
1020 
1021  // 5e. Compute the offsets at which the result values are to be inserted
1022  // back into its destinations.
1023  for (auto [index, tiledValue] :
1024  llvm::enumerate(tilingResult->tiledValues)) {
1025  tiledResults.push_back(tiledValue);
1026  SmallVector<OpFoldResult> resultOffset, resultSize;
1027  if (failed(getResultTilePosition(rewriter, index, tiledValue, op, offsets,
1028  sizes, resultOffset, resultSize,
1029  options))) {
1030  for (auto op : tilingResult->tiledOps) {
1031  rewriter.eraseOp(op);
1032  }
1033  return rewriter.notifyMatchFailure(
1034  op, "failed to get slice of result produced");
1035  }
1036  resultOffsets.emplace_back(std::move(resultOffset));
1037  resultSizes.emplace_back(std::move(resultSize));
1038  }
1039 
1040  return success();
1041  };
1042 
1043  // 6. Find the destination tensors to use for the operation.
1044  FailureOr<SmallVector<Value>> maybeInits =
1045  createInitialTensorsForTiling(rewriter, op, tileSizes, options);
1046  if (failed(maybeInits)) {
1047  return rewriter.notifyMatchFailure(
1048  op, "unable to create initial tensors for tiling");
1049  }
1050  SmallVector<Value> &initTensors = maybeInits.value();
1051 
1052  // 7. Generate the tiled loops nest using the callback defined above.
1054  if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
1055  tileSizes, numThreads, initTensors,
1056  innerYieldTiledValuesFn, loops)))
1057  return op.emitOpError("failed to generate tiling loops");
1058  assert(succeeded(tilingResult) &&
1059  "expected tiling result to be computed after loop generation");
1060 
1061  SmallVector<Value> partialResults;
1062  if (loops.empty()) {
1063  // If loops are empty, the tiled op is used as the replacement for the
1064  // untiled op.
1065  partialResults = tilingResult->tiledValues;
1066  } else {
1067  partialResults = llvm::map_to_vector(loops.front()->getResults(),
1068  [](OpResult r) -> Value { return r; });
1069  }
1070 
1071  FailureOr<MergeResult> mergeResult =
1072  mergeTilingResults(rewriter, op, partialResults, options);
1073  if (failed(mergeResult)) {
1074  return rewriter.notifyMatchFailure(
1075  op, "Failed to merge partial results from tiling");
1076  }
1077 
1078  return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops,
1079  mergeResult.value(),
1080  tilingResult->generatedSlices};
1081 }
1082 
1083 FailureOr<scf::SCFTilingResult>
1085  PartialReductionOpInterface op,
1086  ArrayRef<OpFoldResult> tileSizes) {
1089  options.setReductionTilingStrategy(SCFTilingOptions::ReductionTilingStrategy::
1090  PartialReductionOuterReduction);
1091  options.setTileSizes(tileSizes);
1092 
1093  TilingInterface tilingInterfaceOp =
1094  dyn_cast<TilingInterface>(op.getOperation());
1095  if (!tilingInterfaceOp) {
1096  return b.notifyMatchFailure(
1097  op,
1098  "Operation implementing PartialReductionOpInterface should implement "
1099  "TilingInterface");
1100  }
1101 
1102  return tileUsingSCF(b, tilingInterfaceOp, options);
1103 }
1104 
1105 //===----------------------------------------------------------------------===//
1106 // tileConsumerAndFuseProducersUsingSCF implementation.
1107 //===----------------------------------------------------------------------===//
1108 
1109 /// Return the untiled producer whose slice is used in a tiled consumer. The
1110 /// method traverses the tile loop nest (`loops`) if needed, and returns the
1111 /// `iter_args` of the outer most that is encountered. Traversing the
1112 /// iter_args indicates that this is a destination operand of the consumer. If
1113 /// there was no loop traversal needed, the second value of the returned tuple
1114 /// is empty.
1115 static std::tuple<OpResult, std::optional<OpOperand *>>
1118  std::optional<OpOperand *> destinationIterArg;
1119  assert(!loops.empty() && "expected non empty loops container");
1120  auto loopIt = loops.rbegin();
1121  while (loopIt != loops.rend() && isa<BlockArgument>(source->get())) {
1122  auto iterArg = cast<BlockArgument>(source->get());
1123  auto loop = *loopIt;
1124  if (iterArg.getOwner()->getParentOp() != loop)
1125  break;
1126  source = loop.getTiedLoopInit(iterArg);
1127  loopIt++;
1128  }
1129  if (loopIt == loops.rend())
1130  destinationIterArg = source;
1131  return {dyn_cast<OpResult>(source->get()), destinationIterArg};
1132 }
1133 
1134 /// Implementation of fusing producer of a single slice by computing the
1135 /// slice of the producer in-place.
1136 std::optional<scf::SCFFuseProducerOfSliceResult>
1138  RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1140  // 1. Get the producer of the source (potentially walking through
1141  // `iter_args` of nested `scf.for`)
1142  auto [fusableProducer, destinationInitArg] =
1143  getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
1144  loops);
1145  if (!fusableProducer)
1146  return std::nullopt;
1147  unsigned resultNumber = fusableProducer.getResultNumber();
1148 
1149  OpBuilder::InsertionGuard g(rewriter);
1150  rewriter.setInsertionPoint(candidateSliceOp);
1151 
1152  // 2. Clone the fused producer
1153  // 2a. Compute the destination operands to use for the cloned operation.
1154  SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
1155  Operation *fusableProducerOp = fusableProducer.getOwner();
1156  if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
1158  rewriter, fusableProducerOp->getLoc(), fusableProducerOp,
1159  origDestinationTensors)))
1160  return std::nullopt;
1161 
1162  clonedOpDestinationTensors = origDestinationTensors;
1163  if (destinationInitArg &&
1164  isa<DestinationStyleOpInterface>(fusableProducerOp)) {
1165  // 2b. If the producer is also destination style, then to maintain the
1166  // destination passing style, update the destination of the producer to be
1167  // the source of the slice.
1168  clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1169  }
1170  // 2c. Clone the fused producer.
1171  Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
1172  rewriter, fusableProducerOp, clonedOpDestinationTensors);
1173  // 2d. Update the source of the candidateSlice to be the cloned producer.
1174  // Easier to just clone the slice with different source since
1175  // replacements and DCE of cloned ops becomes easier
1176  SmallVector<Value> candidateSliceOpOperands =
1177  llvm::to_vector(candidateSliceOp->getOperands());
1178  candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
1179  tensor::ExtractSliceOp clonedCandidateSliceOp =
1180  mlir::clone(rewriter, candidateSliceOp,
1181  candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1182 
1183  // 3. Generate the tiled implementation of the producer of the source
1184  FailureOr<TilingResult> tileAndFuseResult =
1186  rewriter, clonedCandidateSliceOp,
1187  clonedProducerOp->getResult(resultNumber));
1188  if (failed(tileAndFuseResult))
1189  return std::nullopt;
1190  // Note: Do not delete the candidateSliceOp, since its passed in from the
1191  // caller.
1192  rewriter.replaceAllUsesWith(candidateSliceOp,
1193  tileAndFuseResult->tiledValues[0]);
1194  rewriter.eraseOp(clonedCandidateSliceOp);
1195  rewriter.eraseOp(clonedProducerOp);
1196 
1197  // 3. If the slice is for a destination operand, for example,
1198  //
1199  // ```mlir
1200  // %0 = linalg.init
1201  // %1 = linalg.fill .. outs(%0 : )
1202  // %2 = scf.for .. iter_args(%arg0 = %1) {
1203  // %3 = scf.for .. iter_args(%arg1 = %arg0) {
1204  // %4 = tensor.extract_slice %arg1 [..]
1205  // .. = linalg.matmul .. outs(%4 : )
1206  // }
1207  // }
1208  // ```
1209  //
1210  // the IR is currently
1211  //
1212  // ```
1213  // %0 = linalg.init
1214  // %1 = linalg.fill
1215  // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
1216  // %3 = scf.for .. iter_args(%arg1 = %arg0) {
1217  // %4 = tensor.extract_slice %arg1[..]
1218  // %5 = linalg.fill .. outs(%4 : )
1219  // .. = linalg.matmul .. outs(%5 : )
1220  // }
1221  // }
1222  // ```
1223  //
1224  // The untiled `linalg.fill` is still used as the `init_value` since it
1225  // was originally a destination operand of the untiled `linalg.matmul`.
1226  // When fusing an operand that is a destination operand, the iter_arg of
1227  // the outer most loop should be changed to use the destination of the
1228  // fused operation. With this the IR will be.
1229  //
1230  // ```
1231  // %0 = linalg.init
1232  // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
1233  // %2 = scf.for .. iter_args(%arg1 = %arg0) {
1234  // %3 = tensor.extract_slice %arg1[..]
1235  // %4 = linalg.fill .. outs(%3 : )
1236  // .. = linalg.matmul .. outs(%4 : )
1237  // }
1238  // }
1239  // ```
1240  if (destinationInitArg &&
1241  isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
1242  loops.front()
1243  ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1244  .set(origDestinationTensors[resultNumber]);
1245  }
1247  fusableProducer, tileAndFuseResult->tiledValues[0],
1248  tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices};
1249 }
1250 
1251 /// Reconstruct the fused producer from within the tiled-and-fused code.
1252 FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
1253  RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1254  scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
1256  ArrayRef<unsigned> yieldResultNumber) {
1257  if (loops.empty())
1258  return success();
1259 
1260  Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
1261  *tiledOwner = fusedProducerInfo.tiledOps[0];
1262 
1263  Location loc = originalOwner->getLoc();
1264  // a. collect all init Value to be appended
1265  SmallVector<unsigned> initNumberList =
1266  yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
1267  0, originalOwner->getNumResults()))
1268  : llvm::to_vector(yieldResultNumber);
1269  SmallVector<Value> initValueList;
1270  for (const auto &resultNumber : initNumberList) {
1271  FailureOr<Value> initValue = tensor::getOrCreateDestination(
1272  rewriter, loc, originalOwner->getResult(resultNumber));
1273  if (succeeded(initValue)) {
1274  initValueList.push_back(initValue.value());
1275  } else {
1276  return failure();
1277  }
1278  }
1279 
1280  SmallVector<Operation *> generatedSlices;
1281  YieldTiledValuesFn newYieldValuesFn =
1282  [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
1283  ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
1285  SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
1286  OpBuilder::InsertionGuard g(innerRewriter);
1287 
1288  // get sliceOp tile information
1289  SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
1290  sliceSizes = sliceOp.getMixedSizes();
1291 
1292  // expect all strides of sliceOp being 1
1293  if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
1294  return !isConstantIntValue(ofr, 1);
1295  }))
1296  return failure();
1297 
1298  unsigned sliceResultNumber =
1299  fusedProducerInfo.origProducer.getResultNumber();
1300 
1301  auto tilableOp = cast<TilingInterface>(originalOwner);
1302  // b. get iterDomain Offset and Sizes based on sliceOp tile
1303  SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
1304  // skip tensor.pack/unpack/pad, which expects single opResult
1305  if (tilableOp->getNumResults() > 1 &&
1306  failed(tilableOp.getIterationDomainTileFromResultTile(
1307  rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1308  iterDomainOffset, iterDomainSizes))) {
1309  // In theory, it is unnecessary to raise an error here. Actually
1310  // although it fails to reconstruct the result tensor, it should not
1311  // broke current fusion anyway. The reason why we must return failure
1312  // currently is that the callback function `newYieldValuesFn` will be
1313  // called after new init operand(s) has already been appended. It will
1314  // take more refactoring to make sure the init operands are added
1315  // consistently in the future. For more details, please refer to:
1316  // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
1317  return failure();
1318  }
1319 
1320  // c. calculate offsets and sizes info of all OpResults respectively based
1321  // on iteration Domain Tile
1322  SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
1323  for (const auto &resultNumber : initNumberList) {
1324  if (resultNumber == sliceResultNumber) {
1325  offsetList.push_back(sliceOffset);
1326  sizesList.push_back(sliceSizes);
1327  } else {
1328  assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1329  // infer result tile according to the iteration domain tile
1330  SmallVector<OpFoldResult> offset, sizes;
1331  if (failed(tilableOp.getResultTilePosition(
1332  rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1333  offset, sizes))) {
1334  return failure();
1335  }
1336  offsetList.push_back(offset);
1337  sizesList.push_back(sizes);
1338  }
1339  }
1340 
1341  // d. create `extract_slice` for `iter_args` for DPS operation if
1342  // necessary
1343  if (auto tiledDestStyleOp =
1344  dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1345  rewriter.setInsertionPoint(tiledDestStyleOp);
1346  for (const auto &&[index, newRegionArg] :
1347  llvm::enumerate(newRegionIterArgs)) {
1348  auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
1349  loc, newRegionArg, offsetList[index], sizesList[index],
1350  SmallVector<OpFoldResult>(offsetList[index].size(),
1351  rewriter.getIndexAttr(1)));
1352  generatedSlices.push_back(destSlice);
1353  unsigned resultNumber = initNumberList[index];
1354  rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
1355  tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1356  });
1357  }
1358  }
1359 
1360  // e. prepare tiled offset and sizes for later `insert_slice` creation by
1361  // caller
1362  Block *block = rewriter.getInsertionPoint()->getBlock();
1363  rewriter.setInsertionPoint(block->getTerminator());
1364  for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
1365  tiledResult.push_back(tiledOwner->getResult(resultNumber));
1366  tiledOffset.emplace_back(offsetList[index]);
1367  tiledSizes.emplace_back(sizesList[index]);
1368  }
1369  return success();
1370  };
1371 
1372  if (failed(addInitOperandsToLoopNest(rewriter, loops, initValueList,
1373  newYieldValuesFn))) {
1374  return failure();
1375  }
1376  return generatedSlices;
1377 }
1378 
1379 namespace {
1380 
1381 //===----------------------------------------------------------------------===//
1382 // SliceTrackingListener
1383 //===----------------------------------------------------------------------===//
1384 
1385 /// This class is a listener for tracking the insertion and removal of
1386 /// `tensor.extract_slice` ops in a worklist. This can be used in a greedy
1387 /// fusion algorithm to apply cleanup patterns in between fusion steps.
1388 class SliceTrackingListener : public RewriterBase::Listener {
1389 public:
1390  explicit SliceTrackingListener(
1391  std::optional<FrozenRewritePatternSet> patterns);
1392  SliceTrackingListener() = default;
1393 
1394  /// Adds the given list of operations to the worklist, and if present,
1395  /// applies the list of `patterns` to the newly added operations. This only
1396  /// processes the given operations and any newly inserted ones by the
1397  /// pattern set.
1398  LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
1399 
1400  /// Add to the new operation worklist if it is an extract_slice.
1401  void notifyOperationInserted(Operation *op,
1402  OpBuilder::InsertPoint previous) override;
1403 
1404  /// Shared helper for operation removal from the worklist.
1405  void removeOp(Operation *op);
1406 
1407  /// Remove the operation from the worklist.
1408  void notifyOperationErased(Operation *op) override;
1409 
1410  /// Remove the operation from the worklist.
1411  void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
1412 
1413  /// The worklist for this transformation keeps track of the slices to visit
1414  /// next for fusion.
1415  std::deque<tensor::ExtractSliceOp> worklist;
1416 
1417 private:
1418  /// Optional pattern set to apply when adding new operations to the
1419  /// worklist.
1420  std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
1421 };
1422 
1423 SliceTrackingListener::SliceTrackingListener(
1424  std::optional<FrozenRewritePatternSet> p) {
1425  patterns = std::move(p);
1426 }
1427 
1428 LogicalResult
1429 SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
1430  for (Operation *op : ops) {
1431  if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1432  worklist.push_back(slice);
1433  }
1434 
1435  if (!patterns)
1436  return success();
1437 
1438  return applyOpPatternsGreedily(
1439  ops, patterns.value(),
1440  GreedyRewriteConfig().setListener(this).setStrictness(
1442 }
1443 
1444 void SliceTrackingListener::notifyOperationInserted(
1445  Operation *op, OpBuilder::InsertPoint previous) {
1446  auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1447  if (!slice)
1448  return;
1449  worklist.push_back(slice);
1450 }
1451 
1452 // Scan the worklist for the given op and remove it if present. The
1453 // expectation is for the worklist to be small and for removal to be
1454 // relatively rare.
1455 void SliceTrackingListener::removeOp(Operation *op) {
1456  if (!isa<tensor::ExtractSliceOp>(op))
1457  return;
1458  auto iter = worklist.begin();
1459  while (iter != worklist.end()) {
1460  if (*iter == op)
1461  break;
1462  iter++;
1463  }
1464  if (iter == worklist.end())
1465  return;
1466 
1467  worklist.erase(iter);
1468 }
1469 
1470 void SliceTrackingListener::notifyOperationErased(Operation *op) {
1471  removeOp(op);
1472 }
1473 
1474 void SliceTrackingListener::notifyOperationReplaced(Operation *op,
1475  ValueRange replacement) {
1476  removeOp(op);
1477 }
1478 
1479 //===----------------------------------------------------------------------===//
1480 // ReplacementListener
1481 //===----------------------------------------------------------------------===//
1482 
1483 /// Listener that tracks updates replacements for values which can be mutated.
1484 /// This listener runs on top of the existing listener for the rewriter,
1485 /// to make sure external users can still run listeners.
1486 class ReplacementListener : public RewriterBase::ForwardingListener {
1487 public:
1488  ReplacementListener(DenseMap<Value, Value> &replacements,
1489  OpBuilder::Listener *listener)
1490  : ForwardingListener(listener), replacements(replacements) {}
1491 
1492  void updateReplacementValues(ValueRange origValues,
1493  ValueRange replaceValues) {
1494  // This can probably be written better, but just iterates over the map
1495  // and the new replacements for now.
1496  for (auto &[key, val] : replacements) {
1497  for (auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) {
1498  if (val == orig) {
1499  val = replace;
1500  }
1501  }
1502  }
1503  }
1504 
1505  void notifyOperationReplaced(Operation *op, Operation *newOp) override {
1506  ForwardingListener::notifyOperationReplaced(op, newOp);
1507  updateReplacementValues(op->getResults(), newOp->getResults());
1508  }
1509 
1510  void notifyOperationReplaced(Operation *op, ValueRange values) override {
1511  ForwardingListener::notifyOperationReplaced(op, values);
1512  updateReplacementValues(op->getResults(), values);
1513  }
1514 
1515 private:
1516  DenseMap<Value, Value> &replacements;
1517 };
1518 
1519 } // namespace
1520 
1521 /// Implementation of tile consumer and fuse producer greedily.
1522 FailureOr<scf::SCFTileAndFuseResult>
1524  RewriterBase &rewriter, TilingInterface consumer,
1526  // This transformation is only valid for ops that return values (i.e. not
1527  // valid to use with operations that have memref operands).
1528  if (!consumer->getNumResults()) {
1529  return rewriter.notifyMatchFailure(
1530  consumer, "invalid pattern for op with no results");
1531  }
1532 
1533  // 1. First tile the consumer.
1534  SetVector<Operation *> fusedProducers, tiledAndFusedOps;
1535 
1536  FailureOr<scf::SCFTilingResult> tilingResult =
1537  tileUsingSCF(rewriter, consumer, options.tilingOptions);
1538 
1539  if (failed(tilingResult))
1540  return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
1541  tiledAndFusedOps.insert_range(tilingResult->tiledOps);
1542 
1543  DenseMap<Value, Value> replacements;
1544  for (auto [origVal, replacement] : llvm::zip_equal(
1545  consumer->getResults(), tilingResult->mergeResult.replacements)) {
1546  replacements[origVal] = replacement;
1547  }
1548 
1549  // If there are no loops generated, fusion is immaterial.
1550  auto &loops = tilingResult->loops;
1551  if (loops.empty()) {
1552  return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1553  replacements};
1554  }
1555 
1556  // Since the loop gets potentially replaced during fusion, we need to track
1557  // the mutation of replacement values. To do this, we attach a listener to
1558  // update the replacements as they happen.
1559  OpBuilder::Listener *previousListener = rewriter.getListener();
1560  auto resetListener =
1561  llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
1562  ReplacementListener replaceListener(replacements, previousListener);
1563  rewriter.setListener(&replaceListener);
1564 
1565  // 2. Typically, the operands of the tiled operation are slices of the
1566  // operands of the untiled operation. These are expressed in IR using
1567  // `tensor.extract_slice` operations with source being the operands of
1568  // the untiled operation. Create a worklist of these
1569  // `tensor.extract_slice` operations. If the producers of the source of
1570  // the `tensor.extract_slice` can be tiled such that the tiled value is
1571  // generated in-place, that effectively tiles + fuses the operations.
1572  struct WorklistItem {
1573  tensor::ExtractSliceOp candidateSlice;
1575  };
1576 
1577  SliceTrackingListener sliceTracker =
1578  SliceTrackingListener(options.cleanupPatterns);
1579 
1580  if (failed(
1581  sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1582  return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1583  }
1584  OpBuilder::InsertionGuard g(rewriter);
1585  while (!sliceTracker.worklist.empty()) {
1586  auto candidateSlice = sliceTracker.worklist.front();
1587  sliceTracker.worklist.pop_front();
1588 
1589  auto [fusableProducer, destinationInitArg] =
1590  getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(),
1591  loops);
1592  if (!fusableProducer)
1593  continue;
1594 
1595  std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1596  options.fusionControlFn(candidateSlice, fusableProducer,
1597  destinationInitArg.has_value());
1598  if (!controlFnResult)
1599  continue;
1600 
1601  WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
1602 
1603  // The operands of the fused producer might themselved be slices of
1604  // values produced by operations that implement the `TilingInterface`.
1605  // Add these operations to the worklist.
1606  std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1607  tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice,
1608  loops);
1609  if (!fusedResult)
1610  continue;
1611 
1612  SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
1613 
1614  if (worklistItem.controlFnResult.yieldProducerReplacement) {
1615  // Reconstruct and yield all opResult of fusableProducerOp by default.
1616  // The caller can specific which one to yield by designating optional
1617  // argument named `yieldResultNumber` of
1618  // `yieldReplacementForFusedProducer`.
1619  Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1620  FailureOr<SmallVector<Operation *>> newSlices =
1622  worklistItem.candidateSlice,
1623  fusedResult.value(), loops);
1624  if (failed(newSlices)) {
1625  return rewriter.notifyMatchFailure(
1626  fusableProducerOp, "failed to replacement value for this "
1627  "operation from within the tiled loop");
1628  }
1629  worklistCandidates.append(newSlices.value());
1630  for (auto [index, result] :
1631  llvm::enumerate(fusableProducerOp->getResults())) {
1632  replacements[result] = loops.front()->getResult(
1633  loops.front()->getNumResults() -
1634  fusableProducerOp->getNumResults() + index);
1635  }
1636  }
1637  if (Operation *tiledAndFusedOp =
1638  fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1639  fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1640  tiledAndFusedOps.insert(tiledAndFusedOp);
1641  }
1642 
1643  if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1644  return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1645  }
1646  }
1647 
1648  return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1649  replacements};
1650 }
1651 
1652 //===----------------------------------------------------------------------===//
1653 // tileAndFuseConsumerUsingSCF implementation.
1654 //===----------------------------------------------------------------------===//
1655 
1656 /// A utility function that checks whether the only use of the result of a
1657 /// tensor.insert_slice op is in a scf.yield op.
1658 static LogicalResult
1659 checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
1660  Value result = candidateSliceOp.getResult();
1661  Value::use_range uses = result.getUses();
1662  if (!llvm::hasSingleElement(uses)) {
1663  LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
1664  return failure();
1665  }
1666  OpOperand &operandUse = (*uses.begin());
1667  Operation *userOp = operandUse.getOwner();
1668  if (!isa<scf::YieldOp>(userOp)) {
1669  LLVM_DEBUG(llvm::dbgs()
1670  << "Expected scf.yield to be the only user, but got -> "
1671  << (*userOp));
1672  return failure();
1673  }
1674  if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
1675  LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
1676  "be in the same block\n");
1677  return failure();
1678  }
1679  return success();
1680 }
1681 
1682 /// An utility to get the first user of the given loopOp. If any of user stay
1683 /// in different block of loopOp, return failure.
1684 static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
1685  if (!isa<LoopLikeOpInterface>(loopOp))
1686  return failure();
1687  Operation *firstUserOfLoop = nullptr;
1688  for (Operation *userOp : loopOp->getUsers()) {
1689  // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
1690  // block with any other types of operation. Thus, just redirecting to its
1691  // parent `InParallelOp`. E.g.
1692  //
1693  // ```
1694  // %1 = scf.for {
1695  // ...
1696  // }
1697  // %2 = consumerOp ins(%1, ...)
1698  // scf.forall.in_parallel {
1699  // tensor.parallel_insert_slice %1
1700  // }
1701  // ```
1702  // where `InParallelOp` but not `ParallelInsertSlice` stays in the same
1703  // same block with `consumerOp`.
1704  if (isa<tensor::ParallelInsertSliceOp>(userOp))
1705  userOp = userOp->getParentOfType<scf::InParallelOp>();
1706 
1707  if (loopOp->getBlock() != userOp->getBlock())
1708  return failure();
1709 
1710  if (!firstUserOfLoop || userOp->isBeforeInBlock(firstUserOfLoop))
1711  firstUserOfLoop = userOp;
1712  }
1713  return firstUserOfLoop;
1714 }
1715 
1716 /// This utility currently checks whether the first userOp of loop is NOT
1717 /// before the last defineOp of consumer operand. Because that we need to move
1718 /// the whole loop structure right before the `firstUserOfLoop`. This utility
1719 /// thus helps ensuring that no invalid IR is formed, i.e. no backward slice
1720 /// of consumerOp is dominated by the `firstUserOfLoop`. Saying that:
1721 ///
1722 /// ```
1723 /// %0 = scf.for() {
1724 /// ...
1725 /// }
1726 /// ...
1727 /// %1 = firstUserOfLoop(%0)
1728 /// ...
1729 /// %2 = lastDefOfConsumerOperand
1730 /// ...
1731 /// %3 = consumerOp(%2)
1732 /// ```
1733 ///
1734 /// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it
1735 /// would be invalid to move the `loopOp` right before the `firstUserOfLoop`,
1736 /// a.k.a. use-def chain violation:
1737 ///
1738 /// ```
1739 /// %0:2 = scf.for() {
1740 /// // use before define error
1741 /// %3 = tiledConsumerOp(%2)
1742 /// }
1743 /// %1 = firstUserOfLoop(%0)
1744 /// ...
1745 /// %2 = lastDefOfConsumerOperand
1746 /// ```
1747 ///
1748 /// @param loopOp: loop operation
1749 /// @param consumerOp: consumer operation
1750 /// @param reorderOperations: the flag controls whether to reorder the
1751 /// backward slice w.r.t. the defineOp of `consumerOp` operands.
1752 /// @return: computed backward slice of consumerOp, but excluding those
1753 /// already dominates `firstUserOfLoop`.
1754 static FailureOr<llvm::SetVector<Operation *>>
1756  bool reorderOperations) {
1757  FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
1758  if (failed(firstUserOfLoop))
1759  return failure();
1760 
1762  DominanceInfo dominanceInfo;
1763  options.inclusive = true;
1764  options.omitBlockArguments = true;
1765  bool includeLoopOp = false;
1766  options.filter = [&](Operation *op) {
1767  if (op == loopOp) {
1768  includeLoopOp = true;
1769  return false;
1770  }
1771  // Cut off the slice to not include any operation that already dominates
1772  // firstUserOfLoop.
1773  return !dominanceInfo.properlyDominates(op, *firstUserOfLoop);
1774  };
1776  for (auto operand : consumerOp->getOperands()) {
1777  getBackwardSlice(operand, &slice, options);
1778  }
1779 
1780  if (!slice.empty()) {
1781  // If consumerOp has one producer, which is also the user of loopOp.
1782  // E.g.
1783  // ```
1784  // %0 = %loopOp
1785  // %1 = consumerOp1 ins(%0)
1786  // %2 = consumerOp2 ins(%0, %1)
1787  // ```
1788  // We can not fuse consumerOp2 into loopOp due to UD chain, unless
1789  // consumerOp1 has already been fused into loopOp before.
1790  if (includeLoopOp || !reorderOperations)
1791  return failure();
1792  }
1793 
1794  return slice;
1795 }
1796 
1797 /// Fetches the OpOperand of the first valid user (and use) of the value `val`
1798 /// which implements `TilingInterface` and `DestinationStyleOpInterface`.
1799 /// Returns failure otherwise.
1800 static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
1801  Operation *loopOp,
1802  unsigned resultNumber) {
1803  if (!isa<LoopLikeOpInterface>(loopOp))
1804  return failure();
1805  Value val = loopOp->getResult(resultNumber);
1806  Block *loopBlock = loopOp->getBlock();
1807  for (OpOperand &opOperand : val.getUses()) {
1808  Operation *consumerOp = opOperand.getOwner();
1809  // Step 1. Check if the user is tilable.
1810  if (!isa<TilingInterface>(consumerOp) ||
1811  !isa<DestinationStyleOpInterface>(consumerOp)) {
1812  // TODO: We have to init result of consumer before scf.for, use
1813  // DestinationStyleOpInterface to get result shape from init for now.
1814  // Add support for other op such as op has InferTypeOpInterface.
1815  continue;
1816  }
1817  // Step 2. Check if user stay in the same block.
1818  if (loopBlock != consumerOp->getBlock())
1819  continue;
1820  // Step 3. Check if user has succeeding user. Otherwise, it usually
1821  // represents already tiled.
1822  if (consumerOp->use_empty())
1823  continue;
1824  // Step 4. Check assumption for loop with `reorderOperations` enabled.
1825  FailureOr<llvm::SetVector<Operation *>> slice =
1826  checkAssumptionForLoop(loopOp, consumerOp, true);
1827  if (failed(slice))
1828  continue;
1829  // Step 5. If backward sice is not empty, move them before
1830  // firstUserOfLoop.
1831  if (!slice->empty()) {
1832  mlir::topologicalSort(*slice);
1833  FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
1834  assert(succeeded(firstUserOfLoop) && "First user of loop is not found");
1835  for (auto op : *slice) {
1836  rewriter.moveOpBefore(op, *firstUserOfLoop);
1837  }
1838  }
1839  return &opOperand;
1840  }
1841  return failure();
1842 }
1843 
1844 /// Check that the loop is perfectly nested.
1845 /// The loops are expected to be ordered from outer most to inner most.
1846 /// For example:
1847 /// ```
1848 /// %0 = scf.for()
1849 /// %1 = scf.for()
1850 /// %2 = scf.for()
1851 /// %3 = ...
1852 /// yield %3
1853 /// yield %2
1854 /// yield %1
1855 /// ```
1856 /// Here loops should be [%0, %1].
1857 static bool
1859  assert(!loops.empty() && "unexpected empty loop nest");
1860  if (loops.size() == 1) {
1861  return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
1862  }
1863  for (auto [outerLoop, innerLoop] :
1864  llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
1865  auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
1866  auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
1867  if (!outerFor || !innerFor) {
1868  return false;
1869  }
1870  auto outerBBArgs = outerFor.getRegionIterArgs();
1871  auto innerIterArgs = innerFor.getInitArgs();
1872  if (outerBBArgs.size() != innerIterArgs.size()) {
1873  return false;
1874  }
1875 
1876  for (auto [outerBBArg, innerIterArg] :
1877  llvm::zip_equal(outerBBArgs, innerIterArgs)) {
1878  if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
1879  innerIterArg != outerBBArg) {
1880  return false;
1881  }
1882  }
1883 
1884  ValueRange outerYields =
1885  cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
1886  ValueRange innerResults = innerFor.getResults();
1887  if (outerYields.size() != innerResults.size()) {
1888  return false;
1889  }
1890  for (auto [outerYield, innerResult] :
1891  llvm::zip_equal(outerYields, innerResults)) {
1892  if (!llvm::hasSingleElement(innerResult.getUses()) ||
1893  outerYield != innerResult) {
1894  return false;
1895  }
1896  }
1897  }
1898  return true;
1899 }
1900 
1901 /// Fetch the untiled consumer of the outermost scf.for's result which is
1902 /// yielded by a tensor.insert_slice from the innermost scf.for. This function
1903 /// makes the following assumptions :
1904 /// 1. tensor.insert_slice has scf.yield as its only user.
1905 /// 2. scf.for's corresponding result has only one use.
1906 /// 3. The `loops` passed in are perfectly nested `scf.for` operations.
1907 static FailureOr<OpOperand *>
1909  tensor::InsertSliceOp candidateSliceOp,
1911  assert(!loops.empty() && "unexpected loops to be empty");
1912  // 1. Expect slice to be part of the body of the inner most loop.
1913  Operation *containingOp = candidateSliceOp->getParentOp();
1914  if (containingOp != loops.back()) {
1915  return rewriter.notifyMatchFailure(
1916  candidateSliceOp,
1917  "expected slice to be within body of inner-most loop");
1918  }
1919 
1920  // 2. Check that the loop is perfectly nested.
1921  if (!isPerfectlyNestedForLoops(loops)) {
1922  return rewriter.notifyMatchFailure(
1923  candidateSliceOp, "expected passed loops to be perfectly nested.");
1924  }
1925 
1926  if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
1927  return failure();
1928  Value sliceResult = candidateSliceOp.getResult();
1929 
1930  // 3. Fetch the corresponding output.
1931  OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
1932  unsigned resultNumber = yieldOpOperand.getOperandNumber();
1933 
1934  scf::ForOp topLevelForOp = cast<scf::ForOp>(loops.front().getOperation());
1935 
1936  return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber);
1937 }
1938 
1939 /// Fetch the first untiled consumer of a scf.forall's result which is yielded
1940 /// by a tensor.parallel_insert_slice.
1941 static FailureOr<OpOperand *>
1943  tensor::ParallelInsertSliceOp candidateSliceOp,
1945  assert(!loops.empty() && "unexpected loops to be empty");
1946  // 1. Check that the surrounding loop is a single scf.forall loop.
1947  if (loops.size() != 1) {
1948  return rewriter.notifyMatchFailure(
1949  candidateSliceOp, "expected single surrounding scf.forall");
1950  }
1951  auto forallOp = dyn_cast<scf::ForallOp>(loops.front().getOperation());
1952  if (!forallOp) {
1953  return rewriter.notifyMatchFailure(
1954  candidateSliceOp, "expected single surrounding scf.forall");
1955  }
1956 
1957  // 2. Fetch the corresponding output
1958  Value sliceDest = candidateSliceOp.getDest();
1959  auto iterArg = dyn_cast<BlockArgument>(sliceDest);
1960  if (!iterArg)
1961  return failure();
1962  if (iterArg.getOwner()->getParentOp() != forallOp)
1963  return failure();
1964 
1965  unsigned resultNumber =
1966  forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
1967  .getResultNumber();
1968 
1969  return getConsumerFromLoopUses(rewriter, forallOp, resultNumber);
1970 }
1971 
1972 /// A utility to fetch an untiled consumer of
1973 /// tensor.insert_slice/tensor.parallel_insert_slice.
1974 static FailureOr<OpOperand *>
1977  assert(!loops.empty() && "unexpected empty loops");
1978  if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1979  return getUntiledConsumerFromSlice(rewriter, insertSlice, loops);
1980  } else if (auto parallelInsertSlice =
1981  dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1982  return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice, loops);
1983  } else {
1984  return failure();
1985  }
1986 }
1987 
1988 /// Implementation of fusing consumer of a single slice by computing the
1989 /// slice of the consumer in-place for scf loop.
1990 FailureOr<scf::SCFFuseConsumerOfSliceResult>
1992  RewriterBase &rewriter, Operation *candidateSliceOp,
1994  // Return if `loops` is empty, return an error for now. Caller is expected
1995  // to handle this case.
1996  if (loops.empty()) {
1997  return candidateSliceOp->emitOpError(
1998  "cannot call tile and fuse consumer with an empty loop nest");
1999  }
2000  if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2001  candidateSliceOp))
2002  return failure();
2003 
2004  // 1. Get the consumer of scf.for for the result yielded by
2005  // tensor.insert_slice/parallel_insert_slice.
2006  FailureOr<OpOperand *> maybeConsumerOpOperand =
2007  getUntiledConsumerFromSlice(rewriter, candidateSliceOp, loops);
2008  if (failed(maybeConsumerOpOperand)) {
2009  return rewriter.notifyMatchFailure(candidateSliceOp,
2010  "could not fetch consumer to fuse");
2011  }
2012  OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
2013  Operation *consumerOp = consumerOpOperand->getOwner();
2014  unsigned operandNumber = consumerOpOperand->getOperandNumber();
2015  unsigned resultNumber = 0;
2016  if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
2017  resultNumber = producerResult.getResultNumber();
2018  } else {
2019  return rewriter.notifyMatchFailure(
2020  consumerOp, "consumer op's operand doesn't seem to be an OpResult");
2021  }
2022 
2023  LoopLikeOpInterface outerMostLoop = loops.front();
2024  LoopLikeOpInterface innerMostLoop = loops.back();
2025 
2026  // Check assumption for loop with `reorderOperations` disabled.
2027  if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
2028  return rewriter.notifyMatchFailure(
2029  outerMostLoop, "the first user of loop should not dominate any define "
2030  "of consumer operand(s)");
2031  }
2032 
2033  OpBuilder::InsertionGuard g(rewriter);
2034 
2035  // 2. Check consumer is not using scf loop's output as init.
2036  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
2037  if (!dstOp)
2038  return rewriter.notifyMatchFailure(consumerOp,
2039  "consumer op is not DPS operation");
2040  SmallVector<Value> dpsInits =
2041  llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
2042  if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
2043  return rewriter.notifyMatchFailure(
2044  consumerOp,
2045  "consumer op taking the result of scf.for as init is not supported");
2046  }
2047  SmallVector<Value> newInits = dpsInits;
2048 
2049  Location loc = outerMostLoop->getLoc();
2050 
2051  // 3. Move the whole loop structure right before firstUserOfLoop, the
2052  // dominance should be already ensured by `checkAssumptionForLoop`.
2053  FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(outerMostLoop);
2054  if (failed(firstUserOfLoop)) {
2055  return rewriter.notifyMatchFailure(
2056  outerMostLoop, "could not find the first user of outer most loop");
2057  }
2058  rewriter.moveOpBefore(outerMostLoop, *firstUserOfLoop);
2059 
2060  // 4. Set insertion point before terminator op of the loop and create a new
2061  // tensor.insert_slice. In the scf.for case this is a clone of the
2062  // candidateSliceOp whereas in the scf.forall case this is created from the
2063  // operands of tensor.parallel_insert_slice.
2064  tensor::InsertSliceOp clonedInsertSliceOp;
2065  if (auto sliceOp =
2066  dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
2067  auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
2068  rewriter.setInsertionPoint(newForallOp.getTerminator());
2069  clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
2070  loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
2071  sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
2072  } else {
2073  rewriter.setInsertionPoint(candidateSliceOp);
2074  clonedInsertSliceOp =
2075  cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
2076  }
2077 
2078  // 5.a. Clone consumer op.
2079  auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
2080 
2081  // 5.b. Replace all uses of the loop result with the result of the cloned
2082  // tensor.insert_slice.
2083  OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
2084  rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
2085  operandToReplace.set(clonedInsertSliceOp.getResult());
2086  });
2087 
2088  // 6. Perform tiling of the cloned consumer and replace the operand at
2089  // `operandNumber` with the source of the cloned tensor.insert_slice op.
2090  auto ossSliceOp =
2091  cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
2092  FailureOr<TilingResult> tileAndFuseResult =
2094  rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
2095  if (failed(tileAndFuseResult)) {
2096  return failure();
2097  }
2098  auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
2099  rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
2100  clonedInsertSliceOp.getSource());
2101 
2102  // 7. Reconstruct [nested] loop with new inits.
2103  YieldTiledValuesFn newYieldValuesFn =
2104  [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
2105  ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
2107  SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
2108  OpBuilder::InsertionGuard g(innerRewriter);
2109  // 8. Set inner insertPoint right before tiled consumer op.
2110  innerRewriter.setInsertionPoint(tiledConsumerOp);
2111 
2112  SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
2113  SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
2114  SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
2115 
2116  // 9. Check all insert stride is 1.
2117  if (llvm::any_of(strides, [](OpFoldResult stride) {
2118  return !isConstantIntValue(stride, 1);
2119  })) {
2120  return rewriter.notifyMatchFailure(
2121  candidateSliceOp, "containingOp's result yield with stride");
2122  }
2123 
2124  // 10. Try to get iter domain position from input position. Use
2125  // clonedConsumerOp instead of tiledConsumerOp, because the iteration
2126  // domain may require index computation based on the result size. The
2127  // sizes and offsets should be the same either way, but using
2128  // tiledConsumerOp could lead to some chained unnecessary extra index
2129  // computation.
2130  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
2131  if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
2132  rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
2133  iterDomainSizes))) {
2134  return rewriter.notifyMatchFailure(
2135  clonedConsumerOp,
2136  "can't get iter domain position from input position");
2137  }
2138 
2139  // 11. Try to fetch the offset and size for all results of the cloned
2140  // consumer. This would then be used to form the corresponding
2141  // tensor.insert_slice/parallel_insert_slice later.
2142  unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
2144  totalNumResultsOfConsumer);
2146  totalNumResultsOfConsumer);
2147  for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
2148  if (failed(tiledConsumerOp.getResultTilePosition(
2149  rewriter, idx, iterDomainOffsets, iterDomainSizes,
2150  resultOffsets[idx], resultSizes[idx]))) {
2151  return rewriter.notifyMatchFailure(
2152  tiledConsumerOp,
2153  "can't get result domain position from iter domain position");
2154  }
2155  }
2156 
2157  // 12. Create `extract_slice` for `iter_args` for DPS operation if
2158  // necessary.
2159  if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
2160  tiledConsumerOp.getOperation())) {
2161  rewriter.setInsertionPoint(tiledDestStyleOp);
2162  for (const auto &&[index, newRegionArg] :
2163  llvm::enumerate(newRegionIterArgs)) {
2164  auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
2165  loc, newRegionArg, resultOffsets[index], resultSizes[index],
2166  SmallVector<OpFoldResult>(resultOffsets[index].size(),
2167  rewriter.getIndexAttr(1)));
2168  // Make a copy of index to avoid a capturing structured binding, which
2169  // is a C++20 extension.
2170  auto dstNumber = index;
2171  rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
2172  tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
2173  });
2174  }
2175  }
2176 
2177  // 13. Prepare tiled offset and sizes for later `insert_slice` creation by
2178  // caller.
2179  Block *block = rewriter.getInsertionPoint()->getBlock();
2180  rewriter.setInsertionPoint(block->getTerminator());
2181  for (const auto &&[index, result] :
2182  llvm::enumerate(tiledConsumerOp->getResults())) {
2183  tiledResult.push_back(result);
2184  tiledOffset.emplace_back(resultOffsets[index]);
2185  tiledSizes.emplace_back(resultSizes[index]);
2186  }
2187  return success();
2188  };
2189  // 14. Add new inits to [nested] loops.
2190  if (failed(addInitOperandsToLoopNest(rewriter, loops, newInits,
2191  newYieldValuesFn))) {
2192  return rewriter.notifyMatchFailure(tiledConsumerOp,
2193  "unable to add new inits to nest loop");
2194  }
2195 
2196  // 15. Replace the result of scf loop and consumer op with new loop's
2197  // results.
2198 
2199  for (auto &&[oldResult, newResult] :
2200  llvm::zip(consumerOp->getResults(),
2201  loops.front()->getResults().take_back(newInits.size()))) {
2202  rewriter.replaceAllUsesWith(oldResult, newResult);
2203  }
2204 
2205  // 16. Need to erase the old scf loop and the cloned consumer op.
2206  rewriter.eraseOp(clonedConsumerOp);
2207 
2209  consumerOpOperand,
2210  &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
2211  tileAndFuseResult->tiledOps};
2212 }
2213 
2214 //===----------------------------------------------------------------------===//
2215 // lowerToLoopsUsingSCFForOp implementation.
2216 //===----------------------------------------------------------------------===//
2217 
2218 FailureOr<SmallVector<scf::ForOp>>
2220  TilingInterface op) {
2221  // TODO: Handle cases where the op has results if needed.
2222  if (op->getNumResults() > 0) {
2223  return rewriter.notifyMatchFailure(
2224  op, "unable to lower to loops operations with return values");
2225  }
2226 
2227  SmallVector<Range> domain = op.getIterationDomain(rewriter);
2228  SmallVector<Value> ivs;
2230  Location loc = op.getLoc();
2231  for (auto loopRange : domain) {
2232  Value offsetVal =
2233  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
2234  Value sizeVal =
2235  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
2236  Value strideVal =
2237  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
2238  auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
2239  strideVal, ValueRange{});
2240  loops.push_back(loop);
2241  ivs.push_back(loop.getInductionVar());
2242  rewriter.setInsertionPoint(loop.getBody()->getTerminator());
2243  }
2244  if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
2245  return failure();
2246  }
2247  return loops;
2248 }
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static LogicalResult verifyTileSizeOptions(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options)
Verify the tile size options are set in a consistent manner.
static LogicalResult checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp)
A utility function that checks whether the only use of the result of a tensor.insert_slice op is in a...
static bool isPerfectlyNestedForLoops(MutableArrayRef< LoopLikeOpInterface > loops)
Check that the loop is perfectly nested.
std::function< LogicalResult(RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, SmallVector< Value > &tiledValues, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> YieldTiledValuesFn
A function that allows returning additional yielded values during yieldTiledValuesAndReplace.
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, OpFoldResult numThreads, OpFoldResult iterationSize)
Returns true if the maximum tile offset tileSize * numThreads-1 is less than iterationSize.
static LogicalResult getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize, const scf::SCFTilingOptions &options)
static FailureOr< llvm::SetVector< Operation * > > checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, bool reorderOperations)
This utility currently checks whether the first userOp of loop is NOT before the last defineOp of con...
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, OpFoldResult offset, OpFoldResult tileSize)
Returns the bounded tile size given the current offset, loopRange and tileSize, i....
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Append the specified additional newInitOperands operands to the loops existing init operands (or simi...
static LogicalResult generateLoopNestUsingForOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ValueRange destinationTensors, YieldTiledValuesFn yieldTiledValuesFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.for operation.
static FailureOr< OpOperand * > getConsumerFromLoopUses(RewriterBase &rewriter, Operation *loopOp, unsigned resultNumber)
Fetches the OpOperand of the first valid user (and use) of the value val which implements TilingInter...
static void checkSafeToTileToForall(TilingInterface op, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Checks if any of the tiled loops are not parallel.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, ArrayRef< Range > iterationDomain, const scf::SCFTilingOptions &options)
Method to instantiate the tile sizes and/or number of threads specified by the user.
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< LoopLikeOpInterface > loops)
Return the untiled producer whose slice is used in a tiled consumer.
static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using the loop construct specifed in options.
static bool tileDividesIterationDomain(Range loopRange)
Check if stride evenly divides the trip count size - offset.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes)
Function to return the bounds of the loops to be generated.
static Operation * cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *op, ValueRange newDestArgs)
Clones the operation and updates the destination if the operation implements the DestinationStyleOpIn...
static FailureOr< Operation * > getFirstUserOfLoop(Operation *loopOp)
An utility to get the first user of the given loopOp.
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, const scf::SCFTilingOptions &options)
static LogicalResult addInitOperandsToLoopNest(RewriterBase &rewriter, MutableArrayRef< LoopLikeOpInterface > loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn)
Method to add new init values to a loop nest.
static FailureOr< MergeResult > mergeTilingResults(RewriterBase &rewriter, TilingInterface op, ValueRange partialResults, const scf::SCFTilingOptions &options)
static FailureOr< OpOperand * > getUntiledConsumerFromSlice(RewriterBase &rewriter, tensor::InsertSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Fetch the untiled consumer of the outermost scf.for's result which is yielded by a tensor....
static FailureOr< SmallVector< Value > > createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, ArrayRef< OpFoldResult > tileSizes, const scf::SCFTilingOptions &options)
static LogicalResult generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ArrayRef< Attribute > mappingVector, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.forall operation.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:968
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgListType getArguments()
Definition: Block.h:87
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
A class for computing basic dominance information.
Definition: Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.cpp:324
This class allows control over how the GreedyPatternRewriteDriver works.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class represents a saved insertion point.
Definition: Builders.h:324
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:345
This class helps build Operations.
Definition: Builders.h:204
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:442
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:549
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:313
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:395
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:433
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:317
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:409
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:243
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:433
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:442
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:445
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:852
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:602
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:594
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
Definition: AffineOps.cpp:1335
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1328
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1224
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< scf::SCFFuseConsumerOfSliceResult > tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing consumer of a single slice by computing the slice of the consumer in-place f...
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< SmallVector< Operation * > > yieldReplacementForFusedProducer(RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef< LoopLikeOpInterface > loops, ArrayRef< unsigned > yieldResultNumber=ArrayRef< unsigned >{})
Reconstruct the fused producer from within the tiled-and-fused code.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
FailureOr< scf::SCFTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
std::optional< SCFFuseProducerOfSliceResult > tileAndFuseProducerOfSlice(RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing producer of a single slice by computing the slice of the producer in-place.
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:79
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:114
FailureOr< TilingResult > replaceInsertSliceWithTiledConsumer(OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp, OpOperand &consumerOp)
Method to swap an tensor.insert_slice with its consumer when the consumer implements the TilingInterf...
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
@ ExistingAndNewOps
Only pre-existing and newly created ops are processed.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Container for the result of merge operation of tiling.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:282
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:421
Container for result values of tiling.
Fuse the consumer of the source of candidateSliceOp by computing the required slice of the consumer i...
Fuse the producer of the source of candidateSliceOp by computing the required slice of the producer i...
Control function to check if a slice needs to be fused or not, The control function receives 1) the s...
Options used to control tile + fuse.
Transformation information returned after tile and fuse.
Options to use to control tiling.
SCFTileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes to use for each loop.
SCFTilingOptions & setNumThreads(ArrayRef< OpFoldResult > numThreads)
Convenience function to set the numThreadsComputationFunction to a function that computes num threads...
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
ReductionTilingStrategy
Specify how reduction dimensions should be tiled.
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.