MLIR  22.0.0git
TileUsingInterface.cpp
Go to the documentation of this file.
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
22 #include "mlir/IR/Dominance.h"
23 #include "mlir/IR/PatternMatch.h"
28 #include "llvm/ADT/ScopeExit.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/Support/Debug.h"
31 #include <optional>
32 
33 #define DEBUG_TYPE "tile-using-interface"
34 
35 using namespace mlir;
36 
39  assert(!tileSizeComputationFunction && "tile sizes already set");
40  auto tileSizes = llvm::to_vector(ts);
41  tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
42  return tileSizes;
43  };
44  return *this;
45 }
46 
49  assert(!numThreadsComputationFunction && "num tiles already set");
50  auto numThreads = llvm::to_vector(nt);
51  numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) {
52  return numThreads;
53  };
54  return *this;
55 }
56 
57 /// Helper method to adjust the interchange vector to match the iteration
58 /// domain.
61  size_t iterationDomainSize) {
62  SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
63  if (filledVector.size() < iterationDomainSize) {
64  auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
65  filledVector.append(range.begin(), range.end());
66  }
67  if (filledVector.size() > iterationDomainSize)
68  filledVector.resize(iterationDomainSize);
69  return filledVector;
70 }
71 
72 //===----------------------------------------------------------------------===//
73 // tileUsingSCF implementation.
74 //===----------------------------------------------------------------------===//
75 
76 /// Verify the tile size options are set in a consistent manner.
77 static LogicalResult verifyOptions(RewriterBase &rewriter, Location loc,
79  // Specifying number of threads is only supported on `scf.forall` op.
80  if (options.numThreadsComputationFunction &&
82  return rewriter.notifyMatchFailure(
83  loc, "number of threads can only by specified when loop type is "
84  "set to use `scf.forall`");
85  }
86 
87  // If specified, check that the interchange vector is a permutation.
88  if (!options.interchangeVector.empty()) {
89  if (!isPermutationVector(options.interchangeVector)) {
90  return rewriter.notifyMatchFailure(
91  loc, "invalid interchange vector, not a permutation of the entire "
92  "iteration space");
93  }
94  }
95  return success();
96 }
97 
98 /// Method to instantiate the tile sizes and/or number of threads specified
99 /// by the user.
100 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
101 getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
102  ArrayRef<Range> iterationDomain,
104  OpFoldResult zero = rewriter.getIndexAttr(0);
105  SmallVector<OpFoldResult> tileSizes, numThreads;
106  size_t numLoops = iterationDomain.size();
107 
108  // Check whether the number of tiles to use is specified.
109  if (options.numThreadsComputationFunction) {
110  numThreads = options.numThreadsComputationFunction(rewriter, op);
111  numThreads.resize(numLoops, zero);
112 
113  // If the number of tiles is also specified, use that.
114  if (options.tileSizeComputationFunction) {
115  tileSizes = options.tileSizeComputationFunction(rewriter, op);
116  tileSizes.resize(numLoops, zero);
117  return {tileSizes, numThreads};
118  }
119 
120  // Compute the tile sizes from the iteration domain and number
121  // of tiles as follows
122  // - niters = ceilDiv(ub - lb, step)
123  // - tileSize = ceilDiv(niters, numThreads)
124  AffineExpr s0, s1, s2;
125  bindSymbols(rewriter.getContext(), s0, s1, s2);
126  // TODO: The step here is assumed to be 1.
127  AffineExpr numItersExpr = (s1 - s0);
128  AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s2);
129  tileSizes.resize(numLoops, zero);
130  for (auto [index, range, nt] :
131  llvm::enumerate(iterationDomain, numThreads)) {
132  if (isZeroInteger(nt))
133  continue;
134 
135  tileSizes[index] = affine::makeComposedFoldedAffineApply(
136  rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
137  }
138  tileSizes.resize(numLoops, zero);
139  return {tileSizes, numThreads};
140  }
141 
142  // Enforce the convention that "tiling by zero"
143  // skips tiling a particular dimension. This convention is significantly
144  // simpler to handle instead of adjusting affine maps to account for missing
145  // dimensions.
146  assert(options.tileSizeComputationFunction &&
147  "expected tile sizes to be specified");
148  tileSizes = options.tileSizeComputationFunction(rewriter, op);
149  tileSizes.resize(numLoops, zero);
150 
151  return {tileSizes, numThreads};
152 }
153 
154 /// Checks if any of the tiled loops are not parallel.
155 static LogicalResult checkTileSizes(TilingInterface op,
157  ReductionTilingStrategy reductionStrategy,
158  ArrayRef<OpFoldResult> tileSizes,
159  ArrayRef<OpFoldResult> numThreads) {
160  auto iterators = op.getLoopIteratorTypes();
161  assert(iterators.size() == tileSizes.size() &&
162  "expected as many tile size values as number of loops");
163  assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
164  "when specified, expected number of threads to use for each loop");
165 
166  bool isParallelTiling = false;
167  for (auto [index, iterator, tileSize] :
168  llvm::enumerate(iterators, tileSizes)) {
169  if (!isConstantIntValue(tileSize, 0)) {
170  isParallelTiling |= iterator == utils::IteratorType::parallel;
171  }
172 
174  reductionStrategy == ReductionTilingStrategy::FullReduction) {
175  // If num threads is specified, check that it is greater than one only for
176  // parallel dimensions.
177  if (!numThreads.empty()) {
178  if (std::optional<int64_t> constNumThreads =
179  getConstantIntValue(numThreads[index])) {
180  if (constNumThreads.value() > 1 &&
181  iterator != utils::IteratorType::parallel) {
182  op.emitWarning() << "tiling is not thread safe at axis #" << index;
183  }
184  }
185  continue;
186  }
187 
188  if (std::optional<int64_t> constTileSize =
189  getConstantIntValue(tileSize)) {
190  if (constTileSize.value() > 0 &&
191  iterator != utils::IteratorType::parallel) {
192  op.emitWarning() << "tiling is not thread safe at axis #" << index;
193  }
194  }
195  }
196  }
197 
198  if (reductionStrategy != ReductionTilingStrategy::FullReduction) {
199  if (isParallelTiling) {
200  return op->emitOpError("tiling parallel dimensions is not supported with "
201  "partial reduction tiling strategies");
202  }
203  }
204  return success();
205 }
206 
207 /// Get the reduction dims that are tiled. This accounts for reduction dims
208 /// that are specified as tiled, but the tile size is 0.
209 static SetVector<unsigned>
212  SetVector<unsigned> reductionDims;
213  for (auto dim : options.reductionDims) {
214  if (isConstantIntValue(tileSizes[dim], 0))
215  continue;
216  reductionDims.insert(dim);
217  }
218  return reductionDims;
219 }
220 
221 /// Check if `stride` evenly divides the trip count `size - offset`.
222 static bool tileDividesIterationDomain(Range loopRange) {
223  std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
224  if (!offsetAsInt)
225  return false;
226  std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
227  if (!sizeAsInt)
228  return false;
229  std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
230  if (!strideAsInt)
231  return false;
232  return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
233 }
234 
235 /// Returns the bounded tile size given the current `offset`, `loopRange` and
236 /// `tileSize`, i.e., `min(tileSize, range.end() - offset)`.
238  Range loopRange, OpFoldResult offset,
239  OpFoldResult tileSize) {
240  std::optional<int64_t> ts = getConstantIntValue(tileSize);
241  if (ts && ts.value() == 1)
242  return tileSize;
243 
245  Range{loopRange.offset, loopRange.size, tileSize}))
246  return tileSize;
247 
248  // The tile size to use (to avoid out of bounds access) is minimum of
249  // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
250  // loop.
251  AffineExpr s0, s1, d0;
252  bindDims(b.getContext(), d0);
253  bindSymbols(b.getContext(), s0, s1);
254  AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext());
255  Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
257  b, loc, minMap, SmallVector<OpFoldResult>{offset, size, tileSize});
258 }
259 
260 /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
261 /// than `iterationSize`.
263  OpFoldResult numThreads,
264  OpFoldResult iterationSize) {
265  std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
266  std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
267  std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
268  if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
269  return false;
270  return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
271 }
272 
273 /// Compute the `OpFoldResult`s that represents the multi-dimensional
274 /// `offset`s and `size`s of the tile of the iteration space that the
275 /// innermost loop body of the generated tiled loops corresponds to.
276 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
278  ReductionTilingStrategy strategy, ValueRange ivs,
279  ArrayRef<Range> iterationDomain,
280  ArrayRef<OpFoldResult> tileSizes,
281  ArrayRef<OpFoldResult> numThreads,
282  const llvm::SetVector<unsigned> &reductionDims) {
283  SmallVector<OpFoldResult> offsets, sizes;
284  int materializedLoopNum = 0;
285 
286  if (!numThreads.empty()) {
287  AffineExpr d0, d1, s0, s1;
288  AffineExpr offsetExpr, residualTileSizeExpr;
289  bindDims(rewriter.getContext(), d0, d1);
290  bindSymbols(rewriter.getContext(), s0, s1);
291  offsetExpr = d0 + d1 * s0;
292  residualTileSizeExpr = s1 - (d0 + d1 * s0);
293 
294  for (auto [index, nt, tileSize, loopRange] :
295  llvm::enumerate(numThreads, tileSizes, iterationDomain)) {
296 
297  // Non-tiled cases, set the offset and size to the
298  // `loopRange.offset/size`.
299  if (isZeroInteger(nt)) {
300  offsets.push_back(loopRange.offset);
301  sizes.push_back(loopRange.size);
302  continue;
303  }
304 
305  Value iv = ivs[materializedLoopNum++];
307  rewriter, loc, offsetExpr,
308  ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize});
310  rewriter, loc, residualTileSizeExpr,
311  {loopRange.offset, nt, tileSize, loopRange.size});
312 
313  OpFoldResult size = tileSize;
314  if (!isZeroInteger(residualTileSize)) {
315  OpFoldResult sizeMinusOffsetPerThread =
316  affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
317  {offset, loopRange.size});
319  rewriter, loc,
321  {sizeMinusOffsetPerThread, tileSize});
322  }
323 
324  // Consider the case where the original loop was `[0, 100)`.
325  // If number of threads are `7`, the tile size would be computed as
326  // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6)
327  // - `offset = 0 + 6 * 15 = 105`
328  // - `tileSize = min(15, 100 - 105) = -5`
329  // To avoid negative tile sizes, we need to do a further
330  // `nonNegativeTileSize = affine.max(0, tileSize)`.
331  // This `max` can be avoided if
332  // `offset + tileSize * (numThreads - 1) < (ub - lb)`
333  if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) {
334  AffineMap maxMap =
337  rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});
338  }
339 
340  offsets.push_back(offset);
341  sizes.push_back(size);
342  }
343  return {offsets, sizes};
344  } else {
345  for (auto [tileSize, loopRange] :
346  llvm::zip_equal(tileSizes, iterationDomain)) {
347 
348  // Non-tiled cases, set the offset and size to the
349  // `loopRange.offset/size`.
350  if (isZeroInteger(tileSize)) {
351  offsets.push_back(loopRange.offset);
352  sizes.push_back(loopRange.size);
353  continue;
354  }
355 
356  Value iv = ivs[materializedLoopNum++];
357  OpFoldResult offset = getAsOpFoldResult(iv);
358  offsets.push_back(offset);
359  OpFoldResult size =
360  getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize);
361  sizes.push_back(size);
362  }
363  return {offsets, sizes};
364  }
365 }
366 
367 /// Function to return the bounds of the loops to be generated.
368 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
371  ArrayRef<OpFoldResult> tileSizes) {
372  SmallVector<OpFoldResult> lbs, ubs, steps;
373  for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
374  // No loop if the tile size is 0.
375  if (isZeroInteger(tileSize))
376  continue;
377  lbs.push_back(loopRange.offset);
378  ubs.push_back(loopRange.size);
379  steps.push_back(tileSize);
380  }
381  return {lbs, ubs, steps};
382 }
383 
384 /// A function that allows returning additional yielded values during
385 /// `yieldTiledValuesAndReplace`.
386 /// - `ivs` induction variable for the loop.
387 /// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
388 /// - `tiledValues` the tiled values to return. Must be of same size as
389 /// `newbbArgs`, each element of this array is inserted into the corresponding
390 /// element in `newbbArgs`.
391 /// - `resultOffsets` is of the same size as `tiledValues` and represents
392 /// the offsets to use when inserting corresponding element from `tiledValues`
393 /// into the element from `newBbArgs`.
394 /// - `resultSizes` is of the same size as `tiledValues` and represents
395 /// the size of the corresponding element from `tiledValues` inserted into
396 /// the element from `newBbArgs`.
397 /// In case the method needs to return `failure()` the method is expected
398 /// to clean up any inserted operations.
399 using YieldTiledValuesFn = std::function<LogicalResult(
400  RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
401  SmallVector<Value> &tiledValues,
402  SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
403  SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
404 
405 /// Clones the operation and updates the destination if the operation
406 /// implements the `DestinationStyleOpInterface`.
408  Operation *op,
409  ValueRange newDestArgs) {
410  Operation *clonedOp = rewriter.clone(*op);
411  if (newDestArgs.empty())
412  return clonedOp;
413  if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
414  destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
415  return clonedOp;
416 }
417 
418 /// Generate the tile-loop nest using `scf.for` operation.
419 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
420 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
421 /// - `destinationTensors` are the init values to use for the outer most loop.
422 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
423 /// most
424 /// loop.
425 /// - `loops` is an in-out parameter into which the generated loops are
426 /// populated.
427 static LogicalResult generateLoopNestUsingForOp(
428  RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
429  ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
430  YieldTiledValuesFn yieldTiledValuesFn,
432  assert(!loopRanges.empty() && "unexpected empty loop ranges");
433  assert(loopRanges.size() == tileSizes.size() &&
434  "expected as many tile sizes as loop ranges");
435  OpBuilder::InsertionGuard guard(rewriter);
436 
437  SmallVector<OpFoldResult> lbs, ubs, steps;
438  std::tie(lbs, ubs, steps) =
439  getLoopBounds(rewriter, loc, loopRanges, tileSizes);
440  SmallVector<Value> lbVals =
441  getValueOrCreateConstantIndexOp(rewriter, loc, lbs);
442  SmallVector<Value> ubVals =
443  getValueOrCreateConstantIndexOp(rewriter, loc, ubs);
444  SmallVector<Value> stepVals =
445  getValueOrCreateConstantIndexOp(rewriter, loc, steps);
446 
447  SmallVector<Value> ivs;
448  for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
449  auto loop =
450  scf::ForOp::create(rewriter, loc, lb, ub, step, destinationTensors,
451  [](OpBuilder &bodyBuilder, Location bodyLoc,
452  Value iv, ValueRange /*iterArgs*/) {});
453  loops.push_back(loop);
454  ivs.push_back(loop.getInductionVar());
455  rewriter.setInsertionPointToEnd(loop.getBody());
456  destinationTensors = loop.getRegionIterArgs();
457  }
458 
459  SmallVector<Value> tiledResults;
460  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
461  if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
462  tiledResults, resultOffsets, resultSizes))) {
463  return rewriter.notifyMatchFailure(
464  loc, "failed to generate inner tile loop body");
465  }
466  if (loops.empty())
467  return success();
468 
469  assert(tiledResults.size() == destinationTensors.size() &&
470  "Number of results of body should be equal to number of iter args");
471 
472  // 6. Yield all the results of the tiled operation.
473  SmallVector<Value> yieldedValues;
474  for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
475  llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
476  resultSizes)) {
477  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
478  rewriter.getIndexAttr(1));
479  auto insertSlice = tensor::InsertSliceOp::create(
480  rewriter, loc, tiledValue, destinationTensor, resultOffset, resultSize,
481  resultStride);
482  yieldedValues.push_back(insertSlice);
483  }
484  scf::YieldOp::create(rewriter, loc, yieldedValues);
485 
486  // Add the scf.yield operations for all the outer loops.
487  for (auto [outerLoop, innerLoop] :
488  llvm::zip_equal(MutableArrayRef(loops).drop_back(),
489  MutableArrayRef(loops).drop_front())) {
490  rewriter.setInsertionPointToEnd(
491  cast<scf::ForOp>(outerLoop.getOperation()).getBody());
492  scf::YieldOp::create(rewriter, outerLoop.getLoc(), innerLoop->getResults());
493  }
494  return success();
495 }
496 
497 /// Generate the tile-loop nest using `scf.forall` operation.
498 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
499 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
500 /// - `destinationTensors` are the init values to use for the outer most loop.
501 /// - `mappingVector` is the mapping attributes to use for loop construction.
502 /// Can be empty.
503 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
504 /// most
505 /// loop.
506 /// - `loops` is an in-out parameter into which the generated loops are
507 /// populated.
508 static LogicalResult generateLoopNestUsingForallOp(
509  RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
510  ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
511  ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
513  assert(!loopRanges.empty() && "unexpected empty loop ranges");
514  assert(loopRanges.size() == tileSizes.size() &&
515  "expected as many tile sizes as loop ranges");
516  OpBuilder::InsertionGuard guard(rewriter);
517 
518  std::optional<ArrayAttr> mappingAttr;
519  if (!mappingVector.empty())
520  mappingAttr = rewriter.getArrayAttr(mappingVector);
521 
522  scf::ForallOp forallOp;
523  bool useNumThreads = !numThreads.empty();
524 
525  if (useNumThreads) {
526  // Prune the zero numthreads.
527  SmallVector<OpFoldResult> nonZeroNumThreads;
528  for (auto nt : numThreads) {
529  if (isZeroInteger(nt))
530  continue;
531  nonZeroNumThreads.push_back(nt);
532  }
533  forallOp = scf::ForallOp::create(rewriter, loc, nonZeroNumThreads,
534  destinationTensors, mappingAttr);
535  } else {
536  SmallVector<OpFoldResult> lbs, ubs, steps;
537  std::tie(lbs, ubs, steps) =
538  getLoopBounds(rewriter, loc, loopRanges, tileSizes);
539  forallOp = scf::ForallOp::create(rewriter, loc, lbs, ubs, steps,
540  destinationTensors, mappingAttr);
541  }
542  loops.push_back(forallOp);
543 
544  rewriter.setInsertionPoint(forallOp.getTerminator());
545  destinationTensors = forallOp.getRegionOutArgs();
546 
547  SmallVector<Value> tiledResults;
548  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
549  if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
550  destinationTensors, tiledResults, resultOffsets,
551  resultSizes)))
552  return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
553 
554  rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
555  for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
556  llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
557  resultSizes)) {
558  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
559  rewriter.getIndexAttr(1));
560 
561  tensor::ParallelInsertSliceOp::create(rewriter, loc, tiledValue,
562  destinationTensor, resultOffset,
563  resultSize, resultStride);
564  }
565  return success();
566 }
567 
568 /// Generate the tile-loop nest using the loop construct specifed in `options`.
569 /// - `options`: Tiling options specified.
570 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
571 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
572 /// - `destinationTensors` are the init values to use for the outer most loop.
573 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
574 /// most
575 /// loop.
576 /// - `loops` is an in-out parameter into which the generated loops are
577 /// populated.
578 static LogicalResult generateLoopNest(
579  RewriterBase &rewriter, Location loc,
581  ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
582  ValueRange destinationTensors, ArrayRef<Attribute> mappingVector,
584  // If the tile sizes are all zero, no loops are generated. Just call the
585  // callback function to handle untiled case.
586  if (llvm::all_of(tileSizes, isZeroInteger)) {
587  SmallVector<Value> tiledResults;
588  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
589  return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
590  tiledResults, resultOffsets, resultSizes);
591  }
592  if (loopType == scf::SCFTilingOptions::LoopType::ForOp) {
593  return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
594  destinationTensors, tiledBodyFn, loops);
595  }
598  rewriter, loc, loopRanges, tileSizes, numThreads, mappingVector,
599  destinationTensors, tiledBodyFn, loops);
600  }
601  return rewriter.notifyMatchFailure(loc, "unhandled loop type");
602 }
603 
604 static FailureOr<SmallVector<Value>> createInitialTensorsForTiling(
605  RewriterBase &rewriter, TilingInterface op,
606  ReductionTilingStrategy reductionStrategy, ArrayRef<Range> iterationDomain,
607  ArrayRef<OpFoldResult> numThreads, ArrayRef<OpFoldResult> tileSizes,
608  const SetVector<unsigned> &reductionDims) {
609  SmallVector<Value> initTensors;
610  Location loc = op->getLoc();
611  if (reductionStrategy == ReductionTilingStrategy::FullReduction) {
612  if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, initTensors)))
613  return failure();
614  return initTensors;
615  }
616 
617  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
618  if (!redOp) {
619  return op->emitOpError(
620  "PartialReductionOuterReduction tiling strategy is only supported for "
621  "operations implementing PartialReductionOpInterface");
622  }
623  SmallVector<OpFoldResult> sizes(iterationDomain.size());
624  AffineExpr s0, s1, s2;
625  bindSymbols(rewriter.getContext(), s0, s1, s2);
626  AffineExpr sizeExpr = ((s0 - s1).ceilDiv(s2));
627  AffineExpr divExpr = s0.ceilDiv(s1);
628  for (auto [index, domain, tileSize] :
629  llvm::enumerate(iterationDomain, tileSizes)) {
630  if (!numThreads.empty()) {
631  // Untiled case.
632  if (isConstantIntValue(numThreads[index], 0)) {
634  rewriter, op.getLoc(), sizeExpr,
635  {domain.size, domain.offset, domain.stride});
636  continue;
637  }
638  sizes[index] = numThreads[index];
639  continue;
640  }
641 
642  // Non reduction dimensions/non-tiled dimensions.
643  if (!reductionDims.contains(index) || isConstantIntValue(tileSize, 0)) {
645  rewriter, op.getLoc(), sizeExpr,
646  {domain.size, domain.offset, domain.stride});
647  continue;
648  }
649 
650  if (reductionStrategy ==
652  sizes[index] = tileSize;
653  continue;
654  }
655 
656  assert(reductionStrategy ==
659  rewriter, op.getLoc(), sizeExpr,
660  {domain.size, domain.offset, domain.stride});
662  rewriter, op.getLoc(), divExpr, {normalizedRange, tileSize});
663  }
664  return redOp.generateInitialTensorForPartialReduction(rewriter, loc, sizes,
665  reductionDims);
666 }
667 
668 /// For the case of `ReductionTilingStrategy::PartialReductionOuterParallel`
669 /// the `PartialReductionOpInterface` methods need the index of the parallel
670 /// split reduction being executed.
673  ReductionTilingStrategy reductionStrategy, ValueRange ivs,
674  ArrayRef<OpFoldResult> numThreads,
675  ArrayRef<OpFoldResult> tileSizes,
676  const SetVector<unsigned> &reductionDims) {
677  SmallVector<OpFoldResult> splitReductionIvs;
678  splitReductionIvs.resize(reductionDims.size(), rewriter.getIndexAttr(0));
679  AffineExpr s0, s1;
680  bindSymbols(rewriter.getContext(), s0, s1);
681  AffineExpr divExpr = s0.floorDiv(s1);
682  int ivIndex = 0;
683  if (reductionStrategy ==
685  for (auto [index, reductionDim] : llvm::enumerate(reductionDims)) {
686  if (!numThreads.empty()) {
687  splitReductionIvs[index] = ivs[ivIndex++];
688  continue;
689  }
690  splitReductionIvs[index] = affine::makeComposedFoldedAffineApply(
691  rewriter, loc, divExpr,
692  ArrayRef<OpFoldResult>{ivs[ivIndex++], tileSizes[reductionDim]});
693  }
694  }
695  return splitReductionIvs;
696 }
697 
698 static FailureOr<TilingResult>
699 getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
700  ReductionTilingStrategy reductionStrategy,
701  ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
703  ArrayRef<OpFoldResult> numThreads,
704  ArrayRef<OpFoldResult> tileSizes,
705  const SetVector<unsigned> &reductionDims) {
706  if (reductionStrategy == ReductionTilingStrategy::FullReduction) {
707  return op.getTiledImplementation(rewriter, offsets, sizes);
708  }
709 
710  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
711  if (!redOp) {
712  return rewriter.notifyMatchFailure(
713  op, "PartialReductionOuterReduction tiling strategy is only "
714  "supported for operations "
715  "implementing PartialReductionOpInterface");
716  }
717 
718  SmallVector<OpFoldResult> splitReductionIvs =
719  getSplitReductionIvs(rewriter, op.getLoc(), reductionStrategy, ivs,
720  numThreads, tileSizes, reductionDims);
721  return redOp.tileToPartialReduction(rewriter, op.getLoc(), reductionStrategy,
722  regionIterArg, offsets, sizes,
723  reductionDims, splitReductionIvs);
724 }
725 
726 static LogicalResult getResultTilePosition(
727  RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy,
728  int64_t index, Value tiledResult, TilingInterface op,
730  ValueRange ivs, ArrayRef<OpFoldResult> numThreads,
731  ArrayRef<OpFoldResult> tileSizes, const SetVector<unsigned> &reductionDims,
732  SmallVector<OpFoldResult> &resultOffset,
733  SmallVector<OpFoldResult> &resultSize) {
734 
735  if (reductionStrategy == ReductionTilingStrategy::FullReduction) {
736  return op.getResultTilePosition(rewriter, index, offsets, sizes,
737  resultOffset, resultSize);
738  }
739  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
740  if (!redOp) {
741  return rewriter.notifyMatchFailure(
742  op, "PartialReductionOuterReduction tiling strategy is only supported"
743  "for operations implementing PartialReductionOpInterface");
744  }
745  SmallVector<OpFoldResult> splitReductionIvs =
746  getSplitReductionIvs(rewriter, op.getLoc(), reductionStrategy, ivs,
747  numThreads, tileSizes, reductionDims);
748  return redOp.getPartialResultTilePosition(
749  rewriter, index, reductionStrategy, offsets, sizes, reductionDims,
750  splitReductionIvs, resultOffset, resultSize);
751 }
752 
753 static FailureOr<MergeResult>
754 mergeTilingResults(RewriterBase &rewriter, TilingInterface op,
755  ReductionTilingStrategy reductionStrategy,
756  const SetVector<unsigned> &reductionDims,
757  ValueRange partialResults) {
758  assert(reductionStrategy != ReductionTilingStrategy::FullReduction &&
759  "expected merge to be called for only partial reduction cases");
760 
761  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
762  if (!redOp) {
763  return rewriter.notifyMatchFailure(
764  op, "PartialReductionOuterReduction tiling strategy is only "
765  "supported for operations "
766  "implementing PartialReductionOpInterface");
767  }
768  return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
769  reductionDims);
770 }
771 
772 /// Append the specified additional `newInitOperands` operands to the
773 /// loops existing `init` operands (or similar), and replace `loopOp` with
774 /// the new loop that has the additional init operands. The loop body of
775 /// this loop is moved over to the new loop. `yieldTiledValuesFn`
776 /// is called to get the new tiled values returned, and the offset
777 /// and sizes at which the tiled value is inserted into the
778 /// new region iter_args that correspond to the newly added init operands.
779 template <typename LoopType>
780 FailureOr<LoopLikeOpInterface>
781 yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter,
782  ValueRange newInitOperands,
783  YieldTiledValuesFn yieldTiledValuesFn) {
784  return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
785 }
786 
787 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`.
788 template <>
789 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
790  scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
791  YieldTiledValuesFn yieldTiledValuesFn) {
792  OpBuilder::InsertionGuard g(rewriter);
793  Location loc = loopOp.getLoc();
794  rewriter.setInsertionPoint(loopOp);
795 
796  auto inits = llvm::to_vector(loopOp.getInitArgs());
797  inits.append(newInitOperands.begin(), newInitOperands.end());
798  auto newLoop = scf::ForOp::create(
799  rewriter, loc, loopOp.getLowerBound(), loopOp.getUpperBound(),
800  loopOp.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {},
801  loopOp.getUnsignedCmp());
802 
803  // Move the loop body to the new op.
804  Block *loopBody = loopOp.getBody();
805  Block *newLoopBody = newLoop.getBody();
806  rewriter.mergeBlocks(
807  loopBody, newLoopBody,
808  newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
809 
810  auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator());
811  rewriter.setInsertionPoint(yieldOp);
812 
813  SmallVector<Value> tiledValues;
814  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
815  ValueRange newRegionIterArgs =
816  newLoop.getRegionIterArgs().take_back(newInitOperands.size());
817  if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
818  newRegionIterArgs, tiledValues, resultOffsets,
819  resultSizes))) {
820  rewriter.eraseOp(newLoop);
821  return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");
822  }
823 
824  SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
825  for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
826  llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
827  resultSizes)) {
828  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
829  rewriter.getIndexAttr(1));
830  Value insert = tensor::InsertSliceOp::create(
831  rewriter, yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset,
832  resultSize, resultStride);
833  newYieldValues.push_back(insert);
834  }
835 
836  rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
837  rewriter.replaceOp(loopOp,
838  newLoop->getResults().take_front(loopOp.getNumResults()));
839  return cast<LoopLikeOpInterface>(newLoop.getOperation());
840 }
841 
842 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall`
843 template <>
844 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
845  scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
846  YieldTiledValuesFn yieldTiledValuesFn) {
847  OpBuilder::InsertionGuard g(rewriter);
848  Location loc = loopOp.getLoc();
849  rewriter.setInsertionPoint(loopOp);
850  auto inits = llvm::to_vector(loopOp.getOutputs());
851  inits.append(newInitOperands.begin(), newInitOperands.end());
852  auto newLoop = scf::ForallOp::create(
853  rewriter, loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
854  loopOp.getMixedStep(), inits, loopOp.getMapping(),
855  [](OpBuilder &, Location, ValueRange) {});
856 
857  // Move the region of the current block to the newly created op.
858  Block *loopBody = loopOp.getBody();
859  Block *newLoopBody = newLoop.getBody();
860  rewriter.mergeBlocks(
861  loopBody, newLoopBody,
862  newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
863 
864  auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
865  rewriter.setInsertionPoint(terminator);
866  SmallVector<Value> tiledValues;
867  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
868  ValueRange regionIterArgs =
869  newLoop.getRegionIterArgs().take_back(newInitOperands.size());
870  if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
871  regionIterArgs, tiledValues, resultOffsets,
872  resultSizes))) {
873  rewriter.eraseOp(newLoop);
874  return rewriter.notifyMatchFailure(loopOp,
875  "failed to get yielded tiled values");
876  }
877 
878  // Update the terminator.
879  rewriter.setInsertionPointToEnd(terminator.getBody());
880 
881  for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
882  tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
883  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
884  rewriter.getIndexAttr(1));
885  tensor::ParallelInsertSliceOp::create(rewriter, terminator.getLoc(),
886  tiledValue, iterArg, resultOffset,
887  resultSize, resultStride);
888  }
889 
890  rewriter.replaceOp(loopOp,
891  newLoop->getResults().take_front(loopOp.getNumResults()));
892  return cast<LoopLikeOpInterface>(newLoop.getOperation());
893 }
894 
895 /// Implementation of `yieldTiledValuesAndReplaceLoop` for
896 /// `LoopLikeOpInterface`, that just dispatches to the implementation for each
897 /// supported loop type.
898 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
899  LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
900  ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
902  loopLikeOp.getOperation())
903  .Case<scf::ForOp, scf::ForallOp>(
904  [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
906  loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
907  })
908  .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
909  return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
910  });
911 }
912 
913 /// Method to add new init values to a loop nest. Updates `loops` in-place
914 /// with new loops that use the `newInitValues`. The outer-loops are updated
915 /// to yield the new result values of the inner loop. For the innermost loop,
916 /// the call back `getNewYields` is invoked to get the additional values to
917 /// yield form the innermost loop.
918 static LogicalResult addInitOperandsToLoopNest(
920  ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
921  if (loops.empty())
922  return success();
923  OpBuilder::InsertionGuard g(rewriter);
924  rewriter.setInsertionPoint(loops.front());
925 
926  SmallVector<Value> ivs;
927  for (auto &loop : loops.drop_back()) {
928  rewriter.setInsertionPoint(loop);
929 
930  // if loops.size() > 1 we assume that scf.for is used for the loops.
931  auto forLoop = cast<scf::ForOp>(loop.getOperation());
932 
933  // Create a new loop with the new init values for this loop.
934  SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs());
935  newInits.append(newInitValues.begin(), newInitValues.end());
936  auto newLoop = scf::ForOp::create(
937  rewriter, forLoop.getLoc(), forLoop.getLowerBound(),
938  forLoop.getUpperBound(), forLoop.getStep(), newInits,
939  [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {},
940  forLoop.getUnsignedCmp());
941 
942  // Merge the body of the new loop with the body of the old loops.
943  SmallVector<Value> sourceBlockArgs;
944  sourceBlockArgs.push_back(newLoop.getInductionVar());
945  auto newRegionIterArgs = newLoop.getRegionIterArgs();
946  sourceBlockArgs.append(
947  newRegionIterArgs.begin(),
948  std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
949  rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
950  rewriter.replaceOp(
951  forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
952  loop = newLoop;
953  ivs.push_back(newLoop.getInductionVar());
954  newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
955  }
956 
957  // Update the loop body of the innermost loop to get new yield values.
958  LoopLikeOpInterface innerMostLoop = loops.back();
959  FailureOr<LoopLikeOpInterface> newInnerMostLoop =
960  yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues,
961  getNewTiledYieldsFn);
962 
963  if (failed(newInnerMostLoop))
964  return innerMostLoop.emitOpError("failed to return additional yields");
965  loops.back() = newInnerMostLoop.value();
966 
967  // Make all other loops except the innermost loops yield the values returned
968  // by the inner loop.
969  for (auto [outerLoop, innerLoop] :
970  llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
971  // Again assume that all the outer loops are scf.for operations.
972  auto outerForLoop = cast<scf::ForOp>(outerLoop);
973  auto outerLoopYield =
974  cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
975  SmallVector<Value> newYields =
976  llvm::to_vector(outerLoopYield.getOperands());
977  ValueRange additionalYields =
978  innerLoop->getResults().take_back(newInitValues.size());
979  newYields.append(additionalYields.begin(), additionalYields.end());
980  rewriter.setInsertionPoint(outerLoopYield);
981  rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
982  }
983  return success();
984 }
985 
986 /// Implementation of tiling transformation of `op` that implements the
987 /// `TilingInterface` using `scf.for` to iterate over the tiles.
988 FailureOr<scf::SCFTilingResult>
989 mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
991  if (failed(verifyOptions(rewriter, op.getLoc(), options))) {
992  return failure();
993  }
994 
995  OpBuilder::InsertionGuard guard(rewriter);
996  rewriter.setInsertionPointAfter(op);
997 
998  // 1. Get the range of the loops that are represented by the operation.
999  SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
1000 
1001  // 2. Materialize the tile sizes and/or number of threads;
1002  SmallVector<OpFoldResult> tileSizes, numThreads;
1003  std::tie(tileSizes, numThreads) =
1004  getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options);
1005 
1006  // Check if it is safe to tile. This is hold over from previous iterations
1007  // of tile to for-all. Consider dropping it.
1008  if (failed(checkTileSizes(op, options.loopType, options.reductionStrategy,
1009  tileSizes, numThreads))) {
1010  return failure();
1011  }
1012 
1013  // Get the reduction dims
1014  SetVector<unsigned> reductionDims =
1015  getSanitizedReductionDims(tileSizes, options);
1016 
1017  // 3. If there is an interchange specified, permute the iteration domain and
1018  // the tile sizes.
1019  SmallVector<int64_t> interchangeVector;
1020  if (!options.interchangeVector.empty()) {
1021  interchangeVector = fillInterchangeVector(options.interchangeVector,
1022  iterationDomain.size());
1023  assert(isPermutationVector(interchangeVector) &&
1024  "expected interchange vector to be a permutation");
1025 
1026  applyPermutationToVector(iterationDomain, interchangeVector);
1027  applyPermutationToVector(tileSizes, interchangeVector);
1028  if (!numThreads.empty())
1029  applyPermutationToVector(numThreads, interchangeVector);
1030  }
1031 
1032  FailureOr<TilingResult> tilingResult;
1033  // 4. Define the lambda function used later to generate the body of the
1034  // innermost tiled loop.
1035  YieldTiledValuesFn innerYieldTiledValuesFn =
1036  [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
1037  ValueRange regionIterArgs, SmallVector<Value> &tiledResults,
1038  SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
1040  -> LogicalResult {
1041  // 4a. Compute the `offsets` and `sizes` to use for tiling.
1042  SmallVector<OpFoldResult> offsets, sizes;
1043  std::tie(offsets, sizes) = getTileOffsetAndSizes(
1044  rewriter, loc, options.reductionStrategy, ivs, iterationDomain,
1045  tileSizes, numThreads, reductionDims);
1046 
1047  // 4b. If interchange was provided, apply inverse of the interchange
1048  // to get back the offsets/sizes in the order to be specified.
1049  if (!interchangeVector.empty()) {
1050  auto inversePermutation = invertPermutationVector(interchangeVector);
1053  }
1054 
1055  // 5. Generate the tiled implementation within the inner most loop.
1056 
1057  // 5a. Clone the operation within the loop body.
1058  auto clonedOp = cast<TilingInterface>(
1059  cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
1060 
1061  // 5b. Early return cloned op if tiling is not happening. We can not
1062  // return the original op because it could lead to `rewriter.replaceOp(op,
1063  // op->getResults())` and users would get crash.
1064  if (llvm::all_of(tileSizes, isZeroInteger)) {
1065  tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
1066  tilingResult =
1067  TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
1068  /*generatedSlices=*/{}};
1069  return success();
1070  }
1071 
1072  // 5c. Tile the cloned operation.
1073  tilingResult = getTiledImplementation(
1074  rewriter, clonedOp, options.reductionStrategy, regionIterArgs, offsets,
1075  sizes, ivs, numThreads, tileSizes, reductionDims);
1076  if (failed(tilingResult)) {
1077  rewriter.eraseOp(clonedOp);
1078  return op.emitOpError("faild to tile operation");
1079  }
1080 
1081  // 5d. Delete the cloned operation.
1082  rewriter.eraseOp(clonedOp);
1083 
1084  // 5e. Compute the offsets at which the result values are to be inserted
1085  // back into its destinations.
1086  for (auto [index, tiledValue] :
1087  llvm::enumerate(tilingResult->tiledValues)) {
1088  tiledResults.push_back(tiledValue);
1089  SmallVector<OpFoldResult> resultOffset, resultSize;
1091  rewriter, options.reductionStrategy, index, tiledValue, op,
1092  offsets, sizes, ivs, numThreads, tileSizes, reductionDims,
1093  resultOffset, resultSize))) {
1094  for (auto op : tilingResult->tiledOps) {
1095  rewriter.eraseOp(op);
1096  }
1097  return rewriter.notifyMatchFailure(
1098  op, "failed to get slice of result produced");
1099  }
1100  resultOffsets.emplace_back(std::move(resultOffset));
1101  resultSizes.emplace_back(std::move(resultSize));
1102  }
1103 
1104  return success();
1105  };
1106 
1107  // 6. Find the destination tensors to use for the operation.
1108  FailureOr<SmallVector<Value>> maybeInits = createInitialTensorsForTiling(
1109  rewriter, op, options.reductionStrategy, iterationDomain, numThreads,
1110  tileSizes, reductionDims);
1111  if (failed(maybeInits)) {
1112  return rewriter.notifyMatchFailure(
1113  op, "unable to create initial tensors for tiling");
1114  }
1115  SmallVector<Value> &initTensors = maybeInits.value();
1116 
1117  // 7. Generate the tiled loops nest using the callback defined above.
1119  if (failed(generateLoopNest(rewriter, op.getLoc(), options.loopType,
1120  iterationDomain, tileSizes, numThreads,
1121  initTensors, options.mappingVector,
1122  innerYieldTiledValuesFn, loops)))
1123  return op.emitOpError("failed to generate tiling loops");
1124  assert(succeeded(tilingResult) &&
1125  "expected tiling result to be computed after loop generation");
1126 
1127  if (loops.empty()) {
1128  // If loops are empty, the tiled op is used as the replacement for the
1129  // untiled op.
1130  return scf::SCFTilingResult{tilingResult->tiledOps,
1131  initTensors,
1132  loops,
1133  tilingResult->tiledValues,
1134  tilingResult->generatedSlices,
1135  {}};
1136  }
1137 
1138  auto loopResults = llvm::map_to_vector(loops.front()->getResults(),
1139  [](OpResult r) -> Value { return r; });
1140 
1141  // For the full reduction case, there is nothing more to do.
1142  if (options.reductionStrategy == ReductionTilingStrategy::FullReduction) {
1143  return scf::SCFTilingResult{
1144  tilingResult->tiledOps, initTensors, loops, loopResults,
1145  tilingResult->generatedSlices, {}};
1146  }
1147 
1148  // The results of the loop needs to be merged.
1149  FailureOr<MergeResult> mergeResult = mergeTilingResults(
1150  rewriter, op, options.reductionStrategy, reductionDims, loopResults);
1151  if (failed(mergeResult)) {
1152  return rewriter.notifyMatchFailure(
1153  op, "Failed to merge partial results from tiling");
1154  }
1155  return scf::SCFTilingResult{tilingResult->tiledOps,
1156  initTensors,
1157  loops,
1158  mergeResult->replacements,
1159  tilingResult->generatedSlices,
1160  mergeResult->mergeOps};
1161 }
1162 
1163 FailureOr<scf::SCFTilingResult>
1165  PartialReductionOpInterface op,
1166  ArrayRef<OpFoldResult> tileSize) {
1169  options.setReductionTilingStrategy(
1171  options.setTileSizes(tileSize);
1172  SmallVector<unsigned> reductionDims;
1173  for (auto [index, iteratorType] : llvm::enumerate(op.getLoopIteratorTypes()))
1174  if (iteratorType == utils::IteratorType::reduction)
1175  reductionDims.push_back(index);
1176  options.setReductionDims(reductionDims);
1177  return tileUsingSCF(b, op, options);
1178 }
1179 
1180 //===----------------------------------------------------------------------===//
1181 // tileConsumerAndFuseProducersUsingSCF implementation.
1182 //===----------------------------------------------------------------------===//
1183 
1184 /// Return the untiled producer whose slice is used in a tiled consumer. The
1185 /// method traverses the tile loop nest (`loops`) if needed, and returns the
1186 /// `iter_args` of the outer most that is encountered. Traversing the
1187 /// iter_args indicates that this is a destination operand of the consumer. If
1188 /// there was no loop traversal needed, the second value of the returned tuple
1189 /// is empty.
1190 static std::tuple<OpResult, std::optional<OpOperand *>>
1193  std::optional<OpOperand *> destinationIterArg;
1194  assert(!loops.empty() && "expected non empty loops container");
1195  auto loopIt = loops.rbegin();
1196  while (loopIt != loops.rend() && isa<BlockArgument>(source->get())) {
1197  auto iterArg = cast<BlockArgument>(source->get());
1198  auto loop = *loopIt;
1199  if (iterArg.getOwner()->getParentOp() != loop)
1200  break;
1201  source = loop.getTiedLoopInit(iterArg);
1202  loopIt++;
1203  }
1204  if (loopIt == loops.rend())
1205  destinationIterArg = source;
1206  return {dyn_cast<OpResult>(source->get()), destinationIterArg};
1207 }
1208 
1209 /// Implementation of fusing producer of a single slice by computing the
1210 /// slice of the producer in-place.
1211 std::optional<scf::SCFFuseProducerOfSliceResult>
1213  RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1215  // 1. Get the producer of the source (potentially walking through
1216  // `iter_args` of nested `scf.for`)
1217  auto [fusableProducer, destinationInitArg] =
1218  getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
1219  loops);
1220  if (!fusableProducer)
1221  return std::nullopt;
1222  unsigned resultNumber = fusableProducer.getResultNumber();
1223 
1224  OpBuilder::InsertionGuard g(rewriter);
1225  rewriter.setInsertionPoint(candidateSliceOp);
1226 
1227  // 2. Clone the fused producer
1228  // 2a. Compute the destination operands to use for the cloned operation.
1229  SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
1230  Operation *fusableProducerOp = fusableProducer.getOwner();
1231  if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
1233  rewriter, fusableProducerOp->getLoc(), fusableProducerOp,
1234  origDestinationTensors)))
1235  return std::nullopt;
1236 
1237  clonedOpDestinationTensors = origDestinationTensors;
1238  if (destinationInitArg &&
1239  isa<DestinationStyleOpInterface>(fusableProducerOp)) {
1240  // 2b. If the producer is also destination style, then to maintain the
1241  // destination passing style, update the destination of the producer to be
1242  // the source of the slice.
1243  clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1244  }
1245  // 2c. Clone the fused producer.
1246  Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
1247  rewriter, fusableProducerOp, clonedOpDestinationTensors);
1248  // 2d. Update the source of the candidateSlice to be the cloned producer.
1249  // Easier to just clone the slice with different source since
1250  // replacements and DCE of cloned ops becomes easier
1251  SmallVector<Value> candidateSliceOpOperands =
1252  llvm::to_vector(candidateSliceOp->getOperands());
1253  candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
1254  tensor::ExtractSliceOp clonedCandidateSliceOp =
1255  mlir::clone(rewriter, candidateSliceOp,
1256  candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1257 
1258  // 3. Generate the tiled implementation of the producer of the source
1259  FailureOr<TilingResult> tileAndFuseResult =
1261  rewriter, clonedCandidateSliceOp,
1262  clonedProducerOp->getResult(resultNumber));
1263  if (failed(tileAndFuseResult))
1264  return std::nullopt;
1265  // Note: Do not delete the candidateSliceOp, since its passed in from the
1266  // caller.
1267  rewriter.replaceAllUsesWith(candidateSliceOp,
1268  tileAndFuseResult->tiledValues[0]);
1269  rewriter.eraseOp(clonedCandidateSliceOp);
1270  rewriter.eraseOp(clonedProducerOp);
1271 
1272  // 3. If the slice is for a destination operand, for example,
1273  //
1274  // ```mlir
1275  // %0 = linalg.init
1276  // %1 = linalg.fill .. outs(%0 : )
1277  // %2 = scf.for .. iter_args(%arg0 = %1) {
1278  // %3 = scf.for .. iter_args(%arg1 = %arg0) {
1279  // %4 = tensor.extract_slice %arg1 [..]
1280  // .. = linalg.matmul .. outs(%4 : )
1281  // }
1282  // }
1283  // ```
1284  //
1285  // the IR is currently
1286  //
1287  // ```
1288  // %0 = linalg.init
1289  // %1 = linalg.fill
1290  // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
1291  // %3 = scf.for .. iter_args(%arg1 = %arg0) {
1292  // %4 = tensor.extract_slice %arg1[..]
1293  // %5 = linalg.fill .. outs(%4 : )
1294  // .. = linalg.matmul .. outs(%5 : )
1295  // }
1296  // }
1297  // ```
1298  //
1299  // The untiled `linalg.fill` is still used as the `init_value` since it
1300  // was originally a destination operand of the untiled `linalg.matmul`.
1301  // When fusing an operand that is a destination operand, the iter_arg of
1302  // the outer most loop should be changed to use the destination of the
1303  // fused operation. With this the IR will be.
1304  //
1305  // ```
1306  // %0 = linalg.init
1307  // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
1308  // %2 = scf.for .. iter_args(%arg1 = %arg0) {
1309  // %3 = tensor.extract_slice %arg1[..]
1310  // %4 = linalg.fill .. outs(%3 : )
1311  // .. = linalg.matmul .. outs(%4 : )
1312  // }
1313  // }
1314  // ```
1315  if (destinationInitArg &&
1316  isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
1317  loops.front()
1318  ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1319  .set(origDestinationTensors[resultNumber]);
1320  }
1322  fusableProducer, tileAndFuseResult->tiledValues[0],
1323  tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices};
1324 }
1325 
1326 /// Reconstruct the fused producer from within the tiled-and-fused code.
1327 FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
1328  RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1329  scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
1331  ArrayRef<unsigned> yieldResultNumber) {
1332  if (loops.empty())
1333  return success();
1334 
1335  Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
1336  *tiledOwner = fusedProducerInfo.tiledOps[0];
1337 
1338  Location loc = originalOwner->getLoc();
1339  // a. collect all init Value to be appended
1340  SmallVector<unsigned> initNumberList =
1341  yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
1342  0, originalOwner->getNumResults()))
1343  : llvm::to_vector(yieldResultNumber);
1344  SmallVector<Value> initValueList;
1345  for (const auto &resultNumber : initNumberList) {
1346  FailureOr<Value> initValue = tensor::getOrCreateDestination(
1347  rewriter, loc, originalOwner->getResult(resultNumber));
1348  if (succeeded(initValue)) {
1349  initValueList.push_back(initValue.value());
1350  } else {
1351  return failure();
1352  }
1353  }
1354 
1355  SmallVector<Operation *> generatedSlices;
1356  YieldTiledValuesFn newYieldValuesFn =
1357  [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
1358  ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
1360  SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
1361  OpBuilder::InsertionGuard g(innerRewriter);
1362 
1363  // get sliceOp tile information
1364  SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
1365  sliceSizes = sliceOp.getMixedSizes();
1366 
1367  // expect all strides of sliceOp being 1
1368  if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
1369  return failure();
1370 
1371  unsigned sliceResultNumber =
1372  fusedProducerInfo.origProducer.getResultNumber();
1373 
1374  auto tilableOp = cast<TilingInterface>(originalOwner);
1375  // b. get iterDomain Offset and Sizes based on sliceOp tile
1376  SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
1377  // skip tensor.pack/unpack/pad, which expects single opResult
1378  if (tilableOp->getNumResults() > 1 &&
1379  failed(tilableOp.getIterationDomainTileFromResultTile(
1380  rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1381  iterDomainOffset, iterDomainSizes))) {
1382  // In theory, it is unnecessary to raise an error here. Actually
1383  // although it fails to reconstruct the result tensor, it should not
1384  // broke current fusion anyway. The reason why we must return failure
1385  // currently is that the callback function `newYieldValuesFn` will be
1386  // called after new init operand(s) has already been appended. It will
1387  // take more refactoring to make sure the init operands are added
1388  // consistently in the future. For more details, please refer to:
1389  // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
1390  return failure();
1391  }
1392 
1393  // c. calculate offsets and sizes info of all OpResults respectively based
1394  // on iteration Domain Tile
1395  SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
1396  for (const auto &resultNumber : initNumberList) {
1397  if (resultNumber == sliceResultNumber) {
1398  offsetList.push_back(sliceOffset);
1399  sizesList.push_back(sliceSizes);
1400  } else {
1401  assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1402  // infer result tile according to the iteration domain tile
1403  SmallVector<OpFoldResult> offset, sizes;
1404  if (failed(tilableOp.getResultTilePosition(
1405  rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1406  offset, sizes))) {
1407  return failure();
1408  }
1409  offsetList.push_back(offset);
1410  sizesList.push_back(sizes);
1411  }
1412  }
1413 
1414  // d. create `extract_slice` for `iter_args` for DPS operation if
1415  // necessary
1416  if (auto tiledDestStyleOp =
1417  dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1418  rewriter.setInsertionPoint(tiledDestStyleOp);
1419  for (const auto &&[index, newRegionArg] :
1420  llvm::enumerate(newRegionIterArgs)) {
1421  auto destSlice = tensor::ExtractSliceOp::create(
1422  rewriter, loc, newRegionArg, offsetList[index], sizesList[index],
1423  SmallVector<OpFoldResult>(offsetList[index].size(),
1424  rewriter.getIndexAttr(1)));
1425  generatedSlices.push_back(destSlice);
1426  unsigned resultNumber = initNumberList[index];
1427  rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
1428  tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1429  });
1430  }
1431  }
1432 
1433  // e. prepare tiled offset and sizes for later `insert_slice` creation by
1434  // caller
1435  Block *block = rewriter.getInsertionPoint()->getBlock();
1436  rewriter.setInsertionPoint(block->getTerminator());
1437  for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
1438  tiledResult.push_back(tiledOwner->getResult(resultNumber));
1439  tiledOffset.emplace_back(offsetList[index]);
1440  tiledSizes.emplace_back(sizesList[index]);
1441  }
1442  return success();
1443  };
1444 
1445  if (failed(addInitOperandsToLoopNest(rewriter, loops, initValueList,
1446  newYieldValuesFn))) {
1447  return failure();
1448  }
1449  return generatedSlices;
1450 }
1451 
1452 namespace {
1453 
1454 //===----------------------------------------------------------------------===//
1455 // SliceTrackingListener
1456 //===----------------------------------------------------------------------===//
1457 
1458 /// This class is a listener for tracking the insertion and removal of
1459 /// `tensor.extract_slice` ops in a worklist. This can be used in a greedy
1460 /// fusion algorithm to apply cleanup patterns in between fusion steps.
1461 class SliceTrackingListener : public RewriterBase::Listener {
1462 public:
1463  explicit SliceTrackingListener(
1464  std::optional<FrozenRewritePatternSet> patterns);
1465  SliceTrackingListener() = default;
1466 
1467  /// Adds the given list of operations to the worklist, and if present,
1468  /// applies the list of `patterns` to the newly added operations. This only
1469  /// processes the given operations and any newly inserted ones by the
1470  /// pattern set.
1471  LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
1472 
1473  /// Add to the new operation worklist if it is an extract_slice.
1474  void notifyOperationInserted(Operation *op,
1475  OpBuilder::InsertPoint previous) override;
1476 
1477  /// Shared helper for operation removal from the worklist.
1478  void removeOp(Operation *op);
1479 
1480  /// Remove the operation from the worklist.
1481  void notifyOperationErased(Operation *op) override;
1482 
1483  /// Remove the operation from the worklist.
1484  void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
1485 
1486  /// The worklist for this transformation keeps track of the slices to visit
1487  /// next for fusion.
1488  std::deque<tensor::ExtractSliceOp> worklist;
1489 
1490 private:
1491  /// Optional pattern set to apply when adding new operations to the
1492  /// worklist.
1493  std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
1494 };
1495 
1496 SliceTrackingListener::SliceTrackingListener(
1497  std::optional<FrozenRewritePatternSet> p) {
1498  patterns = std::move(p);
1499 }
1500 
1501 LogicalResult
1502 SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
1503  for (Operation *op : ops) {
1504  if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1505  worklist.push_back(slice);
1506  }
1507 
1508  if (!patterns)
1509  return success();
1510 
1511  return applyOpPatternsGreedily(
1512  ops, patterns.value(),
1513  GreedyRewriteConfig().setListener(this).setStrictness(
1515 }
1516 
1517 void SliceTrackingListener::notifyOperationInserted(
1518  Operation *op, OpBuilder::InsertPoint previous) {
1519  auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1520  if (!slice)
1521  return;
1522  worklist.push_back(slice);
1523 }
1524 
1525 // Scan the worklist for the given op and remove it if present. The
1526 // expectation is for the worklist to be small and for removal to be
1527 // relatively rare.
1528 void SliceTrackingListener::removeOp(Operation *op) {
1529  if (!isa<tensor::ExtractSliceOp>(op))
1530  return;
1531  auto iter = worklist.begin();
1532  while (iter != worklist.end()) {
1533  if (*iter == op)
1534  break;
1535  iter++;
1536  }
1537  if (iter == worklist.end())
1538  return;
1539 
1540  worklist.erase(iter);
1541 }
1542 
1543 void SliceTrackingListener::notifyOperationErased(Operation *op) {
1544  removeOp(op);
1545 }
1546 
1547 void SliceTrackingListener::notifyOperationReplaced(Operation *op,
1548  ValueRange replacement) {
1549  removeOp(op);
1550 }
1551 
1552 //===----------------------------------------------------------------------===//
1553 // ReplacementListener
1554 //===----------------------------------------------------------------------===//
1555 
1556 /// Listener that tracks updates replacements for values which can be mutated.
1557 /// This listener runs on top of the existing listener for the rewriter,
1558 /// to make sure external users can still run listeners.
1559 class ReplacementListener : public RewriterBase::ForwardingListener {
1560 public:
1561  ReplacementListener(DenseMap<Value, Value> &replacements,
1562  OpBuilder::Listener *listener)
1563  : ForwardingListener(listener), replacements(replacements) {}
1564 
1565  void updateReplacementValues(ValueRange origValues,
1566  ValueRange replaceValues) {
1567  // This can probably be written better, but just iterates over the map
1568  // and the new replacements for now.
1569  for (auto &[key, val] : replacements) {
1570  for (auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) {
1571  if (val == orig) {
1572  val = replace;
1573  }
1574  }
1575  }
1576  }
1577 
1578  void notifyOperationReplaced(Operation *op, Operation *newOp) override {
1579  ForwardingListener::notifyOperationReplaced(op, newOp);
1580  updateReplacementValues(op->getResults(), newOp->getResults());
1581  }
1582 
1583  void notifyOperationReplaced(Operation *op, ValueRange values) override {
1584  ForwardingListener::notifyOperationReplaced(op, values);
1585  updateReplacementValues(op->getResults(), values);
1586  }
1587 
1588 private:
1589  DenseMap<Value, Value> &replacements;
1590 };
1591 
1592 } // namespace
1593 
1594 /// Implementation of tile consumer and fuse producer greedily.
1595 FailureOr<scf::SCFTileAndFuseResult>
1597  RewriterBase &rewriter, TilingInterface consumer,
1599  // This transformation is only valid for ops that return values (i.e. not
1600  // valid to use with operations that have memref operands).
1601  if (!consumer->getNumResults()) {
1602  return rewriter.notifyMatchFailure(
1603  consumer, "invalid pattern for op with no results");
1604  }
1605 
1606  // 1. First tile the consumer.
1607  SetVector<Operation *> fusedProducers, tiledAndFusedOps;
1608 
1609  FailureOr<scf::SCFTilingResult> tilingResult =
1610  tileUsingSCF(rewriter, consumer, options.tilingOptions);
1611 
1612  if (failed(tilingResult))
1613  return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
1614  tiledAndFusedOps.insert_range(tilingResult->tiledOps);
1615 
1616  DenseMap<Value, Value> replacements;
1617  for (auto [origVal, replacement] :
1618  llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
1619  replacements[origVal] = replacement;
1620  }
1621 
1622  // If there are no loops generated, fusion is immaterial.
1623  auto &loops = tilingResult->loops;
1624  if (loops.empty()) {
1625  return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1626  replacements};
1627  }
1628 
1629  // Since the loop gets potentially replaced during fusion, we need to track
1630  // the mutation of replacement values. To do this, we attach a listener to
1631  // update the replacements as they happen.
1632  OpBuilder::Listener *previousListener = rewriter.getListener();
1633  auto resetListener =
1634  llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
1635  ReplacementListener replaceListener(replacements, previousListener);
1636  rewriter.setListener(&replaceListener);
1637 
1638  // 2. Typically, the operands of the tiled operation are slices of the
1639  // operands of the untiled operation. These are expressed in IR using
1640  // `tensor.extract_slice` operations with source being the operands of
1641  // the untiled operation. Create a worklist of these
1642  // `tensor.extract_slice` operations. If the producers of the source of
1643  // the `tensor.extract_slice` can be tiled such that the tiled value is
1644  // generated in-place, that effectively tiles + fuses the operations.
1645  struct WorklistItem {
1646  tensor::ExtractSliceOp candidateSlice;
1648  };
1649 
1650  SliceTrackingListener sliceTracker =
1651  SliceTrackingListener(options.cleanupPatterns);
1652 
1653  if (failed(
1654  sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1655  return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1656  }
1657  OpBuilder::InsertionGuard g(rewriter);
1658  while (!sliceTracker.worklist.empty()) {
1659  auto candidateSlice = sliceTracker.worklist.front();
1660  sliceTracker.worklist.pop_front();
1661 
1662  auto [fusableProducer, destinationInitArg] =
1663  getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(),
1664  loops);
1665  if (!fusableProducer)
1666  continue;
1667 
1668  std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1669  options.fusionControlFn(candidateSlice, fusableProducer,
1670  destinationInitArg.has_value());
1671  if (!controlFnResult)
1672  continue;
1673 
1674  WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
1675 
1676  // The operands of the fused producer might themselved be slices of
1677  // values produced by operations that implement the `TilingInterface`.
1678  // Add these operations to the worklist.
1679  std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1680  tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice,
1681  loops);
1682  if (!fusedResult)
1683  continue;
1684 
1685  SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
1686 
1687  if (worklistItem.controlFnResult.yieldProducerReplacement) {
1688  // Reconstruct and yield all opResult of fusableProducerOp by default.
1689  // The caller can specific which one to yield by designating optional
1690  // argument named `yieldResultNumber` of
1691  // `yieldReplacementForFusedProducer`.
1692  Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1693  FailureOr<SmallVector<Operation *>> newSlices =
1695  worklistItem.candidateSlice,
1696  fusedResult.value(), loops);
1697  if (failed(newSlices)) {
1698  return rewriter.notifyMatchFailure(
1699  fusableProducerOp, "failed to replacement value for this "
1700  "operation from within the tiled loop");
1701  }
1702  worklistCandidates.append(newSlices.value());
1703  for (auto [index, result] :
1704  llvm::enumerate(fusableProducerOp->getResults())) {
1705  replacements[result] = loops.front()->getResult(
1706  loops.front()->getNumResults() -
1707  fusableProducerOp->getNumResults() + index);
1708  }
1709  }
1710  if (Operation *tiledAndFusedOp =
1711  fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1712  fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1713  tiledAndFusedOps.insert(tiledAndFusedOp);
1714  }
1715 
1716  if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1717  return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1718  }
1719  }
1720 
1721  return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1722  replacements};
1723 }
1724 
1725 //===----------------------------------------------------------------------===//
1726 // tileAndFuseConsumerUsingSCF implementation.
1727 //===----------------------------------------------------------------------===//
1728 
1729 /// A utility function that checks whether the only use of the result of a
1730 /// tensor.insert_slice op is in a scf.yield op.
1731 static LogicalResult
1732 checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
1733  Value result = candidateSliceOp.getResult();
1734  Value::use_range uses = result.getUses();
1735  if (!llvm::hasSingleElement(uses)) {
1736  LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
1737  return failure();
1738  }
1739  OpOperand &operandUse = (*uses.begin());
1740  Operation *userOp = operandUse.getOwner();
1741  if (!isa<scf::YieldOp>(userOp)) {
1742  LLVM_DEBUG(llvm::dbgs()
1743  << "Expected scf.yield to be the only user, but got -> "
1744  << (*userOp));
1745  return failure();
1746  }
1747  if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
1748  LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
1749  "be in the same block\n");
1750  return failure();
1751  }
1752  return success();
1753 }
1754 
1755 /// An utility to get the first user of the given loopOp. If any of user stay
1756 /// in different block of loopOp, return failure.
1757 static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
1758  if (!isa<LoopLikeOpInterface>(loopOp))
1759  return failure();
1760  Operation *firstUserOfLoop = nullptr;
1761  for (Operation *userOp : loopOp->getUsers()) {
1762  // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
1763  // block with any other types of operation. Thus, just redirecting to its
1764  // parent `InParallelOp`. E.g.
1765  //
1766  // ```
1767  // %1 = scf.for {
1768  // ...
1769  // }
1770  // %2 = consumerOp ins(%1, ...)
1771  // scf.forall.in_parallel {
1772  // tensor.parallel_insert_slice %1
1773  // }
1774  // ```
1775  // where `InParallelOp` but not `ParallelInsertSlice` stays in the same
1776  // same block with `consumerOp`.
1777  if (isa<tensor::ParallelInsertSliceOp>(userOp))
1778  userOp = userOp->getParentOfType<scf::InParallelOp>();
1779 
1780  if (loopOp->getBlock() != userOp->getBlock())
1781  return failure();
1782 
1783  if (!firstUserOfLoop || userOp->isBeforeInBlock(firstUserOfLoop))
1784  firstUserOfLoop = userOp;
1785  }
1786  return firstUserOfLoop;
1787 }
1788 
1789 /// This utility currently checks whether the first userOp of loop is NOT
1790 /// before the last defineOp of consumer operand. Because that we need to move
1791 /// the whole loop structure right before the `firstUserOfLoop`. This utility
1792 /// thus helps ensuring that no invalid IR is formed, i.e. no backward slice
1793 /// of consumerOp is dominated by the `firstUserOfLoop`. Saying that:
1794 ///
1795 /// ```
1796 /// %0 = scf.for() {
1797 /// ...
1798 /// }
1799 /// ...
1800 /// %1 = firstUserOfLoop(%0)
1801 /// ...
1802 /// %2 = lastDefOfConsumerOperand
1803 /// ...
1804 /// %3 = consumerOp(%2)
1805 /// ```
1806 ///
1807 /// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it
1808 /// would be invalid to move the `loopOp` right before the `firstUserOfLoop`,
1809 /// a.k.a. use-def chain violation:
1810 ///
1811 /// ```
1812 /// %0:2 = scf.for() {
1813 /// // use before define error
1814 /// %3 = tiledConsumerOp(%2)
1815 /// }
1816 /// %1 = firstUserOfLoop(%0)
1817 /// ...
1818 /// %2 = lastDefOfConsumerOperand
1819 /// ```
1820 ///
1821 /// @param loopOp: loop operation
1822 /// @param consumerOp: consumer operation
1823 /// @param reorderOperations: the flag controls whether to reorder the
1824 /// backward slice w.r.t. the defineOp of `consumerOp` operands.
1825 /// @return: computed backward slice of consumerOp, but excluding those
1826 /// already dominates `firstUserOfLoop`.
1827 static FailureOr<llvm::SetVector<Operation *>>
1829  bool reorderOperations) {
1830  FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
1831  if (failed(firstUserOfLoop))
1832  return failure();
1833 
1835  DominanceInfo dominanceInfo;
1836  options.inclusive = true;
1837  options.omitBlockArguments = true;
1838  bool includeLoopOp = false;
1839  options.filter = [&](Operation *op) {
1840  if (op == loopOp) {
1841  includeLoopOp = true;
1842  return false;
1843  }
1844  // Cut off the slice to not include any operation that already dominates
1845  // firstUserOfLoop.
1846  return !dominanceInfo.properlyDominates(op, *firstUserOfLoop);
1847  };
1849  for (auto operand : consumerOp->getOperands()) {
1850  LogicalResult result = getBackwardSlice(operand, &slice, options);
1851  assert(result.succeeded() && "expected a backward slice");
1852  (void)result;
1853  }
1854 
1855  if (!slice.empty()) {
1856  // If consumerOp has one producer, which is also the user of loopOp.
1857  // E.g.
1858  // ```
1859  // %0 = %loopOp
1860  // %1 = consumerOp1 ins(%0)
1861  // %2 = consumerOp2 ins(%0, %1)
1862  // ```
1863  // We can not fuse consumerOp2 into loopOp due to UD chain, unless
1864  // consumerOp1 has already been fused into loopOp before.
1865  if (includeLoopOp || !reorderOperations)
1866  return failure();
1867  }
1868 
1869  return slice;
1870 }
1871 
1872 /// Fetches the OpOperand of the first valid user (and use) of the value `val`
1873 /// which implements `TilingInterface` and `DestinationStyleOpInterface`.
1874 /// Returns failure otherwise.
1875 static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
1876  Operation *loopOp,
1877  unsigned resultNumber) {
1878  if (!isa<LoopLikeOpInterface>(loopOp))
1879  return failure();
1880  Value val = loopOp->getResult(resultNumber);
1881  Block *loopBlock = loopOp->getBlock();
1882  for (OpOperand &opOperand : val.getUses()) {
1883  Operation *consumerOp = opOperand.getOwner();
1884  // Step 1. Check if the user is tilable.
1885  if (!isa<TilingInterface>(consumerOp) ||
1886  !isa<DestinationStyleOpInterface>(consumerOp)) {
1887  // TODO: We have to init result of consumer before scf.for, use
1888  // DestinationStyleOpInterface to get result shape from init for now.
1889  // Add support for other op such as op has InferTypeOpInterface.
1890  continue;
1891  }
1892  // Step 2. Check if user stay in the same block.
1893  if (loopBlock != consumerOp->getBlock())
1894  continue;
1895  // Step 3. Check if user has succeeding user. Otherwise, it usually
1896  // represents already tiled.
1897  if (consumerOp->use_empty())
1898  continue;
1899  // Step 4. Check assumption for loop with `reorderOperations` enabled.
1900  FailureOr<llvm::SetVector<Operation *>> slice =
1901  checkAssumptionForLoop(loopOp, consumerOp, true);
1902  if (failed(slice))
1903  continue;
1904  // Step 5. If backward sice is not empty, move them before
1905  // firstUserOfLoop.
1906  if (!slice->empty()) {
1907  mlir::topologicalSort(*slice);
1908  FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
1909  assert(succeeded(firstUserOfLoop) && "First user of loop is not found");
1910  for (auto op : *slice) {
1911  rewriter.moveOpBefore(op, *firstUserOfLoop);
1912  }
1913  }
1914  return &opOperand;
1915  }
1916  return failure();
1917 }
1918 
1919 /// Fetch the untiled consumer of the outermost scf.for's result which is
1920 /// yielded by a tensor.insert_slice from the innermost scf.for. This function
1921 /// makes the following assumptions :
1922 /// 1. tensor.insert_slice has scf.yield as its only user.
1923 /// 2. scf.for's corresponding result has only one use.
1924 /// 3. The `loops` passed in are perfectly nested `scf.for` operations.
1925 static FailureOr<OpOperand *>
1927  tensor::InsertSliceOp candidateSliceOp,
1929  assert(!loops.empty() && "unexpected loops to be empty");
1930  // 1. Expect slice to be part of the body of the inner most loop.
1931  Operation *containingOp = candidateSliceOp->getParentOp();
1932  if (containingOp != loops.back()) {
1933  return rewriter.notifyMatchFailure(
1934  candidateSliceOp,
1935  "expected slice to be within body of inner-most loop");
1936  }
1937 
1938  // 2. Check that the loop is perfectly nested.
1939  if (!isPerfectlyNestedForLoops(loops)) {
1940  return rewriter.notifyMatchFailure(
1941  candidateSliceOp, "expected passed loops to be perfectly nested.");
1942  }
1943 
1944  if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
1945  return failure();
1946  Value sliceResult = candidateSliceOp.getResult();
1947 
1948  // 3. Fetch the corresponding output.
1949  OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
1950  unsigned resultNumber = yieldOpOperand.getOperandNumber();
1951 
1952  scf::ForOp topLevelForOp = cast<scf::ForOp>(loops.front().getOperation());
1953 
1954  return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber);
1955 }
1956 
1957 /// Fetch the first untiled consumer of a scf.forall's result which is yielded
1958 /// by a tensor.parallel_insert_slice.
1959 static FailureOr<OpOperand *>
1961  tensor::ParallelInsertSliceOp candidateSliceOp,
1963  assert(!loops.empty() && "unexpected loops to be empty");
1964  // 1. Check that the surrounding loop is a single scf.forall loop.
1965  if (loops.size() != 1) {
1966  return rewriter.notifyMatchFailure(
1967  candidateSliceOp, "expected single surrounding scf.forall");
1968  }
1969  auto forallOp = dyn_cast<scf::ForallOp>(loops.front().getOperation());
1970  if (!forallOp) {
1971  return rewriter.notifyMatchFailure(
1972  candidateSliceOp, "expected single surrounding scf.forall");
1973  }
1974 
1975  // 2. Fetch the corresponding output
1976  Value sliceDest = candidateSliceOp.getDest();
1977  auto iterArg = dyn_cast<BlockArgument>(sliceDest);
1978  if (!iterArg)
1979  return failure();
1980  if (iterArg.getOwner()->getParentOp() != forallOp)
1981  return failure();
1982 
1983  unsigned resultNumber =
1984  forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
1985  .getResultNumber();
1986 
1987  return getConsumerFromLoopUses(rewriter, forallOp, resultNumber);
1988 }
1989 
1990 /// A utility to fetch an untiled consumer of
1991 /// tensor.insert_slice/tensor.parallel_insert_slice.
1992 static FailureOr<SmallVector<OpOperand *>> getUntiledConsumerOperandsFromSlices(
1993  RewriterBase &rewriter, ArrayRef<Operation *> sliceOps,
1995  assert(!loops.empty() && "unexpected empty loops");
1996  assert(!sliceOps.empty() && "unexpected empty list of candidate slices");
1997  SmallVector<OpOperand *> fusedOperands;
1998  for (auto sliceOp : sliceOps) {
1999  FailureOr<OpOperand *> fusedOperand =
2001  .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2002  [&](auto op) {
2003  return getUntiledConsumerFromSlice(rewriter, op, loops);
2004  })
2005  .Default([&](Operation *op) {
2006  return rewriter.notifyMatchFailure(op, "unhandled slice type");
2007  });
2008  if (failed(fusedOperand)) {
2009  return failure();
2010  }
2011  if (!fusedOperands.empty() &&
2012  fusedOperand.value()->getOwner() != fusedOperands.front()->getOwner()) {
2013  return rewriter.notifyMatchFailure(
2014  fusedOperand.value()->getOwner(),
2015  "all candidate slices must be to the same consumer");
2016  }
2017  fusedOperands.push_back(fusedOperand.value());
2018  }
2019  return fusedOperands;
2020 }
2021 
2022 template <typename InsertSliceOpTy>
2023 static tensor::InsertSliceOp cloneAsInsertSlice(RewriterBase &rewriter,
2024  InsertSliceOpTy sliceOp);
2025 
2026 template <>
2027 tensor::InsertSliceOp
2028 cloneAsInsertSlice<tensor::InsertSliceOp>(RewriterBase &rewriter,
2029  tensor::InsertSliceOp insertSliceOp) {
2030  return cast<tensor::InsertSliceOp>(
2031  rewriter.clone(*insertSliceOp.getOperation()));
2032 }
2033 
2034 template <>
2035 tensor::InsertSliceOp cloneAsInsertSlice<tensor::ParallelInsertSliceOp>(
2036  RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp) {
2037  return tensor::InsertSliceOp::create(
2038  rewriter, insertSliceOp->getLoc(), insertSliceOp.getSource(),
2039  insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(),
2040  insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
2041 }
2042 
2045  ArrayRef<Operation *> candidateSlices) {
2046  assert(!candidateSlices.empty() &&
2047  "unexpected empty list of slices to clone");
2049  for (auto sliceOp : candidateSlices) {
2050  TypeSwitch<Operation *>(sliceOp)
2051  .Case<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2052  [&](auto op) {
2053  auto clonedOp = cloneAsInsertSlice(rewriter, op);
2054  clonedSlices.push_back(clonedOp);
2055  })
2056  .Default([&](Operation *op) {
2057  // Assert here assuming this has already been checked.
2058  assert(0 && "unexpected slice type while cloning as insert slice");
2059  });
2060  }
2061  return clonedSlices;
2062 }
2063 
2064 /// Implementation of fusing consumer of a single slice by computing the
2065 /// slice of the consumer in-place for scf loop.
2066 FailureOr<scf::SCFFuseConsumerOfSliceResult>
2068  RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
2070  if (candidateSlices.empty()) {
2071  return rewriter.notifyMatchFailure(
2072  rewriter.getUnknownLoc(),
2073  "no candidate slices provided for consumer fusion");
2074  }
2075  // Return if `loops` is empty, return an error for now. Caller is expected
2076  // to handle this case.
2077  if (loops.empty()) {
2078  return rewriter.notifyMatchFailure(
2079  candidateSlices.front(),
2080  "cannot call tile and fuse consumer with an empty loop nest");
2081  }
2082 
2083  if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
2084  llvm::all_of(candidateSlices,
2085  llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
2086  return rewriter.notifyMatchFailure(
2087  candidateSlices.front(),
2088  "candidates slices need to be all `tensor.extract_slice`s or "
2089  "`tensor.parallel_insert_slice`s");
2090  }
2091 
2092  // 1. Get the consumer of scf.for for the result yielded by
2093  // tensor.insert_slice/parallel_insert_slice.
2094  SmallVector<OpOperand *> consumerOpOperands;
2095  Operation *consumerOp;
2096  {
2097  FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
2098  getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
2099  if (failed(maybeConsumerOpOperand)) {
2100  return rewriter.notifyMatchFailure(candidateSlices.front(),
2101  "could not fetch consumer to fuse");
2102  }
2103  std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
2104  consumerOp = consumerOpOperands.front()->getOwner();
2105  }
2106 
2107  LoopLikeOpInterface outerMostLoop = loops.front();
2108  LoopLikeOpInterface innerMostLoop = loops.back();
2109 
2110  // Check assumption for loop with `reorderOperations` disabled.
2111  if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
2112  return rewriter.notifyMatchFailure(
2113  outerMostLoop, "the first user of loop should not dominate any define "
2114  "of consumer operand(s)");
2115  }
2116 
2117  OpBuilder::InsertionGuard g(rewriter);
2118 
2119  // 2. Check consumer is not using scf loop's output as init.
2120  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
2121  if (!dstOp)
2122  return rewriter.notifyMatchFailure(consumerOp,
2123  "consumer op is not DPS operation");
2124  if (llvm::any_of(consumerOpOperands, [&](OpOperand *opOperand) {
2125  return dstOp.isDpsInit(opOperand);
2126  })) {
2127  return rewriter.notifyMatchFailure(
2128  consumerOp,
2129  "consumer op taking the result of scf.for as init is not supported");
2130  }
2131  SmallVector<Value> newInits = llvm::to_vector(dstOp.getDpsInits());
2132 
2133  // 3. Move the whole loop structure right before firstUserOfLoop, the
2134  // dominance should be already ensured by `checkAssumptionForLoop`.
2135  FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(outerMostLoop);
2136  if (failed(firstUserOfLoop)) {
2137  return rewriter.notifyMatchFailure(
2138  outerMostLoop, "could not find the first user of outer most loop");
2139  }
2140  rewriter.moveOpBefore(outerMostLoop, *firstUserOfLoop);
2141 
2142  // 4. Set insertion point before terminator op of the loop and create a new
2143  // tensor.insert_slice. In the scf.for case this is a clone of the
2144  // candidateSliceOp whereas in the scf.forall case this is created from the
2145  // operands of tensor.parallel_insert_slice.
2146  if (auto sliceOp =
2147  dyn_cast<tensor::ParallelInsertSliceOp>(candidateSlices.front())) {
2148  auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
2149  rewriter.setInsertionPoint(newForallOp.getTerminator());
2150  } else {
2151  rewriter.setInsertionPoint(candidateSlices.front());
2152  }
2153  // 5.a. Clone all the candidate slices as equivalent insert slice ops.
2154  SmallVector<tensor::InsertSliceOp> clonedInsertSlices =
2155  cloneAsInsertSlices(rewriter, candidateSlices);
2156 
2157  // 5.b. Clone consumer op.
2158  auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
2159  SmallVector<unsigned> operandNumbers =
2160  llvm::map_to_vector(consumerOpOperands, [](OpOperand *opOperand) {
2161  return opOperand->getOperandNumber();
2162  });
2163  SmallVector<OpOperand *> clonedOpFusedOperandsList =
2164  llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
2165  return &clonedConsumerOp->getOpOperand(operandNum);
2166  });
2167 
2168  // 5.c. Replace all uses of the loop result with the result of the cloned
2169  // tensor.insert_slice.
2170  rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
2171  for (auto [operandToReplace, clonedSliceOp] :
2172  llvm::zip_equal(clonedOpFusedOperandsList, clonedInsertSlices)) {
2173  operandToReplace->set(clonedSliceOp.getResult());
2174  }
2175  });
2176 
2177  // 6. Perform tiling of the cloned consumer and replace the operand at
2178  // `operandNumber` with the source of the cloned tensor.insert_slice op.
2179  FailureOr<TilingResult> tileAndFuseResult =
2180  tensor::replaceInsertSlicesWithTiledConsumer(rewriter, clonedInsertSlices,
2181  clonedOpFusedOperandsList);
2182  if (failed(tileAndFuseResult)) {
2183  return failure();
2184  }
2185 
2186  auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
2187  for (auto [operandNum, clonedSliceOp] :
2188  llvm::zip_equal(operandNumbers, clonedInsertSlices)) {
2189  rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNum),
2190  clonedSliceOp.getSource());
2191  }
2192 
2193  // 7. Reconstruct [nested] loop with new inits.
2194  YieldTiledValuesFn newYieldValuesFn =
2195  [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
2196  ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
2198  SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
2199  OpBuilder::InsertionGuard g(innerRewriter);
2200  // 8. Set inner insertPoint right before tiled consumer op.
2201  innerRewriter.setInsertionPoint(tiledConsumerOp);
2202 
2203  SmallVector<SmallVector<OpFoldResult>> allOffsets, allSizes;
2204  for (auto candidateSliceOp : clonedInsertSlices) {
2205  SmallVector<OpFoldResult> offsets = candidateSliceOp.getMixedOffsets();
2206  SmallVector<OpFoldResult> sizes = candidateSliceOp.getMixedSizes();
2207  SmallVector<OpFoldResult> strides = candidateSliceOp.getMixedStrides();
2208 
2209  // 9. Check all insert stride is 1.
2210  if (!llvm::all_of(strides, isOneInteger)) {
2211  return rewriter.notifyMatchFailure(
2212  candidateSliceOp, "containingOp's result yield with stride");
2213  }
2214 
2215  allOffsets.emplace_back(std::move(offsets));
2216  allSizes.emplace_back(std::move(sizes));
2217  }
2218 
2219  // 10. Try to get iter domain position from input position. Use
2220  // clonedConsumerOp instead of tiledConsumerOp, because the iteration
2221  // domain may require index computation based on the result size. The
2222  // sizes and offsets should be the same either way, but using
2223  // tiledConsumerOp could lead to some chained unnecessary extra index
2224  // computation.
2225  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
2226  if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTiles(
2227  rewriter, operandNumbers, allOffsets, allSizes, iterDomainOffsets,
2228  iterDomainSizes))) {
2229  return rewriter.notifyMatchFailure(
2230  clonedConsumerOp,
2231  "can't get iter domain position from input position");
2232  }
2233 
2234  // 11. Try to fetch the offset and size for all results of the cloned
2235  // consumer. This would then be used to form the corresponding
2236  // tensor.insert_slice/parallel_insert_slice later.
2237  unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
2239  totalNumResultsOfConsumer);
2241  totalNumResultsOfConsumer);
2242  for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
2243  if (failed(tiledConsumerOp.getResultTilePosition(
2244  rewriter, idx, iterDomainOffsets, iterDomainSizes,
2245  resultOffsets[idx], resultSizes[idx]))) {
2246  return rewriter.notifyMatchFailure(
2247  tiledConsumerOp,
2248  "can't get result domain position from iter domain position");
2249  }
2250  }
2251 
2252  // 12. Create `extract_slice` for `iter_args` for DPS operation if
2253  // necessary.
2254  if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
2255  tiledConsumerOp.getOperation())) {
2256  rewriter.setInsertionPoint(tiledDestStyleOp);
2257  for (const auto &&[index, newRegionArg] :
2258  llvm::enumerate(newRegionIterArgs)) {
2259  auto destSlice = tensor::ExtractSliceOp::create(
2260  rewriter, loc, newRegionArg, resultOffsets[index],
2261  resultSizes[index],
2262  SmallVector<OpFoldResult>(resultOffsets[index].size(),
2263  rewriter.getIndexAttr(1)));
2264  // Make a copy of index to avoid a capturing structured binding, which
2265  // is a C++20 extension.
2266  auto dstNumber = index;
2267  rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
2268  tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
2269  });
2270  }
2271  }
2272 
2273  // 13. Prepare tiled offset and sizes for later `insert_slice` creation by
2274  // caller.
2275  Block *block = rewriter.getInsertionPoint()->getBlock();
2276  rewriter.setInsertionPoint(block->getTerminator());
2277  for (const auto &&[index, result] :
2278  llvm::enumerate(tiledConsumerOp->getResults())) {
2279  tiledResult.push_back(result);
2280  tiledOffset.emplace_back(resultOffsets[index]);
2281  tiledSizes.emplace_back(resultSizes[index]);
2282  }
2283  return success();
2284  };
2285  // 14. Add new inits to [nested] loops.
2286  if (failed(addInitOperandsToLoopNest(rewriter, loops, newInits,
2287  newYieldValuesFn))) {
2288  return rewriter.notifyMatchFailure(tiledConsumerOp,
2289  "unable to add new inits to nest loop");
2290  }
2291 
2292  // 15. Replace the result of scf loop and consumer op with new loop's
2293  // results.
2294 
2295  for (auto &&[oldResult, newResult] :
2296  llvm::zip(consumerOp->getResults(),
2297  loops.front()->getResults().take_back(newInits.size()))) {
2298  rewriter.replaceAllUsesWith(oldResult, newResult);
2299  }
2300 
2301  // 16. Need to erase the old scf loop and the cloned consumer op.
2302  rewriter.eraseOp(clonedConsumerOp);
2303 
2304  SmallVector<OpOperand *> tiledAndFusedOpOperands =
2305  llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
2306  return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum);
2307  });
2309  std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands),
2310  std::move(tileAndFuseResult->tiledOps)};
2311 }
2312 
2313 //===----------------------------------------------------------------------===//
2314 // lowerToLoopsUsingSCFForOp implementation.
2315 //===----------------------------------------------------------------------===//
2316 
2317 FailureOr<SmallVector<scf::ForOp>>
2319  TilingInterface op) {
2320  // TODO: Handle cases where the op has results if needed.
2321  if (op->getNumResults() > 0) {
2322  return rewriter.notifyMatchFailure(
2323  op, "unable to lower to loops operations with return values");
2324  }
2325 
2326  SmallVector<Range> domain = op.getIterationDomain(rewriter);
2327  SmallVector<Value> ivs;
2329  Location loc = op.getLoc();
2330  for (auto loopRange : domain) {
2331  Value offsetVal =
2332  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
2333  Value sizeVal =
2334  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
2335  Value strideVal =
2336  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
2337  auto loop = scf::ForOp::create(rewriter, op.getLoc(), offsetVal, sizeVal,
2338  strideVal, ValueRange{});
2339  loops.push_back(loop);
2340  ivs.push_back(loop.getInductionVar());
2341  rewriter.setInsertionPoint(loop.getBody()->getTerminator());
2342  }
2343  if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
2344  return failure();
2345  }
2346  return loops;
2347 }
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static SetVector< unsigned > getSanitizedReductionDims(ArrayRef< OpFoldResult > tileSizes, const scf::SCFTilingOptions &options)
Get the reduction dims that are tiled.
static LogicalResult checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp)
A utility function that checks whether the only use of the result of a tensor.insert_slice op is in a...
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
std::function< LogicalResult(RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, SmallVector< Value > &tiledValues, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> YieldTiledValuesFn
A function that allows returning additional yielded values during yieldTiledValuesAndReplace.
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, OpFoldResult numThreads, OpFoldResult iterationSize)
Returns true if the maximum tile offset tileSize * numThreads-1 is less than iterationSize.
static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc, scf::SCFTilingOptions::LoopType loopType, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ValueRange destinationTensors, ArrayRef< Attribute > mappingVector, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using the loop construct specifed in options.
static FailureOr< llvm::SetVector< Operation * > > checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, bool reorderOperations)
This utility currently checks whether the first userOp of loop is NOT before the last defineOp of con...
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, OpFoldResult offset, OpFoldResult tileSize)
Returns the bounded tile size given the current offset, loopRange and tileSize, i....
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Append the specified additional newInitOperands operands to the loops existing init operands (or simi...
static LogicalResult generateLoopNestUsingForOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ValueRange destinationTensors, YieldTiledValuesFn yieldTiledValuesFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.for operation.
static FailureOr< OpOperand * > getConsumerFromLoopUses(RewriterBase &rewriter, Operation *loopOp, unsigned resultNumber)
Fetches the OpOperand of the first valid user (and use) of the value val which implements TilingInter...
static FailureOr< MergeResult > mergeTilingResults(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, const SetVector< unsigned > &reductionDims, ValueRange partialResults)
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, ArrayRef< Range > iterationDomain, const scf::SCFTilingOptions &options)
Method to instantiate the tile sizes and/or number of threads specified by the user.
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< LoopLikeOpInterface > loops)
Return the untiled producer whose slice is used in a tiled consumer.
static bool tileDividesIterationDomain(Range loopRange)
Check if stride evenly divides the trip count size - offset.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes)
Function to return the bounds of the loops to be generated.
static Operation * cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *op, ValueRange newDestArgs)
Clones the operation and updates the destination if the operation implements the DestinationStyleOpIn...
static FailureOr< Operation * > getFirstUserOfLoop(Operation *loopOp)
An utility to get the first user of the given loopOp.
static LogicalResult addInitOperandsToLoopNest(RewriterBase &rewriter, MutableArrayRef< LoopLikeOpInterface > loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn)
Method to add new init values to a loop nest.
static FailureOr< SmallVector< OpOperand * > > getUntiledConsumerOperandsFromSlices(RewriterBase &rewriter, ArrayRef< Operation * > sliceOps, MutableArrayRef< LoopLikeOpInterface > loops)
A utility to fetch an untiled consumer of tensor.insert_slice/tensor.parallel_insert_slice.
static FailureOr< SmallVector< Value > > createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes, const SetVector< unsigned > &reductionDims)
static SmallVector< tensor::InsertSliceOp > cloneAsInsertSlices(RewriterBase &rewriter, ArrayRef< Operation * > candidateSlices)
static LogicalResult checkTileSizes(TilingInterface op, scf::SCFTilingOptions::LoopType loopType, ReductionTilingStrategy reductionStrategy, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Checks if any of the tiled loops are not parallel.
static tensor::InsertSliceOp cloneAsInsertSlice(RewriterBase &rewriter, InsertSliceOpTy sliceOp)
static FailureOr< OpOperand * > getUntiledConsumerFromSlice(RewriterBase &rewriter, tensor::InsertSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Fetch the untiled consumer of the outermost scf.for's result which is yielded by a tensor....
static LogicalResult verifyOptions(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options)
Verify the tile size options are set in a consistent manner.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ReductionTilingStrategy strategy, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, const llvm::SetVector< unsigned > &reductionDims)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes, const SetVector< unsigned > &reductionDims)
static SmallVector< OpFoldResult > getSplitReductionIvs(RewriterBase &rewriter, Location loc, ReductionTilingStrategy reductionStrategy, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes, const SetVector< unsigned > &reductionDims)
For the case of ReductionTilingStrategy::PartialReductionOuterParallel the PartialReductionOpInterfac...
static LogicalResult generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ArrayRef< Attribute > mappingVector, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.forall operation.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:959
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
Definition: Block.h:33
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgListType getArguments()
Definition: Block.h:87
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
MLIRContext * getContext() const
Definition: Builders.h:55
Location getUnknownLoc()
Definition: Builders.cpp:24
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:261
A class for computing basic dominance information.
Definition: Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.cpp:323
This class allows control over how the GreedyPatternRewriteDriver works.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class represents a saved insertion point.
Definition: Builders.h:325
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:548
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:314
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:456
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:459
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:852
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:636
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:188
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
Definition: AffineOps.cpp:1439
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1329
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1432
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
FailureOr< scf::SCFTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSizes)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
FailureOr< scf::SCFFuseConsumerOfSliceResult > tileAndFuseConsumerOfSlices(RewriterBase &rewriter, ArrayRef< Operation * > candidateSlices, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing consumer of a single slice by computing the slice of the consumer in-place f...
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< SmallVector< Operation * > > yieldReplacementForFusedProducer(RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef< LoopLikeOpInterface > loops, ArrayRef< unsigned > yieldResultNumber=ArrayRef< unsigned >{})
Reconstruct the fused producer from within the tiled-and-fused code.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
std::optional< SCFFuseProducerOfSliceResult > tileAndFuseProducerOfSlice(RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing producer of a single slice by computing the slice of the producer in-place.
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
FailureOr< TilingResult > replaceInsertSlicesWithTiledConsumer(OpBuilder &builder, ArrayRef< tensor::InsertSliceOp > sliceOps, ArrayRef< OpOperand * > consumerOperands)
Method to swap tensor.insert_slices with their consumers when the consumer implements the TilingInter...
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:79
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:114
Include the generated interface declarations.
bool isPerfectlyNestedForLoops(MutableArrayRef< LoopLikeOpInterface > loops)
Check if the provided loops are perfectly nested for-loops.
Definition: Utils.cpp:1516
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
ReductionTilingStrategy
Tiling can be thought of as splitting a dimension into 2 and materializing the outer dimension as a l...
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:784
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
@ ExistingAndNewOps
Only pre-existing and newly created ops are processed.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:283
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:421
Container for result values of tiling.
Fuse the consumer candidateSlices by computing the required slice of the consumer in-place.
Fuse the producer of the source of candidateSliceOp by computing the required slice of the producer i...
Control function to check if a slice needs to be fused or not, The control function receives 1) the s...
Options used to control tile + fuse.
Transformation information returned after tile and fuse.
Options to use to control tiling.
SCFTileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes to use for each loop.
SCFTilingOptions & setNumThreads(ArrayRef< OpFoldResult > numThreads)
Convenience function to set the numThreadsComputationFunction to a function that computes num threads...
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
LoopType
Specify which loop construct to use for tile and fuse.
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.