MLIR  21.0.0git
TileUsingInterface.cpp
Go to the documentation of this file.
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
24 #include "mlir/IR/Dominance.h"
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
31 #include "llvm/ADT/ScopeExit.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/Debug.h"
34 #include <optional>
35 
36 #define DEBUG_TYPE "tile-using-interface"
37 
38 using namespace mlir;
39 
42  assert(!tileSizeComputationFunction && "tile sizes already set");
43  auto tileSizes = llvm::to_vector(ts);
44  tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
45  return tileSizes;
46  };
47  return *this;
48 }
49 
52  assert(!numThreadsComputationFunction && "num tiles already set");
53  auto numThreads = llvm::to_vector(nt);
54  numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) {
55  return numThreads;
56  };
57  return *this;
58 }
59 
60 /// Helper method to adjust the interchange vector to match the iteration
61 /// domain.
64  size_t iterationDomainSize) {
65  SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
66  if (filledVector.size() < iterationDomainSize) {
67  auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
68  filledVector.append(range.begin(), range.end());
69  }
70  if (filledVector.size() > iterationDomainSize)
71  filledVector.resize(iterationDomainSize);
72  return filledVector;
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // tileUsingSCF implementation.
77 //===----------------------------------------------------------------------===//
78 
79 /// Verify the tile size options are set in a consistent manner.
80 static LogicalResult
83  // Specifying number of threads is only supported on `scf.forall` op.
84  if (options.numThreadsComputationFunction &&
86  return rewriter.notifyMatchFailure(
87  loc, "number of threads can only by specified when loop type is "
88  "set to use `scf.forall`");
89  }
90 
91  // If specified, check that the interchange vector is a permutation.
92  if (!options.interchangeVector.empty()) {
93  if (!isPermutationVector(options.interchangeVector)) {
94  return rewriter.notifyMatchFailure(
95  loc, "invalid interchange vector, not a permutation of the entire "
96  "iteration space");
97  }
98  }
99  return success();
100 }
101 
102 /// Method to instantiate the tile sizes and/or number of threads specified
103 /// by the user.
104 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
105 getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
106  ArrayRef<Range> iterationDomain,
108  OpFoldResult zero = rewriter.getIndexAttr(0);
109  SmallVector<OpFoldResult> tileSizes, numThreads;
110  size_t numLoops = iterationDomain.size();
111 
112  // Check whether the number of tiles to use is specified.
113  if (options.numThreadsComputationFunction) {
114  numThreads = options.numThreadsComputationFunction(rewriter, op);
115  numThreads.resize(numLoops, zero);
116 
117  // If the number of tiles is also specified, use that.
118  if (options.tileSizeComputationFunction) {
119  tileSizes = options.tileSizeComputationFunction(rewriter, op);
120  tileSizes.resize(numLoops, zero);
121  return {tileSizes, numThreads};
122  }
123 
124  // Compute the tile sizes from the iteration domain and number
125  // of tiles as follows
126  // - niters = ceilDiv(ub - lb, step)
127  // - tileSize = ceilDiv(niters, numThreads)
128  AffineExpr s0, s1, s2;
129  bindSymbols(rewriter.getContext(), s0, s1, s2);
130  // TODO: The step here is assumed to be 1.
131  AffineExpr numItersExpr = (s1 - s0);
132  AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s2);
133  tileSizes.resize(numLoops, zero);
134  for (auto [index, range, nt] :
135  llvm::enumerate(iterationDomain, numThreads)) {
136  if (isConstantIntValue(nt, 0))
137  continue;
138 
139  tileSizes[index] = affine::makeComposedFoldedAffineApply(
140  rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
141  }
142  tileSizes.resize(numLoops, zero);
143  return {tileSizes, numThreads};
144  }
145 
146  // Enforce the convention that "tiling by zero"
147  // skips tiling a particular dimension. This convention is significantly
148  // simpler to handle instead of adjusting affine maps to account for missing
149  // dimensions.
150  assert(options.tileSizeComputationFunction &&
151  "expected tile sizes to be specified");
152  tileSizes = options.tileSizeComputationFunction(rewriter, op);
153  tileSizes.resize(numLoops, zero);
154 
155  return {tileSizes, numThreads};
156 }
157 
158 /// Checks if any of the tiled loops are not parallel.
159 static void checkSafeToTileToForall(TilingInterface op,
160  ArrayRef<OpFoldResult> tileSizes,
161  ArrayRef<OpFoldResult> numThreads) {
162  auto iterators = op.getLoopIteratorTypes();
163  assert(iterators.size() == tileSizes.size() &&
164  "expected as many tile size values as number of loops");
165  assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
166  "when specified, expected number of threads to use for each loop");
167 
168  for (auto [index, iterator, tileSize] :
169  llvm::enumerate(iterators, tileSizes)) {
170  // If num threads is specified, check that it is greater than one only for
171  // parallel dimensions.
172  if (!numThreads.empty()) {
173  if (std::optional<int64_t> constNumThreads =
174  getConstantIntValue(numThreads[index])) {
175  if (constNumThreads.value() > 1 &&
176  iterator != utils::IteratorType::parallel) {
177  op.emitWarning() << "tiling is not thread safe at axis #" << index;
178  }
179  }
180  continue;
181  }
182 
183  if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) {
184  if (constTileSize.value() > 0 &&
185  iterator != utils::IteratorType::parallel) {
186  op.emitWarning() << "tiling is not thread safe at axis #" << index;
187  }
188  }
189  }
190 }
191 
192 /// Check if `stride` evenly divides the trip count `size - offset`.
193 static bool tileDividesIterationDomain(Range loopRange) {
194  std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
195  if (!offsetAsInt)
196  return false;
197  std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
198  if (!sizeAsInt)
199  return false;
200  std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
201  if (!strideAsInt)
202  return false;
203  return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
204 }
205 
206 /// Returns the bounded tile size given the current `offset`, `loopRange` and
207 /// `tileSize`, i.e., `min(tileSize, range.end() - offset)`.
209  Range loopRange, OpFoldResult offset,
210  OpFoldResult tileSize) {
211  std::optional<int64_t> ts = getConstantIntValue(tileSize);
212  if (ts && ts.value() == 1)
213  return tileSize;
214 
216  Range{loopRange.offset, loopRange.size, tileSize}))
217  return tileSize;
218 
219  // The tile size to use (to avoid out of bounds access) is minimum of
220  // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
221  // loop.
222  AffineExpr s0, s1, d0;
223  bindDims(b.getContext(), d0);
224  bindSymbols(b.getContext(), s0, s1);
225  AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext());
226  Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
228  b, loc, minMap, SmallVector<OpFoldResult>{offset, size, tileSize});
229 }
230 
231 /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
232 /// than `iterationSize`.
234  OpFoldResult numThreads,
235  OpFoldResult iterationSize) {
236  std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
237  std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
238  std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
239  if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
240  return false;
241  return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
242 }
243 
244 /// Compute the `OpFoldResult`s that represents the multi-dimensional
245 /// `offset`s and `size`s of the tile of the iteration space that the
246 /// innermost loop body of the generated tiled loops corresponds to.
247 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
249  ArrayRef<Range> iterationDomain,
250  ArrayRef<OpFoldResult> tileSizes,
251  ArrayRef<OpFoldResult> numThreads) {
252  SmallVector<OpFoldResult> offsets, sizes;
253  int materializedLoopNum = 0;
254 
255  if (!numThreads.empty()) {
256  AffineExpr d0, d1, s0, s1;
257  AffineExpr offsetExpr, residualTileSizeExpr;
258  bindDims(rewriter.getContext(), d0, d1);
259  bindSymbols(rewriter.getContext(), s0, s1);
260  offsetExpr = d0 + d1 * s0;
261  residualTileSizeExpr = s1 - (d0 + d1 * s0);
262 
263  for (auto [nt, tileSize, loopRange] :
264  llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
265 
266  // Non-tiled cases, set the offset and size to the
267  // `loopRange.offset/size`.
268  if (isConstantIntValue(nt, 0)) {
269  offsets.push_back(loopRange.offset);
270  sizes.push_back(loopRange.size);
271  continue;
272  }
273 
274  Value iv = ivs[materializedLoopNum++];
276  rewriter, loc, offsetExpr,
277  ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize});
279  rewriter, loc, residualTileSizeExpr,
280  {loopRange.offset, nt, tileSize, loopRange.size});
281 
282  OpFoldResult size = tileSize;
283  if (!isConstantIntValue(residualTileSize, 0)) {
284  OpFoldResult sizeMinusOffsetPerThread =
285  affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
286  {offset, loopRange.size});
288  rewriter, loc,
290  {sizeMinusOffsetPerThread, tileSize});
291  }
292 
293  // Consider the case where the original loop was `[0, 100)`.
294  // If number of threads are `7`, the tile size would be computed as
295  // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6)
296  // - `offset = 0 + 6 * 15 = 105`
297  // - `tileSize = min(15, 100 - 105) = -5`
298  // To avoid negative tile sizes, we need to do a further
299  // `nonNegativeTileSize = affine.max(0, tileSize)`.
300  // This `max` can be avoided if
301  // `offset + tileSize * (numThreads - 1) < (ub - lb)`
302  if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) {
303  AffineMap maxMap =
306  rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});
307  }
308 
309  offsets.push_back(offset);
310  sizes.push_back(size);
311  }
312  return {offsets, sizes};
313  } else {
314  for (auto [tileSize, loopRange] :
315  llvm::zip_equal(tileSizes, iterationDomain)) {
316 
317  // Non-tiled cases, set the offset and size to the
318  // `loopRange.offset/size`.
319  if (isConstantIntValue(tileSize, 0)) {
320  offsets.push_back(loopRange.offset);
321  sizes.push_back(loopRange.size);
322  continue;
323  }
324 
325  Value iv = ivs[materializedLoopNum++];
326  OpFoldResult offset = getAsOpFoldResult(iv);
327  offsets.push_back(offset);
328  OpFoldResult size =
329  getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize);
330  sizes.push_back(size);
331  }
332  return {offsets, sizes};
333  }
334 }
335 
336 /// Function to return the bounds of the loops to be generated.
337 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
340  ArrayRef<OpFoldResult> tileSizes) {
341  SmallVector<OpFoldResult> lbs, ubs, steps;
342  for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
343  // No loop if the tile size is 0.
344  if (isConstantIntValue(tileSize, 0))
345  continue;
346  lbs.push_back(loopRange.offset);
347  ubs.push_back(loopRange.size);
348  steps.push_back(tileSize);
349  }
350  return {lbs, ubs, steps};
351 }
352 
353 /// A function that allows returning additional yielded values during
354 /// `yieldTiledValuesAndReplace`.
355 /// - `ivs` induction variable for the loop.
356 /// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
357 /// - `tiledValues` the tiled values to return. Must be of same size as
358 /// `newbbArgs`, each element of this array is inserted into the corresponding
359 /// element in `newbbArgs`.
360 /// - `resultOffsets` is of the same size as `tiledValues` and represents
361 /// the offsets to use when inserting corresponding element from `tiledValues`
362 /// into the element from `newBbArgs`.
363 /// - `resultSizes` is of the same size as `tiledValues` and represents
364 /// the size of the corresponding element from `tiledValues` inserted into
365 /// the element from `newBbArgs`.
366 /// In case the method needs to return `failure()` the method is expected
367 /// to clean up any inserted operations.
368 using YieldTiledValuesFn = std::function<LogicalResult(
369  RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
370  SmallVector<Value> &tiledValues,
371  SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
372  SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
373 
374 /// Clones the operation and updates the destination if the operation
375 /// implements the `DestinationStyleOpInterface`.
377  Operation *op,
378  ValueRange newDestArgs) {
379  Operation *clonedOp = rewriter.clone(*op);
380  if (newDestArgs.empty())
381  return clonedOp;
382  if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
383  destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
384  return clonedOp;
385 }
386 
387 /// Generate the tile-loop nest using `scf.for` operation.
388 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
389 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
390 /// - `destinationTensors` are the init values to use for the outer most loop.
391 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
392 /// most
393 /// loop.
394 /// - `loops` is an in-out parameter into which the generated loops are
395 /// populated.
396 static LogicalResult generateLoopNestUsingForOp(
397  RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
398  ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
399  YieldTiledValuesFn yieldTiledValuesFn,
401  assert(!loopRanges.empty() && "unexpected empty loop ranges");
402  assert(loopRanges.size() == tileSizes.size() &&
403  "expected as many tile sizes as loop ranges");
404  OpBuilder::InsertionGuard guard(rewriter);
405 
406  SmallVector<OpFoldResult> lbs, ubs, steps;
407  std::tie(lbs, ubs, steps) =
408  getLoopBounds(rewriter, loc, loopRanges, tileSizes);
409  SmallVector<Value> lbVals =
410  getValueOrCreateConstantIndexOp(rewriter, loc, lbs);
411  SmallVector<Value> ubVals =
412  getValueOrCreateConstantIndexOp(rewriter, loc, ubs);
413  SmallVector<Value> stepVals =
414  getValueOrCreateConstantIndexOp(rewriter, loc, steps);
415 
416  SmallVector<Value> ivs;
417  for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
418  auto loop =
419  rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
420  [](OpBuilder &bodyBuilder, Location bodyLoc,
421  Value iv, ValueRange /*iterArgs*/) {});
422  loops.push_back(loop);
423  ivs.push_back(loop.getInductionVar());
424  rewriter.setInsertionPointToEnd(loop.getBody());
425  destinationTensors = loop.getRegionIterArgs();
426  }
427 
428  SmallVector<Value> tiledResults;
429  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
430  if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
431  tiledResults, resultOffsets, resultSizes))) {
432  return rewriter.notifyMatchFailure(
433  loc, "failed to generate inner tile loop body");
434  }
435  if (loops.empty())
436  return success();
437 
438  assert(tiledResults.size() == destinationTensors.size() &&
439  "Number of results of body should be equal to number of iter args");
440 
441  // 6. Yield all the results of the tiled operation.
442  SmallVector<Value> yieldedValues;
443  for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
444  llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
445  resultSizes)) {
446  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
447  rewriter.getIndexAttr(1));
448  auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
449  loc, tiledValue, destinationTensor, resultOffset, resultSize,
450  resultStride);
451  yieldedValues.push_back(insertSlice);
452  }
453  rewriter.create<scf::YieldOp>(loc, yieldedValues);
454 
455  // Add the scf.yield operations for all the outer loops.
456  for (auto [outerLoop, innerLoop] :
457  llvm::zip_equal(MutableArrayRef(loops).drop_back(),
458  MutableArrayRef(loops).drop_front())) {
459  rewriter.setInsertionPointToEnd(
460  cast<scf::ForOp>(outerLoop.getOperation()).getBody());
461  rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
462  }
463  return success();
464 }
465 
466 /// Generate the tile-loop nest using `scf.forall` operation.
467 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
468 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
469 /// - `destinationTensors` are the init values to use for the outer most loop.
470 /// - `mappingVector` is the mapping attributes to use for loop construction.
471 /// Can be empty.
472 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
473 /// most
474 /// loop.
475 /// - `loops` is an in-out parameter into which the generated loops are
476 /// populated.
477 static LogicalResult generateLoopNestUsingForallOp(
478  RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
479  ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
480  ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
482  assert(!loopRanges.empty() && "unexpected empty loop ranges");
483  assert(loopRanges.size() == tileSizes.size() &&
484  "expected as many tile sizes as loop ranges");
485  OpBuilder::InsertionGuard guard(rewriter);
486  SmallVector<OpFoldResult> offsets(loopRanges.size()),
487  sizes(loopRanges.size());
488 
489  std::optional<ArrayAttr> mappingAttr;
490  if (!mappingVector.empty())
491  mappingAttr = rewriter.getArrayAttr(mappingVector);
492 
493  scf::ForallOp forallOp;
494  bool useNumThreads = !numThreads.empty();
495 
496  if (useNumThreads) {
497  // Prune the zero numthreads.
498  SmallVector<OpFoldResult> nonZeroNumThreads;
499  for (auto nt : numThreads) {
500  if (isConstantIntValue(nt, 0))
501  continue;
502  nonZeroNumThreads.push_back(nt);
503  }
504  forallOp = rewriter.create<scf::ForallOp>(loc, nonZeroNumThreads,
505  destinationTensors, mappingAttr);
506  } else {
507  SmallVector<OpFoldResult> lbs, ubs, steps;
508  std::tie(lbs, ubs, steps) =
509  getLoopBounds(rewriter, loc, loopRanges, tileSizes);
510  forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps,
511  destinationTensors, mappingAttr);
512  }
513  loops.push_back(forallOp);
514 
515  rewriter.setInsertionPoint(forallOp.getTerminator());
516  destinationTensors = forallOp.getRegionOutArgs();
517 
518  SmallVector<Value> tiledResults;
519  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
520  if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
521  destinationTensors, tiledResults, resultOffsets,
522  resultSizes)))
523  return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
524 
525  rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
526  for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
527  llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
528  resultSizes)) {
529  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
530  rewriter.getIndexAttr(1));
531 
532  rewriter.create<tensor::ParallelInsertSliceOp>(
533  loc, tiledValue, destinationTensor, resultOffset, resultSize,
534  resultStride);
535  }
536  return success();
537 }
538 
539 /// Generate the tile-loop nest using the loop construct specifed in `options`.
540 /// - `options`: Tiling options specified.
541 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
542 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
543 /// - `destinationTensors` are the init values to use for the outer most loop.
544 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
545 /// most
546 /// loop.
547 /// - `loops` is an in-out parameter into which the generated loops are
548 /// populated.
549 static LogicalResult generateLoopNest(
550  RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options,
551  ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes,
552  ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors,
554  // If the tile sizes are all zero, no loops are generated. Just call the
555  // callback function to handle untiled case.
556  if (llvm::all_of(tileSizes, isZeroIndex)) {
557  SmallVector<Value> tiledResults;
558  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
559  return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
560  tiledResults, resultOffsets, resultSizes);
561  }
563  return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
564  destinationTensors, tiledBodyFn, loops);
565  }
568  rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector,
569  destinationTensors, tiledBodyFn, loops);
570  }
571  return rewriter.notifyMatchFailure(loc, "unhandled loop type");
572 }
573 
574 static FailureOr<SmallVector<Value>>
575 createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op,
576  ArrayRef<OpFoldResult> tileSizes,
578  SmallVector<Value> initTensors;
579  Location loc = op->getLoc();
580  switch (options.reductionStrategy) {
582  if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, initTensors)))
583  return failure();
584  return initTensors;
587  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
588  if (!redOp) {
589  return rewriter.notifyMatchFailure(
590  op, "PartialReductionOuterReduction tiling strategy is only supported"
591  "for operations implementing PartialReductionOpInterface");
592  }
593  // Get reduction dimensions.
594  // TODO: PartialReductionOpInterface should really query TilingInterface
595  // itself and find reduction dimensions.
596  SmallVector<int> reductionDims;
597  for (auto [idx, iteratorType] :
598  llvm::enumerate(op.getLoopIteratorTypes())) {
599  if (iteratorType == utils::IteratorType::reduction)
600  reductionDims.push_back(idx);
601  }
602  return redOp.generateInitialTensorForPartialReduction(
603  rewriter, loc, tileSizes, reductionDims);
604  }
605  default:
606  return rewriter.notifyMatchFailure(op,
607  "unhandled reduction tiling strategy");
608  }
609 }
610 
611 static FailureOr<TilingResult>
612 getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
613  ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
616  switch (options.reductionStrategy) {
618  return op.getTiledImplementation(rewriter, offsets, sizes);
621  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
622  if (!redOp) {
623  return rewriter.notifyMatchFailure(
624  op, "PartialReductionOuterReduction tiling strategy is only "
625  "supported for operations "
626  "implementing PartialReductionOpInterface");
627  }
628  // Get reduction dimensions.
629  // TODO: PartialReductionOpInterface should really query TilingInterface
630  // itself and find reduction dimensions.
631  SmallVector<int> reductionDims;
632  for (auto [idx, iteratorType] :
633  llvm::enumerate(op.getLoopIteratorTypes())) {
634  if (iteratorType == utils::IteratorType::reduction)
635  reductionDims.push_back(idx);
636  }
637  return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg,
638  offsets, sizes, reductionDims);
639  }
640  default:
641  return rewriter.notifyMatchFailure(op,
642  "unhandled reduction tiling strategy");
643  }
644 }
645 
646 static LogicalResult
647 getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
648  TilingInterface op, ArrayRef<OpFoldResult> offsets,
650  SmallVector<OpFoldResult> &resultOffset,
651  SmallVector<OpFoldResult> &resultSize,
653 
654  switch (options.reductionStrategy) {
656  return op.getResultTilePosition(rewriter, index, offsets, sizes,
657  resultOffset, resultSize);
660  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
661  if (!redOp) {
662  return rewriter.notifyMatchFailure(
663  op, "PartialReductionOuterReduction tiling strategy is only supported"
664  "for operations implementing PartialReductionOpInterface");
665  }
666  // Get reduction dimensions.
667  // TODO: PartialReductionOpInterface should really query TilingInterface
668  // itself and find reduction dimensions.
669  SmallVector<int> reductionDims;
670  for (auto [idx, iteratorType] :
671  llvm::enumerate(op.getLoopIteratorTypes())) {
672  if (iteratorType == utils::IteratorType::reduction)
673  reductionDims.push_back(idx);
674  }
675  return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
676  resultOffset, resultSize,
677  reductionDims);
678  }
679  default:
680  return rewriter.notifyMatchFailure(op,
681  "unhandled reduction tiling strategy");
682  }
683 }
684 
685 static FailureOr<MergeResult>
686 mergeTilingResults(RewriterBase &rewriter, TilingInterface op,
687  ValueRange partialResults,
689  switch (options.reductionStrategy) {
691  // No need to merge results for reduction tiling strategy.
692  return MergeResult{{}, partialResults};
695  auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
696  if (!redOp) {
697  return rewriter.notifyMatchFailure(
698  op, "PartialReductionOuterReduction tiling strategy is only "
699  "supported for operations "
700  "implementing PartialReductionOpInterface");
701  }
702  // Get reduction dimensions.
703  // TODO: PartialReductionOpInterface should really query TilingInterface
704  // itself and find reduction dimensions.
705  SmallVector<int> reductionDims;
706  for (auto [idx, iteratorType] :
707  llvm::enumerate(op.getLoopIteratorTypes())) {
708  if (iteratorType == utils::IteratorType::reduction)
709  reductionDims.push_back(idx);
710  }
711  return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
712  reductionDims);
713  }
714  default:
715  return rewriter.notifyMatchFailure(op,
716  "unhandled reduction tiling strategy");
717  }
718 }
719 
720 /// Append the specified additional `newInitOperands` operands to the
721 /// loops existing `init` operands (or similar), and replace `loopOp` with
722 /// the new loop that has the additional init operands. The loop body of
723 /// this loop is moved over to the new loop. `yieldTiledValuesFn`
724 /// is called to get the new tiled values returned, and the offset
725 /// and sizes at which the tiled value is inserted into the
726 /// new region iter_args that correspond to the newly added init operands.
727 template <typename LoopType>
728 FailureOr<LoopLikeOpInterface>
729 yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter,
730  ValueRange newInitOperands,
731  YieldTiledValuesFn yieldTiledValuesFn) {
732  return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
733 }
734 
735 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`.
736 template <>
737 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
738  scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
739  YieldTiledValuesFn yieldTiledValuesFn) {
740  OpBuilder::InsertionGuard g(rewriter);
741  Location loc = loopOp.getLoc();
742  rewriter.setInsertionPoint(loopOp);
743 
744  auto inits = llvm::to_vector(loopOp.getInitArgs());
745  inits.append(newInitOperands.begin(), newInitOperands.end());
746  auto newLoop = rewriter.create<scf::ForOp>(
747  loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
748  inits, [](OpBuilder &, Location, Value, ValueRange) {});
749 
750  // Move the loop body to the new op.
751  Block *loopBody = loopOp.getBody();
752  Block *newLoopBody = newLoop.getBody();
753  rewriter.mergeBlocks(
754  loopBody, newLoopBody,
755  newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
756 
757  auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator());
758  rewriter.setInsertionPoint(yieldOp);
759 
760  SmallVector<Value> tiledValues;
761  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
762  ValueRange newRegionIterArgs =
763  newLoop.getRegionIterArgs().take_back(newInitOperands.size());
764  if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
765  newRegionIterArgs, tiledValues, resultOffsets,
766  resultSizes))) {
767  rewriter.eraseOp(newLoop);
768  return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");
769  }
770 
771  SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
772  for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
773  llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
774  resultSizes)) {
775  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
776  rewriter.getIndexAttr(1));
777  Value insert = rewriter.create<tensor::InsertSliceOp>(
778  yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
779  resultStride);
780  newYieldValues.push_back(insert);
781  }
782 
783  rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
784  rewriter.replaceOp(loopOp,
785  newLoop->getResults().take_front(loopOp.getNumResults()));
786  return cast<LoopLikeOpInterface>(newLoop.getOperation());
787 }
788 
789 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall`
790 template <>
791 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
792  scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
793  YieldTiledValuesFn yieldTiledValuesFn) {
794  OpBuilder::InsertionGuard g(rewriter);
795  Location loc = loopOp.getLoc();
796  rewriter.setInsertionPoint(loopOp);
797  auto inits = llvm::to_vector(loopOp.getOutputs());
798  inits.append(newInitOperands.begin(), newInitOperands.end());
799  auto newLoop = rewriter.create<scf::ForallOp>(
800  loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
801  loopOp.getMixedStep(), inits, loopOp.getMapping(),
802  [](OpBuilder &, Location, ValueRange) {});
803 
804  // Move the region of the current block to the newly created op.
805  Block *loopBody = loopOp.getBody();
806  Block *newLoopBody = newLoop.getBody();
807  rewriter.mergeBlocks(
808  loopBody, newLoopBody,
809  newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
810 
811  auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
812  rewriter.setInsertionPoint(terminator);
813  SmallVector<Value> tiledValues;
814  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
815  ValueRange regionIterArgs =
816  newLoop.getRegionIterArgs().take_back(newInitOperands.size());
817  if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
818  regionIterArgs, tiledValues, resultOffsets,
819  resultSizes))) {
820  rewriter.eraseOp(newLoop);
821  return rewriter.notifyMatchFailure(loopOp,
822  "failed to get yielded tiled values");
823  }
824 
825  // Update the terminator.
826  rewriter.setInsertionPointToEnd(terminator.getBody());
827 
828  for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
829  tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
830  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
831  rewriter.getIndexAttr(1));
832  rewriter.create<tensor::ParallelInsertSliceOp>(
833  terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
834  resultStride);
835  }
836 
837  rewriter.replaceOp(loopOp,
838  newLoop->getResults().take_front(loopOp.getNumResults()));
839  return cast<LoopLikeOpInterface>(newLoop.getOperation());
840 }
841 
842 /// Implementation of `yieldTiledValuesAndReplaceLoop` for
843 /// `LoopLikeOpInterface`, that just dispatches to the implementation for each
844 /// supported loop type.
845 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
846  LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
847  ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
849  loopLikeOp.getOperation())
850  .Case<scf::ForOp, scf::ForallOp>(
851  [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
853  loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
854  })
855  .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
856  return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
857  });
858 }
859 
860 /// Method to add new init values to a loop nest. Updates `loops` in-place
861 /// with new loops that use the `newInitValues`. The outer-loops are updated
862 /// to yield the new result values of the inner loop. For the innermost loop,
863 /// the call back `getNewYields` is invoked to get the additional values to
864 /// yield form the innermost loop.
865 static LogicalResult addInitOperandsToLoopNest(
867  ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
868  SmallVector<scf::ForOp> newLoops;
869  if (loops.empty())
870  return success();
871  OpBuilder::InsertionGuard g(rewriter);
872  rewriter.setInsertionPoint(loops.front());
873 
874  SmallVector<Value> ivs;
875  for (auto &loop : loops.drop_back()) {
876  rewriter.setInsertionPoint(loop);
877 
878  // if loops.size() > 1 we assume that scf.for is used for the loops.
879  auto forLoop = cast<scf::ForOp>(loop.getOperation());
880 
881  // Create a new loop with the new init values for this loop.
882  SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs());
883  newInits.append(newInitValues.begin(), newInitValues.end());
884  auto newLoop = rewriter.create<scf::ForOp>(
885  forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
886  forLoop.getStep(), newInits,
887  [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
888 
889  // Merge the body of the new loop with the body of the old loops.
890  SmallVector<Value> sourceBlockArgs;
891  sourceBlockArgs.push_back(newLoop.getInductionVar());
892  auto newRegionIterArgs = newLoop.getRegionIterArgs();
893  sourceBlockArgs.append(
894  newRegionIterArgs.begin(),
895  std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
896  rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
897  rewriter.replaceOp(
898  forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
899  loop = newLoop;
900  ivs.push_back(newLoop.getInductionVar());
901  newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
902  }
903 
904  // Update the loop body of the innermost loop to get new yield values.
905  LoopLikeOpInterface innerMostLoop = loops.back();
906  FailureOr<LoopLikeOpInterface> newInnerMostLoop =
907  yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues,
908  getNewTiledYieldsFn);
909 
910  if (failed(newInnerMostLoop))
911  return innerMostLoop.emitOpError("failed to return additional yields");
912  loops.back() = newInnerMostLoop.value();
913 
914  // Make all other loops except the innermost loops yield the values returned
915  // by the inner loop.
916  for (auto [outerLoop, innerLoop] :
917  llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
918  // Again assume that all the outer loops are scf.for operations.
919  auto outerForLoop = cast<scf::ForOp>(outerLoop);
920  auto outerLoopYield =
921  cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
922  SmallVector<Value> newYields =
923  llvm::to_vector(outerLoopYield.getOperands());
924  ValueRange additionalYields =
925  innerLoop->getResults().take_back(newInitValues.size());
926  newYields.append(additionalYields.begin(), additionalYields.end());
927  rewriter.setInsertionPoint(outerLoopYield);
928  rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
929  }
930  return success();
931 }
932 
933 /// Implementation of tiling transformation of `op` that implements the
934 /// `TilingInterface` using `scf.for` to iterate over the tiles.
935 FailureOr<scf::SCFTilingResult>
936 mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
938  if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) {
939  return failure();
940  }
941 
942  OpBuilder::InsertionGuard guard(rewriter);
943  rewriter.setInsertionPointAfter(op);
944 
945  // 1. Get the range of the loops that are represented by the operation.
946  SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
947 
948  // 2. Materialize the tile sizes and/or number of threads;
949  SmallVector<OpFoldResult> tileSizes, numThreads;
950  std::tie(tileSizes, numThreads) =
951  getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options);
952 
953  // Check if it is safe to tile. This is hold over from previous iterations
954  // of tile to for-all. Consider dropping it.
956  checkSafeToTileToForall(op, tileSizes, numThreads);
957  }
958 
959  // 3. If there is an interchange specified, permute the iteration domain and
960  // the tile sizes.
961  SmallVector<int64_t> interchangeVector;
962  if (!options.interchangeVector.empty()) {
963  interchangeVector = fillInterchangeVector(options.interchangeVector,
964  iterationDomain.size());
965  assert(isPermutationVector(interchangeVector) &&
966  "expected interchange vector to be a permutation");
967 
968  applyPermutationToVector(iterationDomain, interchangeVector);
969  applyPermutationToVector(tileSizes, interchangeVector);
970  if (!numThreads.empty())
971  applyPermutationToVector(numThreads, interchangeVector);
972  }
973 
974  FailureOr<TilingResult> tilingResult;
975  // 4. Define the lambda function used later to generate the body of the
976  // innermost tiled loop.
977  YieldTiledValuesFn innerYieldTiledValuesFn =
978  [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
979  ValueRange regionIterArgs, SmallVector<Value> &tiledResults,
982  -> LogicalResult {
983  // 4a. Compute the `offsets` and `sizes` to use for tiling.
984  SmallVector<OpFoldResult> offsets, sizes;
985  std::tie(offsets, sizes) = getTileOffsetAndSizes(
986  rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
987 
988  // 4b. If interchange was provided, apply inverse of the interchange
989  // to get back the offsets/sizes in the order to be specified.
990  if (!interchangeVector.empty()) {
991  auto inversePermutation = invertPermutationVector(interchangeVector);
994  }
995 
996  // 5. Generate the tiled implementation within the inner most loop.
997 
998  // 5a. Clone the operation within the loop body.
999  auto clonedOp = cast<TilingInterface>(
1000  cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
1001 
1002  // 5b. Early return cloned op if tiling is not happening. We can not
1003  // return the original op because it could lead to `rewriter.replaceOp(op,
1004  // op->getResults())` and users would get crash.
1005  if (llvm::all_of(tileSizes, isZeroIndex)) {
1006  tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
1007  tilingResult =
1008  TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
1009  /*generatedSlices=*/{}};
1010  return success();
1011  }
1012 
1013  // 5c. Tile the cloned operation.
1014  tilingResult = getTiledImplementation(rewriter, clonedOp, regionIterArgs,
1015  offsets, sizes, options);
1016  if (failed(tilingResult)) {
1017  rewriter.eraseOp(clonedOp);
1018  return op.emitOpError("faild to tile operation");
1019  }
1020 
1021  // 5d. Delete the cloned operation.
1022  rewriter.eraseOp(clonedOp);
1023 
1024  // 5e. Compute the offsets at which the result values are to be inserted
1025  // back into its destinations.
1026  for (auto [index, tiledValue] :
1027  llvm::enumerate(tilingResult->tiledValues)) {
1028  tiledResults.push_back(tiledValue);
1029  SmallVector<OpFoldResult> resultOffset, resultSize;
1030  if (failed(getResultTilePosition(rewriter, index, tiledValue, op, offsets,
1031  sizes, resultOffset, resultSize,
1032  options))) {
1033  for (auto op : tilingResult->tiledOps) {
1034  rewriter.eraseOp(op);
1035  }
1036  return rewriter.notifyMatchFailure(
1037  op, "failed to get slice of result produced");
1038  }
1039  resultOffsets.emplace_back(std::move(resultOffset));
1040  resultSizes.emplace_back(std::move(resultSize));
1041  }
1042 
1043  return success();
1044  };
1045 
1046  // 6. Find the destination tensors to use for the operation.
1047  FailureOr<SmallVector<Value>> maybeInits =
1048  createInitialTensorsForTiling(rewriter, op, tileSizes, options);
1049  if (failed(maybeInits)) {
1050  return rewriter.notifyMatchFailure(
1051  op, "unable to create initial tensors for tiling");
1052  }
1053  SmallVector<Value> &initTensors = maybeInits.value();
1054 
1055  // 7. Generate the tiled loops nest using the callback defined above.
1057  if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
1058  tileSizes, numThreads, initTensors,
1059  innerYieldTiledValuesFn, loops)))
1060  return op.emitOpError("failed to generate tiling loops");
1061  assert(succeeded(tilingResult) &&
1062  "expected tiling result to be computed after loop generation");
1063 
1064  SmallVector<Value> partialResults;
1065  if (loops.empty()) {
1066  // If loops are empty, the tiled op is used as the replacement for the
1067  // untiled op.
1068  partialResults = tilingResult->tiledValues;
1069  } else {
1070  partialResults = llvm::map_to_vector(loops.front()->getResults(),
1071  [](OpResult r) -> Value { return r; });
1072  }
1073 
1074  FailureOr<MergeResult> mergeResult =
1075  mergeTilingResults(rewriter, op, partialResults, options);
1076  if (failed(mergeResult)) {
1077  return rewriter.notifyMatchFailure(
1078  op, "Failed to merge partial results from tiling");
1079  }
1080 
1081  return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops,
1082  mergeResult.value(),
1083  tilingResult->generatedSlices};
1084 }
1085 
1086 FailureOr<scf::SCFTilingResult>
1088  PartialReductionOpInterface op,
1089  ArrayRef<OpFoldResult> tileSizes) {
1092  options.setReductionTilingStrategy(SCFTilingOptions::ReductionTilingStrategy::
1093  PartialReductionOuterReduction);
1094  options.setTileSizes(tileSizes);
1095 
1096  TilingInterface tilingInterfaceOp =
1097  dyn_cast<TilingInterface>(op.getOperation());
1098  if (!tilingInterfaceOp) {
1099  return b.notifyMatchFailure(
1100  op,
1101  "Operation implementing PartialReductionOpInterface should implement "
1102  "TilingInterface");
1103  }
1104 
1105  return tileUsingSCF(b, tilingInterfaceOp, options);
1106 }
1107 
1108 //===----------------------------------------------------------------------===//
1109 // tileConsumerAndFuseProducersUsingSCF implementation.
1110 //===----------------------------------------------------------------------===//
1111 
1112 /// Return the untiled producer whose slice is used in a tiled consumer. The
1113 /// method traverses the tile loop nest (`loops`) if needed, and returns the
1114 /// `iter_args` of the outer most that is encountered. Traversing the
1115 /// iter_args indicates that this is a destination operand of the consumer. If
1116 /// there was no loop traversal needed, the second value of the returned tuple
1117 /// is empty.
1118 static std::tuple<OpResult, std::optional<OpOperand *>>
1121  std::optional<OpOperand *> destinationIterArg;
1122  auto loopIt = loops.rbegin();
1123  while (auto iterArg = dyn_cast<BlockArgument>(source->get())) {
1124  auto loop = *loopIt;
1125  if (iterArg.getOwner()->getParentOp() != loop)
1126  break;
1127  source = loop.getTiedLoopInit(iterArg);
1128  loopIt++;
1129  }
1130  if (loopIt == loops.rend())
1131  destinationIterArg = source;
1132  return {dyn_cast<OpResult>(source->get()), destinationIterArg};
1133 }
1134 
1135 /// Implementation of fusing producer of a single slice by computing the
1136 /// slice of the producer in-place.
1137 std::optional<scf::SCFFuseProducerOfSliceResult>
1139  RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1141  // 1. Get the producer of the source (potentially walking through
1142  // `iter_args` of nested `scf.for`)
1143  auto [fusableProducer, destinationInitArg] =
1144  getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
1145  loops);
1146  if (!fusableProducer)
1147  return std::nullopt;
1148  unsigned resultNumber = fusableProducer.getResultNumber();
1149 
1150  OpBuilder::InsertionGuard g(rewriter);
1151  rewriter.setInsertionPoint(candidateSliceOp);
1152 
1153  // 2. Clone the fused producer
1154  // 2a. Compute the destination operands to use for the cloned operation.
1155  SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
1156  Operation *fusableProducerOp = fusableProducer.getOwner();
1157  if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
1159  rewriter, fusableProducerOp->getLoc(), fusableProducerOp,
1160  origDestinationTensors)))
1161  return std::nullopt;
1162 
1163  clonedOpDestinationTensors = origDestinationTensors;
1164  if (destinationInitArg &&
1165  isa<DestinationStyleOpInterface>(fusableProducerOp)) {
1166  // 2b. If the producer is also destination style, then to maintain the
1167  // destination passing style, update the destination of the producer to be
1168  // the source of the slice.
1169  clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
1170  }
1171  // 2c. Clone the fused producer.
1172  Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
1173  rewriter, fusableProducerOp, clonedOpDestinationTensors);
1174  // 2d. Update the source of the candidateSlice to be the cloned producer.
1175  // Easier to just clone the slice with different source since
1176  // replacements and DCE of cloned ops becomes easier
1177  SmallVector<Value> candidateSliceOpOperands =
1178  llvm::to_vector(candidateSliceOp->getOperands());
1179  candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
1180  tensor::ExtractSliceOp clonedCandidateSliceOp =
1181  mlir::clone(rewriter, candidateSliceOp,
1182  candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
1183 
1184  // 3. Generate the tiled implementation of the producer of the source
1185  FailureOr<TilingResult> tileAndFuseResult =
1187  rewriter, clonedCandidateSliceOp,
1188  clonedProducerOp->getResult(resultNumber));
1189  if (failed(tileAndFuseResult))
1190  return std::nullopt;
1191  // Note: Do not delete the candidateSliceOp, since its passed in from the
1192  // caller.
1193  rewriter.replaceAllUsesWith(candidateSliceOp,
1194  tileAndFuseResult->tiledValues[0]);
1195  rewriter.eraseOp(clonedCandidateSliceOp);
1196  rewriter.eraseOp(clonedProducerOp);
1197 
1198  // 3. If the slice is for a destination operand, for example,
1199  //
1200  // ```mlir
1201  // %0 = linalg.init
1202  // %1 = linalg.fill .. outs(%0 : )
1203  // %2 = scf.for .. iter_args(%arg0 = %1) {
1204  // %3 = scf.for .. iter_args(%arg1 = %arg0) {
1205  // %4 = tensor.extract_slice %arg1 [..]
1206  // .. = linalg.matmul .. outs(%4 : )
1207  // }
1208  // }
1209  // ```
1210  //
1211  // the IR is currently
1212  //
1213  // ```
1214  // %0 = linalg.init
1215  // %1 = linalg.fill
1216  // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
1217  // %3 = scf.for .. iter_args(%arg1 = %arg0) {
1218  // %4 = tensor.extract_slice %arg1[..]
1219  // %5 = linalg.fill .. outs(%4 : )
1220  // .. = linalg.matmul .. outs(%5 : )
1221  // }
1222  // }
1223  // ```
1224  //
1225  // The untiled `linalg.fill` is still used as the `init_value` since it
1226  // was originally a destination operand of the untiled `linalg.matmul`.
1227  // When fusing an operand that is a destination operand, the iter_arg of
1228  // the outer most loop should be changed to use the destination of the
1229  // fused operation. With this the IR will be.
1230  //
1231  // ```
1232  // %0 = linalg.init
1233  // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
1234  // %2 = scf.for .. iter_args(%arg1 = %arg0) {
1235  // %3 = tensor.extract_slice %arg1[..]
1236  // %4 = linalg.fill .. outs(%3 : )
1237  // .. = linalg.matmul .. outs(%4 : )
1238  // }
1239  // }
1240  // ```
1241  if (destinationInitArg &&
1242  isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
1243  loops.front()
1244  ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
1245  .set(origDestinationTensors[resultNumber]);
1246  }
1248  fusableProducer, tileAndFuseResult->tiledValues[0],
1249  tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices};
1250 }
1251 
1252 /// Reconstruct the fused producer from within the tiled-and-fused code.
1253 FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
1254  RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
1255  scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
1257  ArrayRef<unsigned> yieldResultNumber) {
1258  if (loops.empty())
1259  return success();
1260 
1261  Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
1262  *tiledOwner = fusedProducerInfo.tiledOps[0];
1263 
1264  Location loc = originalOwner->getLoc();
1265  // a. collect all init Value to be appended
1266  SmallVector<unsigned> initNumberList =
1267  yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
1268  0, originalOwner->getNumResults()))
1269  : llvm::to_vector(yieldResultNumber);
1270  SmallVector<Value> initValueList;
1271  for (const auto &resultNumber : initNumberList) {
1272  FailureOr<Value> initValue = tensor::getOrCreateDestination(
1273  rewriter, loc, originalOwner->getResult(resultNumber));
1274  if (succeeded(initValue)) {
1275  initValueList.push_back(initValue.value());
1276  } else {
1277  return failure();
1278  }
1279  }
1280 
1281  SmallVector<Operation *> generatedSlices;
1282  YieldTiledValuesFn newYieldValuesFn =
1283  [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
1284  ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
1286  SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
1287  OpBuilder::InsertionGuard g(innerRewriter);
1288 
1289  // get sliceOp tile information
1290  SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
1291  sliceSizes = sliceOp.getMixedSizes();
1292 
1293  // expect all strides of sliceOp being 1
1294  if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
1295  return !isConstantIntValue(ofr, 1);
1296  }))
1297  return failure();
1298 
1299  unsigned sliceResultNumber =
1300  fusedProducerInfo.origProducer.getResultNumber();
1301 
1302  auto tilableOp = cast<TilingInterface>(originalOwner);
1303  // b. get iterDomain Offset and Sizes based on sliceOp tile
1304  SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
1305  // skip tensor.pack/unpack/pad, which expects single opResult
1306  if (tilableOp->getNumResults() > 1 &&
1307  failed(tilableOp.getIterationDomainTileFromResultTile(
1308  rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1309  iterDomainOffset, iterDomainSizes))) {
1310  // In theory, it is unnecessary to raise an error here. Actually
1311  // although it fails to reconstruct the result tensor, it should not
1312  // broke current fusion anyway. The reason why we must return failure
1313  // currently is that the callback function `newYieldValuesFn` will be
1314  // called after new init operand(s) has already been appended. It will
1315  // take more refactoring to make sure the init operands are added
1316  // consistently in the future. For more details, please refer to:
1317  // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
1318  return failure();
1319  }
1320 
1321  // c. calculate offsets and sizes info of all OpResults respectively based
1322  // on iteration Domain Tile
1323  SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
1324  for (const auto &resultNumber : initNumberList) {
1325  if (resultNumber == sliceResultNumber) {
1326  offsetList.push_back(sliceOffset);
1327  sizesList.push_back(sliceSizes);
1328  } else {
1329  assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1330  // infer result tile according to the iteration domain tile
1331  SmallVector<OpFoldResult> offset, sizes;
1332  if (failed(tilableOp.getResultTilePosition(
1333  rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1334  offset, sizes))) {
1335  return failure();
1336  }
1337  offsetList.push_back(offset);
1338  sizesList.push_back(sizes);
1339  }
1340  }
1341 
1342  // d. create `extract_slice` for `iter_args` for DPS operation if
1343  // necessary
1344  if (auto tiledDestStyleOp =
1345  dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1346  rewriter.setInsertionPoint(tiledDestStyleOp);
1347  for (const auto &&[index, newRegionArg] :
1348  llvm::enumerate(newRegionIterArgs)) {
1349  auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
1350  loc, newRegionArg, offsetList[index], sizesList[index],
1351  SmallVector<OpFoldResult>(offsetList[index].size(),
1352  rewriter.getIndexAttr(1)));
1353  generatedSlices.push_back(destSlice);
1354  unsigned resultNumber = initNumberList[index];
1355  rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
1356  tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1357  });
1358  }
1359  }
1360 
1361  // e. prepare tiled offset and sizes for later `insert_slice` creation by
1362  // caller
1363  Block *block = rewriter.getInsertionPoint()->getBlock();
1364  rewriter.setInsertionPoint(block->getTerminator());
1365  for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
1366  tiledResult.push_back(tiledOwner->getResult(resultNumber));
1367  tiledOffset.emplace_back(offsetList[index]);
1368  tiledSizes.emplace_back(sizesList[index]);
1369  }
1370  return success();
1371  };
1372 
1373  if (failed(addInitOperandsToLoopNest(rewriter, loops, initValueList,
1374  newYieldValuesFn))) {
1375  return failure();
1376  }
1377  return generatedSlices;
1378 }
1379 
1380 namespace {
1381 
1382 //===----------------------------------------------------------------------===//
1383 // SliceTrackingListener
1384 //===----------------------------------------------------------------------===//
1385 
1386 /// This class is a listener for tracking the insertion and removal of
1387 /// `tensor.extract_slice` ops in a worklist. This can be used in a greedy
1388 /// fusion algorithm to apply cleanup patterns in between fusion steps.
1389 class SliceTrackingListener : public RewriterBase::Listener {
1390 public:
1391  explicit SliceTrackingListener(
1392  std::optional<FrozenRewritePatternSet> patterns);
1393  SliceTrackingListener() = default;
1394 
1395  /// Adds the given list of operations to the worklist, and if present,
1396  /// applies the list of `patterns` to the newly added operations. This only
1397  /// processes the given operations and any newly inserted ones by the
1398  /// pattern set.
1399  LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
1400 
1401  /// Add to the new operation worklist if it is an extract_slice.
1402  void notifyOperationInserted(Operation *op,
1403  OpBuilder::InsertPoint previous) override;
1404 
1405  /// Shared helper for operation removal from the worklist.
1406  void removeOp(Operation *op);
1407 
1408  /// Remove the operation from the worklist.
1409  void notifyOperationErased(Operation *op) override;
1410 
1411  /// Remove the operation from the worklist.
1412  void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
1413 
1414  /// The worklist for this transformation keeps track of the slices to visit
1415  /// next for fusion.
1416  std::deque<tensor::ExtractSliceOp> worklist;
1417 
1418 private:
1419  /// Optional pattern set to apply when adding new operations to the
1420  /// worklist.
1421  std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
1422 };
1423 
1424 SliceTrackingListener::SliceTrackingListener(
1425  std::optional<FrozenRewritePatternSet> p) {
1426  patterns = std::move(p);
1427 }
1428 
1429 LogicalResult
1430 SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
1431  for (Operation *op : ops) {
1432  if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1433  worklist.push_back(slice);
1434  }
1435 
1436  if (!patterns)
1437  return success();
1438 
1440  config.listener = this;
1442  return applyOpPatternsGreedily(ops, patterns.value(), config);
1443 }
1444 
1445 void SliceTrackingListener::notifyOperationInserted(
1446  Operation *op, OpBuilder::InsertPoint previous) {
1447  auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1448  if (!slice)
1449  return;
1450  worklist.push_back(slice);
1451 }
1452 
1453 // Scan the worklist for the given op and remove it if present. The
1454 // expectation is for the worklist to be small and for removal to be
1455 // relatively rare.
1456 void SliceTrackingListener::removeOp(Operation *op) {
1457  if (!isa<tensor::ExtractSliceOp>(op))
1458  return;
1459  auto iter = worklist.begin();
1460  while (iter != worklist.end()) {
1461  if (*iter == op)
1462  break;
1463  iter++;
1464  }
1465  if (iter == worklist.end())
1466  return;
1467 
1468  worklist.erase(iter);
1469 }
1470 
1471 void SliceTrackingListener::notifyOperationErased(Operation *op) {
1472  removeOp(op);
1473 }
1474 
1475 void SliceTrackingListener::notifyOperationReplaced(Operation *op,
1476  ValueRange replacement) {
1477  removeOp(op);
1478 }
1479 
1480 //===----------------------------------------------------------------------===//
1481 // ReplacementListener
1482 //===----------------------------------------------------------------------===//
1483 
1484 /// Listener that tracks updates replacements for values which can be mutated.
1485 /// This listener runs on top of the existing listener for the rewriter,
1486 /// to make sure external users can still run listeners.
1487 class ReplacementListener : public RewriterBase::ForwardingListener {
1488 public:
1489  ReplacementListener(DenseMap<Value, Value> &replacements,
1490  OpBuilder::Listener *listener)
1491  : ForwardingListener(listener), replacements(replacements) {}
1492 
1493  void updateReplacementValues(ValueRange origValues,
1494  ValueRange replaceValues) {
1495  // This can probably be written better, but just iterates over the map
1496  // and the new replacements for now.
1497  for (auto &[key, val] : replacements) {
1498  for (auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) {
1499  if (val == orig) {
1500  val = replace;
1501  }
1502  }
1503  }
1504  }
1505 
1506  void notifyOperationReplaced(Operation *op, Operation *newOp) override {
1507  ForwardingListener::notifyOperationReplaced(op, newOp);
1508  updateReplacementValues(op->getResults(), newOp->getResults());
1509  }
1510 
1511  void notifyOperationReplaced(Operation *op, ValueRange values) override {
1512  ForwardingListener::notifyOperationReplaced(op, values);
1513  updateReplacementValues(op->getResults(), values);
1514  }
1515 
1516 private:
1517  DenseMap<Value, Value> &replacements;
1518 };
1519 
1520 } // namespace
1521 
1522 /// Implementation of tile consumer and fuse producer greedily.
1523 FailureOr<scf::SCFTileAndFuseResult>
1525  RewriterBase &rewriter, TilingInterface consumer,
1527  // This transformation is only valid for ops that return values (i.e. not
1528  // valid to use with operations that have memref operands).
1529  if (!consumer->getNumResults()) {
1530  return rewriter.notifyMatchFailure(
1531  consumer, "invalid pattern for op with no results");
1532  }
1533 
1534  // 1. First tile the consumer.
1535  SetVector<Operation *> fusedProducers, tiledAndFusedOps;
1536  llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
1537 
1538  FailureOr<scf::SCFTilingResult> tilingResult =
1539  tileUsingSCF(rewriter, consumer, options.tilingOptions);
1540 
1541  if (failed(tilingResult))
1542  return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
1543  for (auto *tiledOp : tilingResult->tiledOps)
1544  tiledAndFusedOps.insert(tiledOp);
1545 
1546  DenseMap<Value, Value> replacements;
1547  for (auto [origVal, replacement] : llvm::zip_equal(
1548  consumer->getResults(), tilingResult->mergeResult.replacements)) {
1549  replacements[origVal] = replacement;
1550  }
1551 
1552  // If there are no loops generated, fusion is immaterial.
1553  auto &loops = tilingResult->loops;
1554  if (loops.empty()) {
1555  return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1556  replacements};
1557  }
1558 
1559  // Since the loop gets potentially replaced during fusion, we need to track
1560  // the mutation of replacement values. To do this, we attach a listener to
1561  // update the replacements as they happen.
1562  OpBuilder::Listener *previousListener = rewriter.getListener();
1563  auto resetListener =
1564  llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
1565  ReplacementListener replaceListener(replacements, previousListener);
1566  rewriter.setListener(&replaceListener);
1567 
1568  // 2. Typically, the operands of the tiled operation are slices of the
1569  // operands of the untiled operation. These are expressed in IR using
1570  // `tensor.extract_slice` operations with source being the operands of
1571  // the untiled operation. Create a worklist of these
1572  // `tensor.extract_slice` operations. If the producers of the source of
1573  // the `tensor.extract_slice` can be tiled such that the tiled value is
1574  // generated in-place, that effectively tiles + fuses the operations.
1575  struct WorklistItem {
1576  tensor::ExtractSliceOp candidateSlice;
1578  };
1579 
1580  SliceTrackingListener sliceTracker =
1581  SliceTrackingListener(options.cleanupPatterns);
1582 
1583  if (failed(
1584  sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1585  return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1586  }
1587  OpBuilder::InsertionGuard g(rewriter);
1588  while (!sliceTracker.worklist.empty()) {
1589  auto candidateSlice = sliceTracker.worklist.front();
1590  sliceTracker.worklist.pop_front();
1591 
1592  auto [fusableProducer, destinationInitArg] =
1593  getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(),
1594  loops);
1595  if (!fusableProducer)
1596  continue;
1597 
1598  std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1599  options.fusionControlFn(candidateSlice, fusableProducer,
1600  destinationInitArg.has_value());
1601  if (!controlFnResult)
1602  continue;
1603 
1604  WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
1605 
1606  // The operands of the fused producer might themselved be slices of
1607  // values produced by operations that implement the `TilingInterface`.
1608  // Add these operations to the worklist.
1609  std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1610  tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice,
1611  loops);
1612  if (!fusedResult)
1613  continue;
1614 
1615  SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
1616 
1617  if (worklistItem.controlFnResult.yieldProducerReplacement) {
1618  // Reconstruct and yield all opResult of fusableProducerOp by default.
1619  // The caller can specific which one to yield by designating optional
1620  // argument named `yieldResultNumber` of
1621  // `yieldReplacementForFusedProducer`.
1622  Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
1623  FailureOr<SmallVector<Operation *>> newSlices =
1625  worklistItem.candidateSlice,
1626  fusedResult.value(), loops);
1627  if (failed(newSlices)) {
1628  return rewriter.notifyMatchFailure(
1629  fusableProducerOp, "failed to replacement value for this "
1630  "operation from within the tiled loop");
1631  }
1632  worklistCandidates.append(newSlices.value());
1633  for (auto [index, result] :
1634  llvm::enumerate(fusableProducerOp->getResults())) {
1635  replacements[result] = loops.front()->getResult(
1636  loops.front()->getNumResults() -
1637  fusableProducerOp->getNumResults() + index);
1638  }
1639  }
1640  if (Operation *tiledAndFusedOp =
1641  fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1642  fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1643  tiledAndFusedOps.insert(tiledAndFusedOp);
1644  }
1645 
1646  if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1647  return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1648  }
1649  }
1650 
1651  return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1652  replacements};
1653 }
1654 
1655 //===----------------------------------------------------------------------===//
1656 // tileAndFuseConsumerUsingSCF implementation.
1657 //===----------------------------------------------------------------------===//
1658 
1659 /// A utility function that checks whether the only use of the result of a
1660 /// tensor.insert_slice op is in a scf.yield op.
1661 static LogicalResult
1662 checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
1663  Value result = candidateSliceOp.getResult();
1664  Value::use_range uses = result.getUses();
1665  if (!llvm::hasSingleElement(uses)) {
1666  LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
1667  return failure();
1668  }
1669  OpOperand &operandUse = (*uses.begin());
1670  Operation *userOp = operandUse.getOwner();
1671  if (!isa<scf::YieldOp>(userOp)) {
1672  LLVM_DEBUG(llvm::dbgs()
1673  << "Expected scf.yield to be the only user, but got -> "
1674  << (*userOp));
1675  return failure();
1676  }
1677  if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
1678  LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
1679  "be in the same block\n");
1680  return failure();
1681  }
1682  return success();
1683 }
1684 
1685 /// An utility to get the first user of the given loopOp. If any of user stay
1686 /// in different block of loopOp, return failure.
1687 static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
1688  if (!isa<LoopLikeOpInterface>(loopOp))
1689  return failure();
1690  Operation *firstUserOfLoop = nullptr;
1691  for (Operation *userOp : loopOp->getUsers()) {
1692  // `ParallelInsertSlice` located inside `InParallelOp` has no same parent
1693  // block with any other types of operation. Thus, just redirecting to its
1694  // parent `InParallelOp`. E.g.
1695  //
1696  // ```
1697  // %1 = scf.for {
1698  // ...
1699  // }
1700  // %2 = consumerOp ins(%1, ...)
1701  // scf.forall.in_parallel {
1702  // tensor.parallel_insert_slice %1
1703  // }
1704  // ```
1705  // where `InParallelOp` but not `ParallelInsertSlice` stays in the same
1706  // same block with `consumerOp`.
1707  if (isa<tensor::ParallelInsertSliceOp>(userOp))
1708  userOp = userOp->getParentOfType<scf::InParallelOp>();
1709 
1710  if (loopOp->getBlock() != userOp->getBlock())
1711  return failure();
1712 
1713  if (!firstUserOfLoop || userOp->isBeforeInBlock(firstUserOfLoop))
1714  firstUserOfLoop = userOp;
1715  }
1716  return firstUserOfLoop;
1717 }
1718 
1719 /// This utility currently checks whether the first userOp of loop is NOT
1720 /// before the last defineOp of consumer operand. Because that we need to move
1721 /// the whole loop structure right before the `firstUserOfLoop`. This utility
1722 /// thus helps ensuring that no invalid IR is formed, i.e. no backward slice
1723 /// of consumerOp is dominated by the `firstUserOfLoop`. Saying that:
1724 ///
1725 /// ```
1726 /// %0 = scf.for() {
1727 /// ...
1728 /// }
1729 /// ...
1730 /// %1 = firstUserOfLoop(%0)
1731 /// ...
1732 /// %2 = lastDefOfConsumerOperand
1733 /// ...
1734 /// %3 = consumerOp(%2)
1735 /// ```
1736 ///
1737 /// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it
1738 /// would be invalid to move the `loopOp` right before the `firstUserOfLoop`,
1739 /// a.k.a. use-def chain violation:
1740 ///
1741 /// ```
1742 /// %0:2 = scf.for() {
1743 /// // use before define error
1744 /// %3 = tiledConsumerOp(%2)
1745 /// }
1746 /// %1 = firstUserOfLoop(%0)
1747 /// ...
1748 /// %2 = lastDefOfConsumerOperand
1749 /// ```
1750 ///
1751 /// @param loopOp: loop operation
1752 /// @param consumerOp: consumer operation
1753 /// @param reorderOperations: the flag controls whether to reorder the
1754 /// backward slice w.r.t. the defineOp of `consumerOp` operands.
1755 /// @return: computed backward slice of consumerOp, but excluding those
1756 /// already dominates `firstUserOfLoop`.
1757 static FailureOr<llvm::SetVector<Operation *>>
1759  bool reorderOperations) {
1760  FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
1761  if (failed(firstUserOfLoop))
1762  return failure();
1763 
1765  DominanceInfo dominanceInfo;
1766  options.inclusive = true;
1767  options.omitBlockArguments = true;
1768  bool includeLoopOp = false;
1769  options.filter = [&](Operation *op) {
1770  if (op == loopOp) {
1771  includeLoopOp = true;
1772  return false;
1773  }
1774  // Cut off the slice to not include any operation that already dominates
1775  // firstUserOfLoop.
1776  return !dominanceInfo.properlyDominates(op, *firstUserOfLoop);
1777  };
1779  for (auto operand : consumerOp->getOperands()) {
1780  getBackwardSlice(operand, &slice, options);
1781  }
1782 
1783  if (!slice.empty()) {
1784  // If consumerOp has one producer, which is also the user of loopOp.
1785  // E.g.
1786  // ```
1787  // %0 = %loopOp
1788  // %1 = consumerOp1 ins(%0)
1789  // %2 = consumerOp2 ins(%0, %1)
1790  // ```
1791  // We can not fuse consumerOp2 into loopOp due to UD chain, unless
1792  // consumerOp1 has already been fused into loopOp before.
1793  if (includeLoopOp || !reorderOperations)
1794  return failure();
1795  }
1796 
1797  return slice;
1798 }
1799 
1800 /// Fetches the OpOperand of the first valid user (and use) of the value `val`
1801 /// which implements `TilingInterface` and `DestinationStyleOpInterface`.
1802 /// Returns failure otherwise.
1803 static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
1804  Operation *loopOp,
1805  unsigned resultNumber) {
1806  if (!isa<LoopLikeOpInterface>(loopOp))
1807  return failure();
1808  Value val = loopOp->getResult(resultNumber);
1809  Block *loopBlock = loopOp->getBlock();
1810  for (OpOperand &opOperand : val.getUses()) {
1811  Operation *consumerOp = opOperand.getOwner();
1812  // Step 1. Check if the user is tilable.
1813  if (!isa<TilingInterface>(consumerOp) ||
1814  !isa<DestinationStyleOpInterface>(consumerOp)) {
1815  // TODO: We have to init result of consumer before scf.for, use
1816  // DestinationStyleOpInterface to get result shape from init for now.
1817  // Add support for other op such as op has InferTypeOpInterface.
1818  continue;
1819  }
1820  // Step 2. Check if user stay in the same block.
1821  if (loopBlock != consumerOp->getBlock())
1822  continue;
1823  // Step 3. Check if user has succeeding user. Otherwise, it usually
1824  // represents already tiled.
1825  if (consumerOp->use_empty())
1826  continue;
1827  // Step 4. Check assumption for loop with `reorderOperations` enabled.
1828  FailureOr<llvm::SetVector<Operation *>> slice =
1829  checkAssumptionForLoop(loopOp, consumerOp, true);
1830  if (failed(slice))
1831  continue;
1832  // Step 5. If backward sice is not empty, move them before
1833  // firstUserOfLoop.
1834  if (!slice->empty()) {
1835  mlir::topologicalSort(*slice);
1836  FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
1837  assert(succeeded(firstUserOfLoop) && "First user of loop is not found");
1838  for (auto op : *slice) {
1839  rewriter.moveOpBefore(op, *firstUserOfLoop);
1840  }
1841  }
1842  return &opOperand;
1843  }
1844  return failure();
1845 }
1846 
1847 /// Find the perfectly nested loops outside of given loop(included) sorted
1848 /// from outer to inner.
1849 ///
1850 /// E.g.
1851 ///
1852 /// ```
1853 /// %0 = scf.for()
1854 /// %1 = scf.for()
1855 /// %2 = scf.for()
1856 /// %3 = ...
1857 /// yield %3
1858 /// yield %2
1859 /// yield %1
1860 /// ```
1861 ///
1862 /// This function will return three perfectly nested loops: %0 + %1 + %2, when
1863 /// target inner loop is %2.
1866  SmallVector<scf::ForOp> nestLoops = {loop};
1867  auto outerLoop = dyn_cast<scf::ForOp>(loop->getParentOp());
1868 
1869  // Check if it is the ForOp that yield the result of inner loop.
1870  auto isForOpYieldResultOfInnerLoop =
1871  [](scf::ForOp outerLoop) -> LogicalResult {
1872  Block *body = outerLoop.getBody();
1873  if (!llvm::hasSingleElement(body->without_terminator()))
1874  return failure();
1875  auto yieldOp = cast<scf::YieldOp>(body->getTerminator());
1876  auto innerForOp = dyn_cast<scf::ForOp>(body->front());
1877  if (!innerForOp)
1878  return failure();
1879  // All of innerForOp results should be yielded.
1880  return success(innerForOp->getNumResults() == yieldOp->getNumOperands());
1881  };
1882 
1883  while (outerLoop && succeeded(isForOpYieldResultOfInnerLoop(outerLoop))) {
1884  nestLoops.push_back(outerLoop);
1885  outerLoop = dyn_cast<scf::ForOp>(outerLoop->getParentOp());
1886  }
1887  // sorted from outer to inner
1888  return {nestLoops.rbegin(), nestLoops.rend()};
1889 }
1890 
1891 /// Fetch the untiled consumer of a scf.for's result which is yielded by a
1892 /// tensor.insert_slice. This function makes the following assumptions :
1893 /// 1. tensor.insert_slice has scf.yield as its only user.
1894 /// 2. scf.for's corresponding result has only one use.
1895 static FailureOr<OpOperand *>
1897  tensor::InsertSliceOp candidateSliceOp) {
1898  if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
1899  return failure();
1900  Value sliceResult = candidateSliceOp.getResult();
1901  // Step 1. Fetch the corresponding output.
1902  OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
1903  unsigned resultNumber = yieldOpOperand.getOperandNumber();
1904  // Step 2. Check containing op is scf.for.
1905  Operation *containingOp = candidateSliceOp->getParentOp();
1906  auto forOp = dyn_cast<scf::ForOp>(containingOp);
1907  if (!forOp)
1908  return failure();
1909  scf::ForOp topLevelForOp = getPerfectlyNestedLoopsOutsideOf(forOp).front();
1910 
1911  return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber);
1912 }
1913 
1914 /// Fetch the first untiled consumer of a scf.forall's result which is yielded
1915 /// by a tensor.parallel_insert_slice.
1916 static FailureOr<OpOperand *>
1918  tensor::ParallelInsertSliceOp candidateSliceOp) {
1919  // Step 1. Fetch the corresponding output
1920  Value sliceDest = candidateSliceOp.getDest();
1921  auto iterArg = dyn_cast<BlockArgument>(sliceDest);
1922  if (!iterArg)
1923  return failure();
1924  Operation *containingOp = iterArg.getOwner()->getParentOp();
1925  if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
1926  return failure();
1927  // Step 2. Check that the containing op is scf.forall.
1928  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
1929  if (!forallOp)
1930  return failure();
1931  unsigned resultNumber =
1932  forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
1933  .getResultNumber();
1934 
1935  return getConsumerFromLoopUses(rewriter, containingOp, resultNumber);
1936 }
1937 
1938 /// A utility to fetch an untiled consumer of
1939 /// tensor.insert_slice/tensor.parallel_insert_slice.
1940 static FailureOr<OpOperand *>
1942  if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1943  return getUntiledConsumerFromSlice(rewriter, insertSlice);
1944  } else if (auto parallelInsertSlice =
1945  dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1946  return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice);
1947  } else {
1948  return failure();
1949  }
1950 }
1951 
1952 /// Implementation of fusing consumer of a single slice by computing the
1953 /// slice of the consumer in-place for scf loop.
1954 FailureOr<scf::SCFFuseConsumerOfSliceResult>
1956  Operation *candidateSliceOp) {
1957  if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
1958  candidateSliceOp))
1959  return failure();
1960 
1961  bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
1962 
1963  // 1. Get the consumer of scf.for for the result yielded by
1964  // tensor.insert_slice/parallel_insert_slice.
1965  FailureOr<OpOperand *> maybeConsumerOpOperand =
1966  getUntiledConsumerFromSlice(rewriter, candidateSliceOp);
1967  if (failed(maybeConsumerOpOperand)) {
1968  return rewriter.notifyMatchFailure(candidateSliceOp,
1969  "could not fetch consumer to fuse");
1970  }
1971  OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
1972  Operation *consumerOp = consumerOpOperand->getOwner();
1973  unsigned operandNumber = consumerOpOperand->getOperandNumber();
1974  unsigned resultNumber = 0;
1975  if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
1976  resultNumber = producerResult.getResultNumber();
1977  } else {
1978  return rewriter.notifyMatchFailure(
1979  consumerOp, "consumer op's operand doesn't seem to be an OpResult");
1980  }
1981 
1982  // There are two possible cases regarding `oldLoopOp` here:
1983  // 1. single `scf.forall` or `scf.for`.
1984  // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
1985  // top-level loop is the outer-most one of these nested loops.
1986  LoopLikeOpInterface innerMostLoop =
1987  candidateSliceOp->getParentOfType<LoopLikeOpInterface>();
1989  if (isInsertSliceOp) {
1990  nestedLoops = llvm::map_to_vector(
1992  cast<scf::ForOp>(innerMostLoop.getOperation())),
1993  [](scf::ForOp forOp) {
1994  return cast<LoopLikeOpInterface>(forOp.getOperation());
1995  });
1996  } else {
1997  nestedLoops = {innerMostLoop};
1998  }
1999 
2000  LoopLikeOpInterface outerMostLoop = nestedLoops.front();
2001 
2002  // Check assumption for loop with `reorderOperations` disabled.
2003  if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
2004  return rewriter.notifyMatchFailure(
2005  outerMostLoop, "the first user of loop should not dominate any define "
2006  "of consumer operand(s)");
2007  }
2008 
2009  OpBuilder::InsertionGuard g(rewriter);
2010 
2011  // 2. Check consumer is not using scf loop's output as init.
2012  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
2013  if (!dstOp)
2014  return rewriter.notifyMatchFailure(consumerOp,
2015  "consumer op is not DPS operation");
2016  SmallVector<Value> dpsInits =
2017  llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
2018  if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
2019  return rewriter.notifyMatchFailure(
2020  consumerOp,
2021  "consumer op taking the result of scf.for as init is not supported");
2022  }
2023  SmallVector<Value> newInits = dpsInits;
2024 
2025  Location loc = outerMostLoop->getLoc();
2026 
2027  // 3. Move the whole loop structure right before firstUserOfLoop, the
2028  // dominance should be already ensured by `checkAssumptionForLoop`.
2029  FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(outerMostLoop);
2030  if (failed(firstUserOfLoop)) {
2031  return rewriter.notifyMatchFailure(
2032  outerMostLoop, "could not find the first user of outer most loop");
2033  }
2034  rewriter.moveOpBefore(outerMostLoop, *firstUserOfLoop);
2035 
2036  // 4. Set insertion point before terminator op of the loop and create a new
2037  // tensor.insert_slice. In the scf.for case this is a clone of the
2038  // candidateSliceOp whereas in the scf.forall case this is created from the
2039  // operands of tensor.parallel_insert_slice.
2040  tensor::InsertSliceOp clonedInsertSliceOp;
2041  if (auto sliceOp =
2042  dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
2043  auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
2044  rewriter.setInsertionPoint(newForallOp.getTerminator());
2045  clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
2046  loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
2047  sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
2048  } else {
2049  rewriter.setInsertionPoint(candidateSliceOp);
2050  clonedInsertSliceOp =
2051  cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
2052  }
2053 
2054  // 5.a. Clone consumer op.
2055  auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
2056 
2057  // 5.b. Replace all uses of the loop result with the result of the cloned
2058  // tensor.insert_slice.
2059  OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
2060  rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
2061  operandToReplace.set(clonedInsertSliceOp.getResult());
2062  });
2063 
2064  // 6. Perform tiling of the cloned consumer and replace the operand at
2065  // `operandNumber` with the source of the cloned tensor.insert_slice op.
2066  auto ossSliceOp =
2067  cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
2068  FailureOr<TilingResult> tileAndFuseResult =
2070  rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
2071  if (failed(tileAndFuseResult)) {
2072  return failure();
2073  }
2074  auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
2075  rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
2076  clonedInsertSliceOp.getSource());
2077 
2078  // 7. Reconstruct [nested] loop with new inits.
2079  YieldTiledValuesFn newYieldValuesFn =
2080  [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
2081  ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
2083  SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
2084  OpBuilder::InsertionGuard g(innerRewriter);
2085  // 8. Set inner insertPoint right before tiled consumer op.
2086  innerRewriter.setInsertionPoint(tiledConsumerOp);
2087 
2088  SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
2089  SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
2090  SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
2091 
2092  // 9. Check all insert stride is 1.
2093  if (llvm::any_of(strides, [](OpFoldResult stride) {
2094  return !isConstantIntValue(stride, 1);
2095  })) {
2096  return rewriter.notifyMatchFailure(
2097  candidateSliceOp, "containingOp's result yield with stride");
2098  }
2099 
2100  // 10. Try to get iter domain position from input position. Use
2101  // clonedConsumerOp instead of tiledConsumerOp, because the iteration
2102  // domain may require index computation based on the result size. The
2103  // sizes and offsets should be the same either way, but using
2104  // tiledConsumerOp could lead to some chained unnecessary extra index
2105  // computation.
2106  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
2107  if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
2108  rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
2109  iterDomainSizes))) {
2110  return rewriter.notifyMatchFailure(
2111  clonedConsumerOp,
2112  "can't get iter domain position from input position");
2113  }
2114 
2115  // 11. Try to fetch the offset and size for all results of the cloned
2116  // consumer. This would then be used to form the corresponding
2117  // tensor.insert_slice/parallel_insert_slice later.
2118  unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
2120  totalNumResultsOfConsumer);
2122  totalNumResultsOfConsumer);
2123  for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
2124  if (failed(tiledConsumerOp.getResultTilePosition(
2125  rewriter, idx, iterDomainOffsets, iterDomainSizes,
2126  resultOffsets[idx], resultSizes[idx]))) {
2127  return rewriter.notifyMatchFailure(
2128  tiledConsumerOp,
2129  "can't get result domain position from iter domain position");
2130  }
2131  }
2132 
2133  // 12. Create `extract_slice` for `iter_args` for DPS operation if
2134  // necessary.
2135  if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
2136  tiledConsumerOp.getOperation())) {
2137  rewriter.setInsertionPoint(tiledDestStyleOp);
2138  for (const auto &&[index, newRegionArg] :
2139  llvm::enumerate(newRegionIterArgs)) {
2140  auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
2141  loc, newRegionArg, resultOffsets[index], resultSizes[index],
2142  SmallVector<OpFoldResult>(resultOffsets[index].size(),
2143  rewriter.getIndexAttr(1)));
2144  // Make a copy of index to avoid a capturing structured binding, which
2145  // is a C++20 extension.
2146  auto dstNumber = index;
2147  rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
2148  tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
2149  });
2150  }
2151  }
2152 
2153  // 13. Prepare tiled offset and sizes for later `insert_slice` creation by
2154  // caller.
2155  Block *block = rewriter.getInsertionPoint()->getBlock();
2156  rewriter.setInsertionPoint(block->getTerminator());
2157  for (const auto &&[index, result] :
2158  llvm::enumerate(tiledConsumerOp->getResults())) {
2159  tiledResult.push_back(result);
2160  tiledOffset.emplace_back(resultOffsets[index]);
2161  tiledSizes.emplace_back(resultSizes[index]);
2162  }
2163  return success();
2164  };
2165  // 14. Add new inits to [nested] loops.
2166  if (failed(addInitOperandsToLoopNest(rewriter, nestedLoops, newInits,
2167  newYieldValuesFn))) {
2168  return rewriter.notifyMatchFailure(tiledConsumerOp,
2169  "unable to add new inits to nest loop");
2170  }
2171 
2172  // 15. Replace the result of scf loop and consumer op with new loop's
2173  // results.
2174 
2175  for (auto &&[oldResult, newResult] : llvm::zip(
2176  consumerOp->getResults(),
2177  nestedLoops.front()->getResults().take_back(newInits.size()))) {
2178  rewriter.replaceAllUsesWith(oldResult, newResult);
2179  }
2180 
2181  // 16. Need to erase the old scf loop and the cloned consumer op.
2182  rewriter.eraseOp(clonedConsumerOp);
2183 
2185  consumerOpOperand,
2186  &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
2187  tileAndFuseResult->tiledOps};
2188 }
2189 
2190 //===----------------------------------------------------------------------===//
2191 // lowerToLoopsUsingSCFForOp implementation.
2192 //===----------------------------------------------------------------------===//
2193 
2194 FailureOr<SmallVector<scf::ForOp>>
2196  TilingInterface op) {
2197  // TODO: Handle cases where the op has results if needed.
2198  if (op->getNumResults() > 0) {
2199  return rewriter.notifyMatchFailure(
2200  op, "unable to lower to loops operations with return values");
2201  }
2202 
2203  SmallVector<Range> domain = op.getIterationDomain(rewriter);
2204  SmallVector<Value> ivs;
2206  Location loc = op.getLoc();
2207  for (auto loopRange : domain) {
2208  Value offsetVal =
2209  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
2210  Value sizeVal =
2211  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
2212  Value strideVal =
2213  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
2214  auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
2215  strideVal, ValueRange{});
2216  loops.push_back(loop);
2217  ivs.push_back(loop.getInductionVar());
2218  rewriter.setInsertionPoint(loop.getBody()->getTerminator());
2219  }
2220  if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
2221  return failure();
2222  }
2223  return loops;
2224 }
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static LogicalResult verifyTileSizeOptions(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options)
Verify the tile size options are set in a consistent manner.
static LogicalResult checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp)
A utility function that checks whether the only use of the result of a tensor.insert_slice op is in a...
std::function< LogicalResult(RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, SmallVector< Value > &tiledValues, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> YieldTiledValuesFn
A function that allows returning additional yielded values during yieldTiledValuesAndReplace.
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, OpFoldResult numThreads, OpFoldResult iterationSize)
Returns true if the maximum tile offset tileSize * numThreads-1 is less than iterationSize.
static LogicalResult getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize, const scf::SCFTilingOptions &options)
static FailureOr< llvm::SetVector< Operation * > > checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp, bool reorderOperations)
This utility currently checks whether the first userOp of loop is NOT before the last defineOp of con...
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef< Range > iterationDomain, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Compute the OpFoldResults that represents the multi-dimensional offsets and sizes of the tile of the ...
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, OpFoldResult offset, OpFoldResult tileSize)
Returns the bounded tile size given the current offset, loopRange and tileSize, i....
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Append the specified additional newInitOperands operands to the loops existing init operands (or simi...
static LogicalResult generateLoopNestUsingForOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ValueRange destinationTensors, YieldTiledValuesFn yieldTiledValuesFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.for operation.
static FailureOr< OpOperand * > getConsumerFromLoopUses(RewriterBase &rewriter, Operation *loopOp, unsigned resultNumber)
Fetches the OpOperand of the first valid user (and use) of the value val which implements TilingInter...
static FailureOr< OpOperand * > getUntiledConsumerFromSlice(RewriterBase &rewriter, tensor::InsertSliceOp candidateSliceOp)
Fetch the untiled consumer of a scf.for's result which is yielded by a tensor.insert_slice.
static void checkSafeToTileToForall(TilingInterface op, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads)
Checks if any of the tiled loops are not parallel.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, ArrayRef< Range > iterationDomain, const scf::SCFTilingOptions &options)
Method to instantiate the tile sizes and/or number of threads specified by the user.
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< LoopLikeOpInterface > loops)
Return the untiled producer whose slice is used in a tiled consumer.
static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using the loop construct specifed in options.
static SmallVector< scf::ForOp > getPerfectlyNestedLoopsOutsideOf(scf::ForOp loop)
Find the perfectly nested loops outside of given loop(included) sorted from outer to inner.
static bool tileDividesIterationDomain(Range loopRange)
Check if stride evenly divides the trip count size - offset.
static std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes)
Function to return the bounds of the loops to be generated.
static Operation * cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *op, ValueRange newDestArgs)
Clones the operation and updates the destination if the operation implements the DestinationStyleOpIn...
static FailureOr< Operation * > getFirstUserOfLoop(Operation *loopOp)
An utility to get the first user of the given loopOp.
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, const scf::SCFTilingOptions &options)
static LogicalResult addInitOperandsToLoopNest(RewriterBase &rewriter, MutableArrayRef< LoopLikeOpInterface > loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn)
Method to add new init values to a loop nest.
static FailureOr< MergeResult > mergeTilingResults(RewriterBase &rewriter, TilingInterface op, ValueRange partialResults, const scf::SCFTilingOptions &options)
static FailureOr< SmallVector< Value > > createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op, ArrayRef< OpFoldResult > tileSizes, const scf::SCFTilingOptions &options)
static LogicalResult generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > numThreads, ArrayRef< Attribute > mappingVector, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.forall operation.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:964
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
A class for computing basic dominance information.
Definition: Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
Definition: Dominance.cpp:324
This class allows control over how the GreedyPatternRewriteDriver works.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class represents a saved insertion point.
Definition: Builders.h:325
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:544
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:314
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:466
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:853
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:874
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:644
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
Definition: AffineOps.cpp:1311
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1304
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1198
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
FailureOr< scf::SCFFuseConsumerOfSliceResult > tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp)
Implementation of fusing consumer of a single slice by computing the slice of the consumer in-place f...
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< SmallVector< Operation * > > yieldReplacementForFusedProducer(RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef< LoopLikeOpInterface > loops, ArrayRef< unsigned > yieldResultNumber=ArrayRef< unsigned >{})
Reconstruct the fused producer from within the tiled-and-fused code.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
FailureOr< scf::SCFTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
std::optional< SCFFuseProducerOfSliceResult > tileAndFuseProducerOfSlice(RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing producer of a single slice by computing the slice of the producer in-place.
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:77
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:112
FailureOr< TilingResult > replaceInsertSliceWithTiledConsumer(OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp, OpOperand &consumerOp)
Method to swap an tensor.insert_slice with its consumer when the consumer implements the TilingInterf...
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
const FrozenRewritePatternSet GreedyRewriteConfig config
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
LogicalResult applyOpPatternsGreedily(ArrayRef< Operation * > ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr, bool *allErased=nullptr)
Rewrite the specified ops by repeatedly applying the highest benefit patterns in a greedy worklist dr...
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
@ ExistingAndNewOps
Only pre-existing and newly created ops are processed.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Container for the result of merge operation of tiling.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:283
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
A listener that forwards all notifications to another listener.
Definition: PatternMatch.h:463
Container for result values of tiling.
Fuse the consumer of the source of candidateSliceOp by computing the required slice of the consumer i...
Fuse the producer of the source of candidateSliceOp by computing the required slice of the producer i...
Control function to check if a slice needs to be fused or not, The control function receives 1) the s...
Options used to control tile + fuse.
Transformation information returned after tile and fuse.
Options to use to control tiling.
SCFTileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes to use for each loop.
SCFTilingOptions & setNumThreads(ArrayRef< OpFoldResult > numThreads)
Convenience function to set the numThreadsComputationFunction to a function that computes num threads...
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > tileSizes)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
ReductionTilingStrategy
Specify how reduction dimensions should be tiled.
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.