MLIR  20.0.0git
TileUsingInterface.cpp
Go to the documentation of this file.
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
22 #include "mlir/IR/Dominance.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/Debug.h"
29 #include <optional>
30 
31 #define DEBUG_TYPE "tile-using-interface"
32 
33 using namespace mlir;
34 
37  assert(!tileSizeComputationFunction && "tile sizes already set");
38  auto tileSizes = llvm::to_vector(ts);
39  tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
40  return tileSizes;
41  };
42  return *this;
43 }
44 
45 /// Helper method to adjust the interchange vector to match the iteration
46 /// domain.
49  size_t iterationDomainSize) {
50  SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
51  if (filledVector.size() < iterationDomainSize) {
52  auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
53  filledVector.append(range.begin(), range.end());
54  }
55  if (filledVector.size() > iterationDomainSize)
56  filledVector.resize(iterationDomainSize);
57  return filledVector;
58 }
59 
60 //===----------------------------------------------------------------------===//
61 // tileUsingSCF implementation.
62 //===----------------------------------------------------------------------===//
63 
64 // Check if `stride` evenly divides the trip count `size - offset`.
65 static bool tileDividesIterationDomain(Range loopRange) {
66  std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
67  if (!offsetAsInt)
68  return false;
69  std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
70  if (!sizeAsInt)
71  return false;
72  std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
73  if (!strideAsInt)
74  return false;
75  return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
76 }
77 
78 /// Returns the bounded tile size given the current `iv`, `loopRange` and
79 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
81  Range loopRange, Value iv,
82  OpFoldResult tileSize) {
83  std::optional<int64_t> ts = getConstantIntValue(tileSize);
84  if (ts && ts.value() == 1)
85  return tileSize;
86 
88  Range{loopRange.offset, loopRange.size, tileSize}))
89  return tileSize;
90 
91  // The tile size to use (to avoid out of bounds access) is minimum of
92  // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
93  // loop.
94  AffineExpr s0, s1, d0;
95  bindDims(b.getContext(), d0);
96  bindSymbols(b.getContext(), s0, s1);
97  AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext());
98  Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
100  b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
101 }
102 
103 /// A function that allows returning additional yielded values during
104 /// `yieldTiledValuesAndReplace`.
105 /// - `ivs` induction variable for the loop.
106 /// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
107 /// - `tiledValues` the tiled values to return. Must be of same size as
108 /// `newbbArgs`, each element of this array is inserted into the corresponding
109 /// element in `newbbArgs`.
110 /// - `resultOffsets` is of the same size as `tiledValues` and represents
111 /// the offsets to use when inserting corresponding element from `tiledValues`
112 /// into the element from `newBbArgs`.
113 /// - `resultSizes` is of the same size as `tiledValues` and represents
114 /// the size of the corresponding element from `tiledValues` inserted into
115 /// the element from `newBbArgs`.
116 /// In case the method needs to return `failure()` the method is expected
117 /// to clean up any inserted operations.
118 using YieldTiledValuesFn = std::function<LogicalResult(
119  RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
120  SmallVector<Value> &tiledValues,
121  SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
122  SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
123 
124 /// Clones the operation and updates the destination if the operation
125 /// implements the `DestinationStyleOpInterface`.
127  Operation *op,
128  ValueRange newDestArgs) {
129  Operation *clonedOp = rewriter.clone(*op);
130  if (newDestArgs.empty())
131  return clonedOp;
132  if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
133  destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
134  return clonedOp;
135 }
136 
137 /// Generate the tile-loop nest using `scf.for` operation.
138 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
139 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
140 /// - `destinationTensors` are the init values to use for the outer most loop.
141 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
142 /// most
143 /// loop.
144 /// - `loops` is an in-out parameter into which the generated loops are
145 /// populated.
146 static LogicalResult generateLoopNestUsingForOp(
147  RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
148  ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
149  YieldTiledValuesFn yieldTiledValuesFn,
151  assert(!loopRanges.empty() && "unexpected empty loop ranges");
152  assert(loopRanges.size() == tileSizes.size() &&
153  "expected as many tile sizes as loop ranges");
154  OpBuilder::InsertionGuard guard(rewriter);
155  SmallVector<Value> ivs;
156 
157  for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
158  // No loops if tile size is zero. Set offset and size to the loop
159  // offset and size.
160  if (isConstantIntValue(tileSize, 0))
161  continue;
162 
163  Value lb = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
164  Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
165  Value step = getValueOrCreateConstantIndexOp(rewriter, loc, tileSize);
166  auto loop =
167  rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
168  [](OpBuilder &bodyBuilder, Location bodyLoc,
169  Value iv, ValueRange /*iterArgs*/) {});
170  loops.push_back(loop);
171  ivs.push_back(loop.getInductionVar());
172  rewriter.setInsertionPointToEnd(loop.getBody());
173  destinationTensors = loop.getRegionIterArgs();
174  }
175 
176  SmallVector<Value> tiledResults;
177  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
178  if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
179  tiledResults, resultOffsets, resultSizes))) {
180  return rewriter.notifyMatchFailure(
181  loc, "failed to generate inner tile loop body");
182  }
183  if (loops.empty())
184  return success();
185 
186  assert(tiledResults.size() == destinationTensors.size() &&
187  "Number of results of body should be equal to number of iter args");
188 
189  // 6. Yield all the results of the tiled operation.
190  SmallVector<Value> yieldedValues;
191  for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
192  llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
193  resultSizes)) {
194  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
195  rewriter.getIndexAttr(1));
196  auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
197  loc, tiledValue, destinationTensor, resultOffset, resultSize,
198  resultStride);
199  yieldedValues.push_back(insertSlice);
200  }
201  rewriter.create<scf::YieldOp>(loc, yieldedValues);
202 
203  // Add the scf.yield operations for all the outer loops.
204  for (auto [outerLoop, innerLoop] :
205  llvm::zip_equal(MutableArrayRef(loops).drop_back(),
206  MutableArrayRef(loops).drop_front())) {
207  rewriter.setInsertionPointToEnd(
208  cast<scf::ForOp>(outerLoop.getOperation()).getBody());
209  rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
210  }
211  return success();
212 }
213 
214 /// Generate the tile-loop nest using `scf.forall` operation.
215 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
216 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
217 /// - `destinationTensors` are the init values to use for the outer most loop.
218 /// - `mappingVector` is the mapping attributes to use for loop construction.
219 /// Can be empty.
220 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
221 /// most
222 /// loop.
223 /// - `loops` is an in-out parameter into which the generated loops are
224 /// populated.
225 static LogicalResult generateLoopNestUsingForallOp(
226  RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
227  ArrayRef<OpFoldResult> tileSizes, ArrayRef<Attribute> mappingVector,
228  ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn,
230  SmallVector<OpFoldResult> lbs, ubs, steps;
231  assert(!loopRanges.empty() && "unexpected empty loop ranges");
232  assert(loopRanges.size() == tileSizes.size() &&
233  "expected as many tile sizes as loop ranges");
234  OpBuilder::InsertionGuard guard(rewriter);
235  SmallVector<OpFoldResult> offsets(loopRanges.size()),
236  sizes(loopRanges.size());
237 
238  for (auto [tileSize, loopRange] : llvm::zip_equal(tileSizes, loopRanges)) {
239  if (isConstantIntValue(tileSize, 0))
240  continue;
241  lbs.push_back(loopRange.offset);
242  ubs.push_back(loopRange.size);
243  steps.push_back(tileSize);
244  }
245  assert(!lbs.empty() && "Expected at least one loop range");
246 
247  std::optional<ArrayAttr> mappingAttr;
248  if (!mappingVector.empty())
249  mappingAttr = rewriter.getArrayAttr(mappingVector);
250 
251  auto forallOp = rewriter.create<scf::ForallOp>(
252  loc, lbs, ubs, steps, destinationTensors, mappingAttr);
253  loops.push_back(forallOp);
254 
255  rewriter.setInsertionPoint(forallOp.getTerminator());
256  destinationTensors = forallOp.getRegionOutArgs();
257 
258  SmallVector<Value> tiledResults;
259  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
260  if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
261  destinationTensors, tiledResults, resultOffsets,
262  resultSizes)))
263  return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
264 
265  rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
266  for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
267  llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
268  resultSizes)) {
269  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
270  rewriter.getIndexAttr(1));
271 
272  rewriter.create<tensor::ParallelInsertSliceOp>(
273  loc, tiledValue, destinationTensor, resultOffset, resultSize,
274  resultStride);
275  }
276  return success();
277 }
278 
279 /// Generate the tile-loop nest using the loop construct specifed in `options`.
280 /// - `options`: Tiling options specified.
281 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
282 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
283 /// - `destinationTensors` are the init values to use for the outer most loop.
284 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
285 /// most
286 /// loop.
287 /// - `loops` is an in-out parameter into which the generated loops are
288 /// populated.
289 static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc,
291  ArrayRef<Range> loopRanges,
292  ArrayRef<OpFoldResult> tileSizes,
293  ValueRange destinationTensors,
294  YieldTiledValuesFn tiledBodyFn,
296  // If the tile sizes are all zero, no loops are generated. Just call the
297  // callback function to handle untiled case.
298  if (llvm::all_of(tileSizes, isZeroIndex)) {
299  SmallVector<Value> tiledResults;
300  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
301  return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
302  tiledResults, resultOffsets, resultSizes);
303  }
305  return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
306  destinationTensors, tiledBodyFn, loops);
307  }
310  rewriter, loc, loopRanges, tileSizes, options.mappingVector,
311  destinationTensors, tiledBodyFn, loops);
312  }
313  return rewriter.notifyMatchFailure(loc, "unhandled loop type");
314 }
315 
316 /// Append the specified additional `newInitOperands` operands to the
317 /// loops existing `init` operands (or similar), and replace `loopOp` with
318 /// the new loop that has the additional init operands. The loop body of
319 /// this loop is moved over to the new loop. `yieldTiledValuesFn`
320 /// is called to get the new tiled values returned, and the offset
321 /// and sizes at which the tiled value is inserted into the
322 /// new region iter_args that correspond to the newly added init operands.
323 template <typename LoopType>
324 FailureOr<LoopLikeOpInterface>
325 yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter,
326  ValueRange newInitOperands,
327  YieldTiledValuesFn yieldTiledValuesFn) {
328  return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
329 }
330 
331 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`.
332 template <>
333 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
334  scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
335  YieldTiledValuesFn yieldTiledValuesFn) {
336  OpBuilder::InsertionGuard g(rewriter);
337  Location loc = loopOp.getLoc();
338  rewriter.setInsertionPoint(loopOp);
339 
340  auto inits = llvm::to_vector(loopOp.getInitArgs());
341  inits.append(newInitOperands.begin(), newInitOperands.end());
342  auto newLoop = rewriter.create<scf::ForOp>(
343  loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
344  inits, [](OpBuilder &, Location, Value, ValueRange) {});
345 
346  // Move the loop body to the new op.
347  Block *loopBody = loopOp.getBody();
348  Block *newLoopBody = newLoop.getBody();
349  rewriter.mergeBlocks(
350  loopBody, newLoopBody,
351  newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
352 
353  auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator());
354  rewriter.setInsertionPoint(yieldOp);
355 
356  SmallVector<Value> tiledValues;
357  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
358  ValueRange newRegionIterArgs =
359  newLoop.getRegionIterArgs().take_back(newInitOperands.size());
360  if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
361  newRegionIterArgs, tiledValues, resultOffsets,
362  resultSizes))) {
363  rewriter.eraseOp(newLoop);
364  return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");
365  }
366 
367  SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
368  for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
369  llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
370  resultSizes)) {
371  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
372  rewriter.getIndexAttr(1));
373  Value insert = rewriter.create<tensor::InsertSliceOp>(
374  yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
375  resultStride);
376  newYieldValues.push_back(insert);
377  }
378 
379  rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
380  rewriter.replaceOp(loopOp,
381  newLoop->getResults().take_front(loopOp.getNumResults()));
382  return cast<LoopLikeOpInterface>(newLoop.getOperation());
383 }
384 
385 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall`
386 template <>
387 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
388  scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
389  YieldTiledValuesFn yieldTiledValuesFn) {
390  OpBuilder::InsertionGuard g(rewriter);
391  Location loc = loopOp.getLoc();
392  rewriter.setInsertionPoint(loopOp);
393  auto inits = llvm::to_vector(loopOp.getOutputs());
394  inits.append(newInitOperands.begin(), newInitOperands.end());
395  auto newLoop = rewriter.create<scf::ForallOp>(
396  loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
397  loopOp.getMixedStep(), inits, loopOp.getMapping(),
398  [](OpBuilder &, Location, ValueRange) {});
399 
400  // Move the region of the current block to the newly created op.
401  Block *loopBody = loopOp.getBody();
402  Block *newLoopBody = newLoop.getBody();
403  rewriter.mergeBlocks(
404  loopBody, newLoopBody,
405  newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
406 
407  auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
408  rewriter.setInsertionPoint(terminator);
409  SmallVector<Value> tiledValues;
410  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
411  ValueRange regionIterArgs =
412  newLoop.getRegionIterArgs().take_back(newInitOperands.size());
413  if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
414  regionIterArgs, tiledValues, resultOffsets,
415  resultSizes))) {
416  rewriter.eraseOp(newLoop);
417  return rewriter.notifyMatchFailure(loopOp,
418  "failed to get yielded tiled values");
419  }
420 
421  // Update the terminator.
422  rewriter.setInsertionPointToEnd(terminator.getBody());
423 
424  for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
425  tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
426  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
427  rewriter.getIndexAttr(1));
428  rewriter.create<tensor::ParallelInsertSliceOp>(
429  terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
430  resultStride);
431  }
432 
433  rewriter.replaceOp(loopOp,
434  newLoop->getResults().take_front(loopOp.getNumResults()));
435  return cast<LoopLikeOpInterface>(newLoop.getOperation());
436 }
437 
438 /// Implementation of `yieldTiledValuesAndReplaceLoop` for
439 /// `LoopLikeOpInterface`, that just dispatches to the implementation for each
440 /// supported loop type.
441 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
442  LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
443  ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
445  loopLikeOp.getOperation())
446  .Case<scf::ForOp, scf::ForallOp>(
447  [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
449  loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
450  })
451  .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
452  return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
453  });
454 }
455 
456 /// Method to add new init values to a loop nest. Updates `loops` in-place with
457 /// new loops that use the `newInitValues`.
458 /// The outer-loops are updated to yield the new result values of the inner
459 /// loop. For the innermost loop, the call back `getNewYields` is invoked to get
460 /// the additional values to yield form the innermost loop.
461 static LogicalResult addInitOperandsToLoopNest(
463  ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
464  SmallVector<scf::ForOp> newLoops;
465  if (loops.empty())
466  return success();
467  OpBuilder::InsertionGuard g(rewriter);
468  rewriter.setInsertionPoint(loops.front());
469 
470  SmallVector<Value> ivs;
471  for (auto &loop : loops.drop_back()) {
472  rewriter.setInsertionPoint(loop);
473 
474  // if loops.size() > 1 we assume that scf.for is used for the loops.
475  auto forLoop = cast<scf::ForOp>(loop.getOperation());
476 
477  // Create a new loop with the new init values for this loop.
478  SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs());
479  newInits.append(newInitValues.begin(), newInitValues.end());
480  auto newLoop = rewriter.create<scf::ForOp>(
481  forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
482  forLoop.getStep(), newInits,
483  [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
484 
485  // Merge the body of the new loop with the body of the old loops.
486  SmallVector<Value> sourceBlockArgs;
487  sourceBlockArgs.push_back(newLoop.getInductionVar());
488  auto newRegionIterArgs = newLoop.getRegionIterArgs();
489  sourceBlockArgs.append(
490  newRegionIterArgs.begin(),
491  std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
492  rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
493  rewriter.replaceOp(
494  forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
495  loop = newLoop;
496  ivs.push_back(newLoop.getInductionVar());
497  newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
498  }
499 
500  // Update the loop body of the innermost loop to get new yield values.
501  LoopLikeOpInterface innerMostLoop = loops.back();
502  FailureOr<LoopLikeOpInterface> newInnerMostLoop =
503  yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues,
504  getNewTiledYieldsFn);
505 
506  if (failed(newInnerMostLoop))
507  return innerMostLoop.emitOpError("failed to return additional yields");
508  loops.back() = newInnerMostLoop.value();
509 
510  // Make all other loops except the innermost loops yield the values returned
511  // by the inner loop.
512  for (auto [outerLoop, innerLoop] :
513  llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
514  // Again assume that all the outer loops are scf.for operations.
515  auto outerForLoop = cast<scf::ForOp>(outerLoop);
516  auto outerLoopYield =
517  cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
518  SmallVector<Value> newYields =
519  llvm::to_vector(outerLoopYield.getOperands());
520  ValueRange additionalYields =
521  innerLoop->getResults().take_back(newInitValues.size());
522  newYields.append(additionalYields.begin(), additionalYields.end());
523  rewriter.setInsertionPoint(outerLoopYield);
524  rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
525  }
526  return success();
527 }
528 
529 /// Implementation of tiling transformation of `op` that implements the
530 /// `TilingInterface` using `scf.for` to iterate over the tiles.
531 FailureOr<scf::SCFTilingResult>
532 mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
534  OpBuilder::InsertionGuard guard(rewriter);
535  rewriter.setInsertionPointAfter(op);
536 
537  if (!options.tileSizeComputationFunction) {
538  return rewriter.notifyMatchFailure(
539  op, "missing tile size computation function");
540  }
541 
542  // 1. Get the range of the loops that are represented by the operation.
543  SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
544  size_t numLoops = iterationDomain.size();
545 
546  // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
547  // skips tiling a particular dimension. This convention is significantly
548  // simpler to handle instead of adjusting affine maps to account for missing
549  // dimensions.
550  SmallVector<OpFoldResult> tileSizes =
551  options.tileSizeComputationFunction(rewriter, op);
552  if (tileSizes.size() < iterationDomain.size()) {
553  auto zero = rewriter.getIndexAttr(0);
554  tileSizes.append(numLoops - tileSizes.size(), zero);
555  }
556 
557  // 3. If there is an interchange specified, permute the iteration domain and
558  // the tile sizes.
559  SmallVector<int64_t> interchangeVector;
560  if (!options.interchangeVector.empty()) {
561  interchangeVector = fillInterchangeVector(options.interchangeVector,
562  iterationDomain.size());
563  }
564  if (!interchangeVector.empty()) {
565  if (!isPermutationVector(interchangeVector)) {
566  return rewriter.notifyMatchFailure(
567  op, "invalid intechange vector, not a permutation of the entire "
568  "iteration space");
569  }
570 
571  applyPermutationToVector(iterationDomain, interchangeVector);
572  applyPermutationToVector(tileSizes, interchangeVector);
573  }
574 
575  FailureOr<TilingResult> tilingResult;
576  // 4. Define the lambda function used later to generate the body of the
577  // innermost tiled loop.
578  YieldTiledValuesFn innerYieldTiledValuesFn =
579  [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
580  ValueRange regionIterArgs, SmallVector<Value> &tiledResults,
583  -> LogicalResult {
584  // 4a. Compute the `offsets` and `sizes` to use for tiling.
585  SmallVector<OpFoldResult> offsets, sizes;
586  {
587  int materializedLoopNum = 0;
588  for (auto [tileSize, loopRange] :
589  llvm::zip_equal(tileSizes, iterationDomain)) {
590  if (isConstantIntValue(tileSize, 0)) {
591  offsets.push_back(loopRange.offset);
592  sizes.push_back(loopRange.size);
593  continue;
594  }
595  Value iv = ivs[materializedLoopNum++];
596  offsets.push_back(iv);
597  sizes.push_back(
598  getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
599  }
600  }
601 
602  // 4b. If interchange was provided, apply inverse of the interchange
603  // to get back the offsets/sizes in the order to be specified.
604  if (!interchangeVector.empty()) {
605  auto inversePermutation = invertPermutationVector(interchangeVector);
608  }
609 
610  // 5. Generate the tiled implementation within the inner most loop.
611 
612  // 5a. Clone the operation within the loop body.
613  auto clonedOp = cast<TilingInterface>(
614  cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
615 
616  // 5b. Early return cloned op if tiling is not happening. We can not return
617  // the original op because it could lead to
618  // `rewriter.replaceOp(op, op->getResults())` and users would get crash.
619  if (llvm::all_of(tileSizes, isZeroIndex)) {
620  tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
621  tilingResult =
622  TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults()};
623  return success();
624  }
625 
626  // 5c. Tile the cloned operation.
627  tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes);
628  if (failed(tilingResult)) {
629  rewriter.eraseOp(clonedOp);
630  return op.emitOpError("faild to tile operation");
631  }
632 
633  // 5d. Delete the cloned operation.
634  rewriter.eraseOp(clonedOp);
635 
636  // 5e. Compute the offsets at which the result values are to be inserted
637  // back into its destinations.
638  for (auto [index, tiledValue] :
639  llvm::enumerate(tilingResult->tiledValues)) {
640  tiledResults.push_back(tiledValue);
641  SmallVector<OpFoldResult> resultOffset, resultSize;
642  if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
643  resultOffset, resultSize))) {
644  for (auto op : tilingResult->tiledOps) {
645  rewriter.eraseOp(op);
646  }
647  return rewriter.notifyMatchFailure(
648  op, "failed to get slice of result produced");
649  }
650  resultOffsets.emplace_back(std::move(resultOffset));
651  resultSizes.emplace_back(std::move(resultSize));
652  }
653 
654  return success();
655  };
656 
657  // 6. Find the destination tensors to use for the operation.
658  SmallVector<Value> destinationTensors;
659  if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
660  destinationTensors))) {
661  return rewriter.notifyMatchFailure(op,
662  "unable to create destination tensors");
663  }
664 
665  // 7. Generate the tiled loops nest using the callback defined above.
667  if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
668  tileSizes, destinationTensors,
669  innerYieldTiledValuesFn, loops)))
670  return op.emitOpError("failed to generate tiling loops");
671  assert(succeeded(tilingResult) &&
672  "expected tiling result to be computed after loop generation");
673 
674  // If loops are empty, the tiled op is used as the replacement for the untiled
675  // op.
676  if (loops.empty()) {
677  return scf::SCFTilingResult{tilingResult->tiledOps, loops,
678  tilingResult->tiledValues};
679  }
680 
681  SmallVector<Value> replacements = llvm::map_to_vector(
682  loops.front()->getResults(), [](OpResult r) -> Value { return r; });
683  return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements};
684 }
685 
686 FailureOr<scf::SCFReductionTilingResult>
688  PartialReductionOpInterface op,
689  ArrayRef<OpFoldResult> tileSizes) {
690  Location loc = op.getLoc();
691  // Ops implementing PartialReductionOpInterface are expected to implement
692  // TilingInterface.
693  auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
694  SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
695  auto tileSizesVector = llvm::to_vector(tileSizes);
696  if (tileSizesVector.size() < iterationDomain.size()) {
697  auto zero = b.getIndexAttr(0);
698  tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
699  zero);
700  }
702  tilingInterfaceOp.getLoopIteratorTypes();
703 
704  SmallVector<int> reductionDims;
705  for (auto [idx, iteratorType] :
706  llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
707  if (iteratorType == utils::IteratorType::reduction)
708  reductionDims.push_back(idx);
709  }
710 
711  // 2. create the inital tensor value.
712  FailureOr<SmallVector<Value>> maybeInitTensors =
713  op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
714  reductionDims);
715  if (failed(maybeInitTensors)) {
716  return b.notifyMatchFailure(op, "Failed to create initial tensors.");
717  }
718  SmallVector<Value> &initTensors = maybeInitTensors.value();
719 
720  // 3. Define the callback to use for generating the inner most tile loop body.
721  SmallVector<Operation *> parallelTiledOps;
722  auto innerYieldTiledValuesFn =
723  [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
724  ValueRange regionIterArgs, SmallVector<Value> &tiledResult,
727  -> LogicalResult {
728  SmallVector<OpFoldResult> offsets, sizes;
729  {
730  int materializedLoopNum = 0;
731  for (auto [tileSize, loopRange] :
732  llvm::zip_equal(tileSizesVector, iterationDomain)) {
733  if (isConstantIntValue(tileSize, 0)) {
734  offsets.push_back(loopRange.offset);
735  sizes.push_back(loopRange.size);
736  continue;
737  }
738  Value iv = ivs[materializedLoopNum++];
739  offsets.push_back(iv);
740  sizes.push_back(
741  getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
742  }
743  }
744 
745  // 4a. Clone the operation.
746  {
747  auto clonedOp = cast<PartialReductionOpInterface>(
748  cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
749 
750  // 4b. Tile the cloned operation.
751  FailureOr<TilingResult> partialTilingResult =
752  clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets,
753  sizes, reductionDims);
754  if (failed(partialTilingResult)) {
755  return failure();
756  }
757  std::swap(parallelTiledOps, partialTilingResult->tiledOps);
758  std::swap(tiledResult, partialTilingResult->tiledValues);
759 
760  // 4c. Delete the cloned operation.
761  b.eraseOp(clonedOp);
762  }
763 
764  // 4d. Compute the offsets and sizes needed to insert the result of the
765  // tiled value back into destination before yielding the destination.
766  for (auto result : tiledResult) {
767  SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
768  resultOffsets.emplace_back(std::move(outOffsets));
769 
770  SmallVector<OpFoldResult> outSizes;
771  for (size_t i = 0; i < offsets.size(); i++) {
772  outSizes.push_back(tensor::getMixedSize(b, loc, result, i));
773  }
774  resultSizes.emplace_back(std::move(outSizes));
775  }
776  return success();
777  };
778 
779  // 5. Generate the tiled implementation using the destination tensors.
783  if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
784  initTensors, innerYieldTiledValuesFn, loops)))
785  return b.notifyMatchFailure(op, "failed to tile for parallel reduction");
786 
787  SmallVector<Value> replacements = llvm::map_to_vector(
788  loops.front()->getResults(), [](OpResult r) -> Value { return r; });
789 
790  // 5. Apply the merge reduction to combine all the partial values.
791  b.setInsertionPointAfter(*loops.begin());
792  FailureOr<MergeResult> mergeResult =
793  op.mergeReductions(b, loc, replacements, reductionDims);
794  if (failed(mergeResult)) {
795  return failure();
796  }
797  b.replaceOp(op, mergeResult->replacements);
798 
799  SCFReductionTilingResult reductionTilingResult;
800  std::swap(reductionTilingResult.parallelTiledOps, parallelTiledOps);
801  std::swap(reductionTilingResult.mergeOps, mergeResult->mergeOps);
802  std::swap(reductionTilingResult.initialValues, initTensors);
803  std::swap(reductionTilingResult.loops, loops);
804  std::swap(reductionTilingResult.replacements, mergeResult->replacements);
805 
806  return reductionTilingResult;
807 }
808 
809 //===----------------------------------------------------------------------===//
810 // tileConsumerAndFuseProducersUsingSCF implementation.
811 //===----------------------------------------------------------------------===//
812 
813 /// Return the untiled producer whose slice is used in a tiled consumer. The
814 /// method traverses the tile loop nest (`loops`) if needed, and returns the
815 /// `iter_args` of the outer most that is encountered. Traversing the iter_args
816 /// indicates that this is a destination operand of the consumer. If there was
817 /// no loop traversal needed, the second value of the returned tuple is empty.
818 static std::tuple<OpResult, std::optional<OpOperand *>>
821  std::optional<OpOperand *> destinationIterArg;
822  auto loopIt = loops.rbegin();
823  while (auto iterArg = dyn_cast<BlockArgument>(source->get())) {
824  auto loop = *loopIt;
825  if (iterArg.getOwner()->getParentOp() != loop)
826  break;
827  source = loop.getTiedLoopInit(iterArg);
828  loopIt++;
829  }
830  if (loopIt == loops.rend())
831  destinationIterArg = source;
832  return {dyn_cast<OpResult>(source->get()), destinationIterArg};
833 }
834 
835 /// Implementation of fusing producer of a single slice by computing the
836 /// slice of the producer in-place.
837 std::optional<scf::SCFFuseProducerOfSliceResult>
839  RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
841  // 1. Get the producer of the source (potentially walking through
842  // `iter_args` of nested `scf.for`)
843  auto [fusableProducer, destinationInitArg] =
844  getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
845  loops);
846  if (!fusableProducer)
847  return std::nullopt;
848  unsigned resultNumber = fusableProducer.getResultNumber();
849 
850  OpBuilder::InsertionGuard g(rewriter);
851  rewriter.setInsertionPoint(candidateSliceOp);
852 
853  // 2. Clone the fused producer
854  // 2a. Compute the destination operands to use for the cloned operation.
855  SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
856  Operation *fusableProducerOp = fusableProducer.getOwner();
857  if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
859  rewriter, fusableProducerOp->getLoc(), fusableProducerOp,
860  origDestinationTensors)))
861  return std::nullopt;
862 
863  clonedOpDestinationTensors = origDestinationTensors;
864  if (destinationInitArg &&
865  isa<DestinationStyleOpInterface>(fusableProducerOp)) {
866  // 2b. If the producer is also destination style, then to maintain the
867  // destination passing style, update the destination of the producer to be
868  // the source of the slice.
869  clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
870  }
871  // 2c. Clone the fused producer.
872  Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
873  rewriter, fusableProducerOp, clonedOpDestinationTensors);
874  // 2d. Update the source of the candidateSlice to be the cloned producer.
875  // Easier to just clone the slice with different source since replacements
876  // and DCE of cloned ops becomes easier
877  SmallVector<Value> candidateSliceOpOperands =
878  llvm::to_vector(candidateSliceOp->getOperands());
879  candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
880  tensor::ExtractSliceOp clonedCandidateSliceOp =
881  mlir::clone(rewriter, candidateSliceOp,
882  candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
883 
884  // 3. Generate the tiled implementation of the producer of the source
885  FailureOr<TilingResult> tileAndFuseResult =
887  rewriter, clonedCandidateSliceOp,
888  clonedProducerOp->getResult(resultNumber));
889  if (failed(tileAndFuseResult))
890  return std::nullopt;
891  // Note: Do not delete the candidateSliceOp, since its passed in from the
892  // caller.
893  rewriter.replaceAllUsesWith(candidateSliceOp,
894  tileAndFuseResult->tiledValues[0]);
895  rewriter.eraseOp(clonedCandidateSliceOp);
896  rewriter.eraseOp(clonedProducerOp);
897 
898  // 3. If the slice is for a destination operand, for example,
899  //
900  // ```mlir
901  // %0 = linalg.init
902  // %1 = linalg.fill .. outs(%0 : )
903  // %2 = scf.for .. iter_args(%arg0 = %1) {
904  // %3 = scf.for .. iter_args(%arg1 = %arg0) {
905  // %4 = tensor.extract_slice %arg1 [..]
906  // .. = linalg.matmul .. outs(%4 : )
907  // }
908  // }
909  // ```
910  //
911  // the IR is currently
912  //
913  // ```
914  // %0 = linalg.init
915  // %1 = linalg.fill
916  // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
917  // %3 = scf.for .. iter_args(%arg1 = %arg0) {
918  // %4 = tensor.extract_slice %arg1[..]
919  // %5 = linalg.fill .. outs(%4 : )
920  // .. = linalg.matmul .. outs(%5 : )
921  // }
922  // }
923  // ```
924  //
925  // The untiled `linalg.fill` is still used as the `init_value` since it
926  // was originally a destination operand of the untiled `linalg.matmul`.
927  // When fusing an operand that is a destination operand, the iter_arg of
928  // the outer most loop should be changed to use the destination of the
929  // fused operation. With this the IR will be.
930  //
931  // ```
932  // %0 = linalg.init
933  // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
934  // %2 = scf.for .. iter_args(%arg1 = %arg0) {
935  // %3 = tensor.extract_slice %arg1[..]
936  // %4 = linalg.fill .. outs(%3 : )
937  // .. = linalg.matmul .. outs(%4 : )
938  // }
939  // }
940  // ```
941  if (destinationInitArg &&
942  isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
943  loops.front()
944  ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
945  .set(origDestinationTensors[resultNumber]);
946  }
947  return scf::SCFFuseProducerOfSliceResult{fusableProducer,
948  tileAndFuseResult->tiledValues[0],
949  tileAndFuseResult->tiledOps};
950 }
951 
952 /// Reconstruct the fused producer from within the tiled-and-fused code.
954  RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
955  scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
957  ArrayRef<unsigned> yieldResultNumber) {
958  if (loops.empty())
959  return success();
960 
961  Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
962  *tiledOwner = fusedProducerInfo.tiledOps[0];
963 
964  Location loc = originalOwner->getLoc();
965  // a. collect all init Value to be appended
966  SmallVector<unsigned> initNumberList =
967  yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
968  0, originalOwner->getNumResults()))
969  : llvm::to_vector(yieldResultNumber);
970  SmallVector<Value> initValueList;
971  for (const auto &resultNumber : initNumberList) {
972  FailureOr<Value> initValue = tensor::getOrCreateDestination(
973  rewriter, loc, originalOwner->getResult(resultNumber));
974  if (succeeded(initValue)) {
975  initValueList.push_back(initValue.value());
976  } else {
977  return failure();
978  }
979  }
980 
981  YieldTiledValuesFn newYieldValuesFn =
982  [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
983  ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
985  SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
986  OpBuilder::InsertionGuard g(innerRewriter);
987 
988  // get sliceOp tile information
989  SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
990  sliceSizes = sliceOp.getMixedSizes();
991 
992  // expect all strides of sliceOp being 1
993  if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
994  return !isConstantIntValue(ofr, 1);
995  }))
996  return failure();
997 
998  unsigned sliceResultNumber =
999  fusedProducerInfo.origProducer.getResultNumber();
1000 
1001  auto tilableOp = cast<TilingInterface>(originalOwner);
1002  // b. get iterDomain Offset and Sizes based on sliceOp tile
1003  SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
1004  // skip tensor.pack/unpack/pad, which expects single opResult
1005  if (tilableOp->getNumResults() > 1 &&
1006  failed(tilableOp.getIterationDomainTileFromResultTile(
1007  rewriter, sliceResultNumber, sliceOffset, sliceSizes,
1008  iterDomainOffset, iterDomainSizes))) {
1009  // In theory, it is unnecessary to raise an error here. Actually although
1010  // it fails to reconstruct the result tensor, it should not broke current
1011  // fusion anyway. The reason why we must return failure currently is that
1012  // the callback function `newYieldValuesFn` will be called after new init
1013  // operand(s) has already been appended. It will take more refactoring to
1014  // make sure the init operands are added consistently in the future. For
1015  // more details, please refer to:
1016  // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
1017  return failure();
1018  }
1019 
1020  // c. calculate offsets and sizes info of all OpResults respectively based
1021  // on iteration Domain Tile
1022  SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
1023  for (const auto &resultNumber : initNumberList) {
1024  if (resultNumber == sliceResultNumber) {
1025  offsetList.push_back(sliceOffset);
1026  sizesList.push_back(sliceSizes);
1027  } else {
1028  assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
1029  // infer result tile according to the iteration domain tile
1030  SmallVector<OpFoldResult> offset, sizes;
1031  if (failed(tilableOp.getResultTilePosition(
1032  rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
1033  offset, sizes))) {
1034  return failure();
1035  }
1036  offsetList.push_back(offset);
1037  sizesList.push_back(sizes);
1038  }
1039  }
1040 
1041  // d. create `extract_slice` for `iter_args` for DPS operation if necessary
1042  if (auto tiledDestStyleOp =
1043  dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
1044  rewriter.setInsertionPoint(tiledDestStyleOp);
1045  for (const auto &&[index, newRegionArg] :
1046  llvm::enumerate(newRegionIterArgs)) {
1047  auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
1048  loc, newRegionArg, offsetList[index], sizesList[index],
1049  SmallVector<OpFoldResult>(offsetList[index].size(),
1050  rewriter.getIndexAttr(1)));
1051  unsigned resultNumber = initNumberList[index];
1052  rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
1053  tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
1054  });
1055  }
1056  }
1057 
1058  // e. prepare tiled offset and sizes for later `insert_slice` creation by
1059  // caller
1060  Block *block = rewriter.getInsertionPoint()->getBlock();
1061  rewriter.setInsertionPoint(block->getTerminator());
1062  for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
1063  tiledResult.push_back(tiledOwner->getResult(resultNumber));
1064  tiledOffset.emplace_back(offsetList[index]);
1065  tiledSizes.emplace_back(sizesList[index]);
1066  }
1067  return success();
1068  };
1069 
1070  return addInitOperandsToLoopNest(rewriter, loops, initValueList,
1071  newYieldValuesFn);
1072 }
1073 
1074 /// Implementation of tile consumer and fuse producer greedily.
1075 FailureOr<scf::SCFTileAndFuseResult>
1077  RewriterBase &rewriter, TilingInterface consumer,
1079  // This transformation is only valid for ops that return values (i.e. not
1080  // valid to use with operations that have memref operands).
1081  if (!consumer->getNumResults()) {
1082  return rewriter.notifyMatchFailure(
1083  consumer, "invalid pattern for op with no results");
1084  }
1085 
1086  // 1. First tile the consumer.
1087  SetVector<Operation *> fusedProducers, tiledAndFusedOps;
1088  llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
1089 
1090  FailureOr<scf::SCFTilingResult> tilingResult =
1091  tileUsingSCF(rewriter, consumer, options.tilingOptions);
1092 
1093  if (failed(tilingResult))
1094  return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
1095  for (auto *tiledOp : tilingResult->tiledOps)
1096  tiledAndFusedOps.insert(tiledOp);
1097 
1098  // If there are no loops generated, fusion is immaterial.
1099  auto &loops = tilingResult->loops;
1100  if (loops.empty()) {
1101  DenseMap<Value, Value> replacements;
1102  for (auto [origVal, replacement] :
1103  llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
1104  replacements[origVal] = replacement;
1105  }
1106  return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1107  replacements};
1108  }
1109 
1110  // To keep track of replacements for now just record the map from the original
1111  // untiled value to the result number of the for loop. Since the loop gets
1112  // potentially replaced during fusion, keeping the value directly wont work.
1113  DenseMap<Value, size_t> origValToResultNumber;
1114  for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
1115  origValToResultNumber[result] = index;
1116  }
1117 
1118  // 2. Typically, the operands of the tiled operation are slices of the
1119  // operands of the untiled operation. These are expressed in IR using
1120  // `tensor.extract_slice` operations with source being the operands of the
1121  // untiled operation. Create a worklist of these `tensor.extract_slice`
1122  // operations. If the producers of the source of the `tensor.extract_slice`
1123  // can be tiled such that the tiled value is generated in-place, that
1124  // effectively tiles + fuses the operations.
1125  auto addCandidateSlices = [](Operation *fusedOp,
1126  std::deque<tensor::ExtractSliceOp> &candidates) {
1127  for (Value operand : fusedOp->getOperands())
1128  if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
1129  candidates.push_back(sliceOp);
1130  };
1131 
1132  std::deque<tensor::ExtractSliceOp> candidates;
1133  addCandidateSlices(tiledAndFusedOps.back(), candidates);
1134  OpBuilder::InsertionGuard g(rewriter);
1135  while (!candidates.empty()) {
1136  // Traverse the slices in BFS fashion.
1137  tensor::ExtractSliceOp candidateSliceOp = candidates.front();
1138  candidates.pop_front();
1139 
1140  // Find the original producer of the slice.
1141  auto [fusableProducer, destinationInitArg] =
1142  getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
1143  loops);
1144  if (!fusableProducer)
1145  continue;
1146 
1147  auto [fuseSlice, yieldReplacement] = options.fusionControlFn(
1148  candidateSliceOp, fusableProducer, destinationInitArg.has_value());
1149  if (!fuseSlice)
1150  continue;
1151 
1152  // The operands of the fused producer might themselved be slices of
1153  // values produced by operations that implement the `TilingInterface`.
1154  // Add these operations to the worklist.
1155  std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1156  tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, loops);
1157  if (!fusedResult)
1158  continue;
1159 
1160  if (yieldReplacement) {
1161  // Reconstruct and yield all opResult of fusableProducerOp by default. The
1162  // caller can specific which one to yield by designating optional argument
1163  // named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
1164  Operation *fusableProducerOp = fusableProducer.getOwner();
1166  rewriter, candidateSliceOp, fusedResult.value(), loops))) {
1167  return rewriter.notifyMatchFailure(
1168  fusableProducerOp, "failed to replacement value for this "
1169  "operation from within the tiled loop");
1170  }
1171  for (auto [index, result] :
1172  llvm::enumerate(fusableProducerOp->getResults())) {
1173  origValToResultNumber[result] = loops.front()->getNumResults() -
1174  fusableProducerOp->getNumResults() +
1175  index;
1176  }
1177  }
1178 
1179  if (Operation *tiledAndFusedOp =
1180  fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1181  fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1182  tiledAndFusedOps.insert(tiledAndFusedOp);
1183  addCandidateSlices(tiledAndFusedOp, candidates);
1184  }
1185  }
1186 
1187  DenseMap<Value, Value> replacements;
1188  for (auto [origVal, resultNumber] : origValToResultNumber) {
1189  replacements[origVal] = loops.front()->getResult(resultNumber);
1190  }
1191 
1192  return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1193  replacements};
1194 }
1195 
1196 //===----------------------------------------------------------------------===//
1197 // tileAndFuseConsumerUsingSCF implementation.
1198 //===----------------------------------------------------------------------===//
1199 
1200 /// A utility function that checks whether the only use of the result of a
1201 /// tensor.insert_slice op is in a scf.yield op.
1202 static LogicalResult
1203 checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
1204  Value result = candidateSliceOp.getResult();
1205  Value::use_range uses = result.getUses();
1206  if (!llvm::hasSingleElement(uses)) {
1207  LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
1208  return failure();
1209  }
1210  OpOperand &operandUse = (*uses.begin());
1211  Operation *userOp = operandUse.getOwner();
1212  if (!isa<scf::YieldOp>(userOp)) {
1213  LLVM_DEBUG(llvm::dbgs()
1214  << "Expected scf.yield to be the only user, but got -> "
1215  << (*userOp));
1216  return failure();
1217  }
1218  if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
1219  LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
1220  "be in the same block\n");
1221  return failure();
1222  }
1223  return success();
1224 }
1225 
1226 /// Fetches the OpOperand of the only user (and use) of the value `val` which
1227 /// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
1228 /// failure otherwise.
1229 static FailureOr<OpOperand *> getConsumerFromUses(Value val,
1230  Block *containingOpBlock) {
1231  // Step 1. Check that the value has exactly one use.
1232  if (!llvm::hasSingleElement(val.getUses()))
1233  return failure();
1234  // Step 2. Get uses.
1235  OpOperand &operand = (*val.getUses().begin());
1236  Operation *consumerOp = operand.getOwner();
1237  // TODO: We have to init result of consumer before scf.for, use
1238  // DestinationStyleOpInterface to get result shape from init for now.
1239  // Add support for other op such as op has InferTypeOpInterface.
1240  if (!isa<TilingInterface>(consumerOp) ||
1241  !isa<DestinationStyleOpInterface>(consumerOp))
1242  return failure();
1243  if (containingOpBlock != consumerOp->getBlock())
1244  return failure();
1245  return &operand;
1246 }
1247 
1248 /// Fetch the untiled consumer of a scf.for's result which is yielded by a
1249 /// tensor.insert_slice. This function makes the following assumptions :
1250 /// 1. tensor.insert_slice has scf.yield as its only user.
1251 /// 2. scf.for's corresponding result has only one use.
1252 static FailureOr<OpOperand *>
1253 getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
1254  if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
1255  return failure();
1256  Value sliceResult = candidateSliceOp.getResult();
1257  // Step 1. Fetch the corresponding output.
1258  OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
1259  unsigned resultNumber = yieldOpOperand.getOperandNumber();
1260  // Step 2. Check containing op is scf.for.
1261  Operation *containingOp = candidateSliceOp->getParentOp();
1262  auto forOp = dyn_cast<scf::ForOp>(containingOp);
1263  if (!forOp)
1264  return failure();
1265  Value resultingValue = forOp->getResult(resultNumber);
1266 
1267  return getConsumerFromUses(resultingValue, containingOp->getBlock());
1268 }
1269 
1270 /// Fetch the first untiled consumer of a scf.forall's result which is yielded
1271 /// by a tensor.parallel_insert_slice.
1272 static FailureOr<OpOperand *>
1273 getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
1274  // Step 1. Fetch the corresponding output
1275  Value sliceDest = candidateSliceOp.getDest();
1276  auto iterArg = dyn_cast<BlockArgument>(sliceDest);
1277  if (!iterArg)
1278  return failure();
1279  Operation *containingOp = iterArg.getOwner()->getParentOp();
1280  if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
1281  return failure();
1282  // Step 2. Check that the containing op is scf.forall.
1283  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
1284  if (!forallOp)
1285  return failure();
1286  Value resultingValue =
1287  forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
1288 
1289  return getConsumerFromUses(resultingValue, containingOp->getBlock());
1290 }
1291 
1292 /// This utility currently checks whether the loop either :-
1293 /// 1. Yields exactly one result.
1294 /// 2. Has consumer op as its first user and other users to be in the same
1295 /// containing block as that of consumer op's. Currently we clone the loop op
1296 /// right before the consumer op in order to maintain a valid def-use chain.
1297 /// This utility thus helps ensuring that no invalid IR is formed due to the
1298 /// same.
1299 static LogicalResult checkAssumptionForLoop(Operation *loopOp,
1300  Operation *consumerOp) {
1301  // Check if the loop op yields one result.
1302  if (loopOp->getNumResults() == 1)
1303  return success();
1304  // Check if the consumerOp is the first user of the loopOp and if other users
1305  // are in the same containing block as that of consumer op's.
1306  Block *parentBlock = consumerOp->getBlock();
1307  for (Operation *userOp : loopOp->getUsers()) {
1308  if (userOp == consumerOp)
1309  continue;
1310  if (parentBlock != userOp->getBlock() ||
1311  !consumerOp->isBeforeInBlock(userOp))
1312  return failure();
1313  }
1314  return success();
1315 }
1316 
1317 /// A utility to fetch an untiled consumer of
1318 /// tensor.insert_slice/tensor.parallel_insert_slice.
1319 static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
1320  if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
1321  return getUntiledConsumerFromSlice(insertSlice);
1322  } else if (auto parallelInsertSlice =
1323  dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
1324  return getUntiledConsumerFromSlice(parallelInsertSlice);
1325  } else {
1326  return failure();
1327  }
1328 }
1329 
1330 /// After fusing consumer into scf.for we want to modify the scf.yield operation
1331 /// to reflect the same by returning the values yielded by the tiled consumer.
1332 static void
1333 fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
1334  TilingResult &tilingResult,
1335  ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
1336  ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
1337  ArrayRef<BlockArgument> bbArgs) {
1338  scf::YieldOp oldTerminatorOp =
1339  cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
1340  unsigned totalOldResults = oldTerminatorOp->getNumResults();
1341  unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults();
1342  SmallVector<Value> newYieldOperands;
1343  newYieldOperands.reserve(totalOldResults + totalTiledResults);
1344  for (auto oldResult : oldTerminatorOp.getResults()) {
1345  newYieldOperands.push_back(oldResult);
1346  }
1347  rewriter.setInsertionPointAfter(oldTerminatorOp);
1348  Location loc = newForOp.getLoc();
1349  for (auto [tiledResult, bbArg, resultOffset, resultSize] :
1350  llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs,
1351  resultOffsets, resultSizes)) {
1352  SmallVector<OpFoldResult> strides(resultOffset.size(),
1353  rewriter.getIndexAttr(1));
1354  Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
1355  loc, tiledResult, bbArg, resultOffset, resultSize, strides);
1356  newYieldOperands.push_back(newInsertSliceOp);
1357  }
1358  rewriter.create<scf::YieldOp>(loc, newYieldOperands);
1359  rewriter.eraseOp(oldTerminatorOp);
1360 }
1361 
1362 /// After fusing consumer into scf.forall we want to yield each of the resulting
1363 /// values by the tiled consumer within scf.forall.in_parallel region.
1364 static void
1365 fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
1366  SmallVector<Value> tiledResults,
1367  ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
1368  ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
1369  ArrayRef<BlockArgument> bbArgs) {
1370  scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
1371  rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
1372  Location firstYieldOpLoc =
1373  (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
1374  for (auto [tiledResult, bbArg, resultOffset, resultSize] :
1375  llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
1376  SmallVector<OpFoldResult> strides(resultOffset.size(),
1377  rewriter.getIndexAttr(1));
1378  rewriter.create<tensor::ParallelInsertSliceOp>(
1379  firstYieldOpLoc, tiledResult, bbArg, resultOffset, resultSize, strides);
1380  }
1381 }
1382 
1383 /// Implementation of fusing consumer of a single slice by computing the
1384 /// slice of the consumer in-place for scf loop.
1385 FailureOr<scf::SCFFuseConsumerOfSliceResult>
1387  Operation *candidateSliceOp) {
1388  if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
1389  candidateSliceOp))
1390  return failure();
1391 
1392  bool isInsertSliceOp = isa<tensor::InsertSliceOp>(candidateSliceOp);
1393 
1394  // 1. Get the consumer of scf.for for the result yielded by
1395  // tensor.insert_slice/parallel_insert_slice.
1396  FailureOr<OpOperand *> maybeConsumerOpOperand =
1397  getUntiledConsumerFromSlice(candidateSliceOp);
1398  if (failed(maybeConsumerOpOperand)) {
1399  return rewriter.notifyMatchFailure(candidateSliceOp,
1400  "could not fetch consumer to fuse");
1401  }
1402  OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
1403  Operation *consumerOp = consumerOpOperand->getOwner();
1404  unsigned operandNumber = consumerOpOperand->getOperandNumber();
1405  unsigned resultNumber = 0;
1406  if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
1407  resultNumber = producerResult.getResultNumber();
1408  } else {
1409  return rewriter.notifyMatchFailure(
1410  consumerOp, "consumer op's operand doesn't seem to be an OpResult");
1411  }
1412 
1413  Operation *oldLoopOp = nullptr;
1414  SmallVector<Value> newOuts;
1415  Block *oldLoopBody = nullptr;
1416  unsigned initSize = 0;
1417  unsigned rank = 1;
1418  if (isInsertSliceOp) {
1419  auto forOp = candidateSliceOp->getParentOfType<scf::ForOp>();
1420  oldLoopOp = forOp;
1421  llvm::append_range(newOuts, forOp.getInits());
1422  oldLoopBody = forOp.getBody();
1423  initSize = forOp.getInits().size();
1424  } else {
1425  auto forallOp = candidateSliceOp->getParentOfType<scf::ForallOp>();
1426  oldLoopOp = forallOp;
1427  llvm::append_range(newOuts, forallOp.getOutputs());
1428  oldLoopBody = forallOp.getBody();
1429  initSize = forallOp.getOutputs().size();
1430  rank = forallOp.getRank();
1431  }
1432 
1433  if (failed(checkAssumptionForLoop(oldLoopOp, consumerOp))) {
1434  return rewriter.notifyMatchFailure(
1435  oldLoopOp, "containing loop op should either yield just one value or "
1436  "have the consumer op as its first user");
1437  }
1438 
1439  OpBuilder::InsertionGuard g(rewriter);
1440 
1441  // 2. Check consumer is not using scf loop's output as init.
1442  auto dstOp = cast<DestinationStyleOpInterface>(consumerOp);
1443  SmallVector<Value> dpsInits =
1444  llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
1445  if (llvm::is_contained(dpsInits, oldLoopOp->getResult(resultNumber))) {
1446  return rewriter.notifyMatchFailure(
1447  consumerOp,
1448  "consumer op taking the result of scf.for as init is not supported");
1449  }
1450  newOuts.append(dpsInits);
1451 
1452  Location loc = oldLoopOp->getLoc();
1453 
1454  // 3. Create new scf loop op.
1455  rewriter.setInsertionPoint(consumerOp);
1456  Operation *newLoopOp = nullptr;
1457  Block *newLoopBody = nullptr;
1458  if (isInsertSliceOp) {
1459  auto forOp = cast<scf::ForOp>(oldLoopOp);
1460  auto newForOp = rewriter.create<scf::ForOp>(loc, forOp.getLowerBound(),
1461  forOp.getUpperBound(),
1462  forOp.getStep(), newOuts);
1463  newLoopOp = newForOp;
1464  newLoopBody = newForOp.getBody();
1465  } else {
1466  auto forallOp = cast<scf::ForallOp>(oldLoopOp);
1467  auto newForallOp = rewriter.create<scf::ForallOp>(
1468  loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1469  forallOp.getMixedStep(), newOuts, forallOp.getMapping());
1470  newLoopOp = newForallOp;
1471  rewriter.eraseOp(newForallOp.getTerminator());
1472  newLoopBody = newForallOp.getBody();
1473  }
1474 
1475  // 4. Move the loop body to the new op.
1476  unsigned oldNumArguments = oldLoopBody->getNumArguments();
1477  rewriter.mergeBlocks(oldLoopBody, newLoopBody,
1478  newLoopBody->getArguments().take_front(oldNumArguments));
1479 
1480  // 5. Set insertion point before terminator op of the loop and create a new
1481  // tensor.insert_slice. In the scf.for case this is a clone of the
1482  // candidateSliceOp whereas in the scf.forall case this is created from the
1483  // operands of tensor.parallel_insert_slice.
1484  tensor::InsertSliceOp clonedInsertSliceOp;
1485  if (auto sliceOp =
1486  dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
1487  auto newForallOp = cast<scf::ForallOp>(newLoopOp);
1488  rewriter.setInsertionPoint(newForallOp.getTerminator());
1489  clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
1490  loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
1491  sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
1492  } else {
1493  rewriter.setInsertionPoint(candidateSliceOp);
1494  clonedInsertSliceOp =
1495  cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
1496  }
1497 
1498  // 6.a. Clone consumer op.
1499  auto newForOpBlockArgsForConsumerDest =
1500  newLoopBody->getArguments().drop_front(oldNumArguments);
1501  auto clonedConsumerOp = cast<TilingInterface>(cloneOpAndUpdateDestinationArgs(
1502  rewriter, consumerOp, newForOpBlockArgsForConsumerDest));
1503 
1504  // 6.b. Replace all uses of the loop result with the result of the cloned
1505  // tensor.insert_slice.
1506  OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
1507  rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
1508  operandToReplace.set(clonedInsertSliceOp.getResult());
1509  });
1510 
1511  // 7 - Perform tiling of the cloned consumer and replace the operand at
1512  // `operandNumber` with the source of the cloned tensor.insert_slice op.
1513  auto ossSliceOp =
1514  cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
1515  FailureOr<TilingResult> tileAndFuseResult =
1517  rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
1518  if (failed(tileAndFuseResult)) {
1519  return failure();
1520  }
1521  rewriter.replaceAllUsesWith(
1522  tileAndFuseResult->tiledOps[0]->getOperand(operandNumber),
1523  clonedInsertSliceOp.getSource());
1524 
1525  // 8 - Extract offset/sizes/strides required to create the
1526  // tensor.insert_slice/parallel_insert_slice for each result of the consumer.
1527  SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
1528  SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
1529  SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
1530 
1531  // 9. Check all insert stride is 1.
1532  if (llvm::any_of(strides, [](OpFoldResult stride) {
1533  return !isConstantIntValue(stride, 1);
1534  })) {
1535  return rewriter.notifyMatchFailure(
1536  candidateSliceOp, "containingOp's result yield with stride");
1537  }
1538 
1539  // 10. Try to get iter domain position from input position.
1540  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
1541  if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
1542  rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
1543  iterDomainSizes))) {
1544  return rewriter.notifyMatchFailure(
1545  clonedConsumerOp, "can't get iter domain position from input position");
1546  }
1547 
1548  // 11. Try to fetch the offset and size for all results of the cloned
1549  // consumer. This would then be used to form the corresponding
1550  // tensor.insert_slice/parallel_insert_slice later.
1551  unsigned totalNumResultsOfConsumer = clonedConsumerOp->getNumResults();
1553  totalNumResultsOfConsumer);
1554  SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
1555  for (auto [idx, v] : llvm::enumerate(clonedConsumerOp->getResults())) {
1556  if (failed(clonedConsumerOp.getResultTilePosition(
1557  rewriter, idx, iterDomainOffsets, iterDomainSizes,
1558  resultOffsets[idx], resultSizes[idx]))) {
1559  return rewriter.notifyMatchFailure(
1560  clonedConsumerOp,
1561  "can't get result domain position from iter domain position");
1562  }
1563  }
1564 
1565  auto arrayRefOffsets = ArrayRef<SmallVector<OpFoldResult>>(resultOffsets);
1566  auto arrayRefSizes = ArrayRef<SmallVector<OpFoldResult>>(resultSizes);
1567  if (isInsertSliceOp) {
1568  auto newForOp = cast<scf::ForOp>(newLoopOp);
1570  rewriter, newForOp, *tileAndFuseResult, arrayRefOffsets, arrayRefSizes,
1571  newForOp.getBody()->getArguments().drop_front(1 + initSize));
1572  } else {
1573  auto newForallOp = cast<scf::ForallOp>(newLoopOp);
1575  rewriter, newForallOp, tileAndFuseResult->tiledOps[0]->getResults(),
1576  arrayRefOffsets, arrayRefSizes,
1577  newForallOp.getBody()->getArguments().drop_front(rank + initSize));
1578  }
1579 
1580  // 12. Replace the result of scf loop and consumer op with new loop's results.
1581  for (auto &&[oldResult, newResult] :
1582  llvm::zip_first(oldLoopOp->getResults(), newLoopOp->getResults())) {
1583  rewriter.replaceAllUsesWith(oldResult, newResult);
1584  }
1585 
1586  for (auto &&[oldResult, newResult] :
1587  llvm::zip(consumerOp->getResults(),
1588  newLoopOp->getResults().drop_front(initSize))) {
1589  rewriter.replaceAllUsesWith(oldResult, newResult);
1590  }
1591 
1592  // 13. Need to erase the old scf loop and the cloned consumer op.
1593  rewriter.eraseOp(oldLoopOp);
1594  rewriter.eraseOp(clonedConsumerOp);
1595 
1597  consumerOpOperand,
1598  &(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
1599  tileAndFuseResult->tiledOps};
1600 }
1601 
1602 //===----------------------------------------------------------------------===//
1603 // lowerToLoopsUsingSCFForOp implementation.
1604 //===----------------------------------------------------------------------===//
1605 
1606 FailureOr<SmallVector<scf::ForOp>>
1608  TilingInterface op) {
1609  // TODO: Handle cases where the op has results if needed.
1610  if (op->getNumResults() > 0) {
1611  return rewriter.notifyMatchFailure(
1612  op, "unable to lower to loops operations with return values");
1613  }
1614 
1615  SmallVector<Range> domain = op.getIterationDomain(rewriter);
1616  SmallVector<Value> ivs;
1618  Location loc = op.getLoc();
1619  for (auto loopRange : domain) {
1620  Value offsetVal =
1621  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
1622  Value sizeVal =
1623  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
1624  Value strideVal =
1625  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
1626  auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
1627  strideVal, ValueRange{});
1628  loops.push_back(loop);
1629  ivs.push_back(loop.getInductionVar());
1630  rewriter.setInsertionPoint(loop.getBody()->getTerminator());
1631  }
1632  if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
1633  return failure();
1634  }
1635  return loops;
1636 }
static llvm::ManagedStatic< PassManagerOptions > options
static LogicalResult checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp)
This utility currently checks whether the loop either :-.
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static LogicalResult checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp)
A utility function that checks whether the only use of the result of a tensor.insert_slice op is in a...
static void fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp, TilingResult &tilingResult, ArrayRef< SmallVector< OpFoldResult >> &resultOffsets, ArrayRef< SmallVector< OpFoldResult >> &resultSizes, ArrayRef< BlockArgument > bbArgs)
After fusing consumer into scf.for we want to modify the scf.yield operation to reflect the same by r...
std::function< LogicalResult(RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, SmallVector< Value > &tiledValues, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> YieldTiledValuesFn
A function that allows returning additional yielded values during yieldTiledValuesAndReplace.
static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using the loop construct specifed in options.
static FailureOr< OpOperand * > getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp)
Fetch the untiled consumer of a scf.for's result which is yielded by a tensor.insert_slice.
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Append the specified additional newInitOperands operands to the loops existing init operands (or simi...
static LogicalResult generateLoopNestUsingForOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ValueRange destinationTensors, YieldTiledValuesFn yieldTiledValuesFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.for operation.
static LogicalResult generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< Attribute > mappingVector, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.forall operation.
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< LoopLikeOpInterface > loops)
Return the untiled producer whose slice is used in a tiled consumer.
static bool tileDividesIterationDomain(Range loopRange)
static Operation * cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *op, ValueRange newDestArgs)
Clones the operation and updates the destination if the operation implements the DestinationStyleOpIn...
static LogicalResult addInitOperandsToLoopNest(RewriterBase &rewriter, MutableArrayRef< LoopLikeOpInterface > loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn)
Method to add new init values to a loop nest.
static FailureOr< OpOperand * > getConsumerFromUses(Value val, Block *containingOpBlock)
Fetches the OpOperand of the only user (and use) of the value val which implements TilingInterface an...
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, Value iv, OpFoldResult tileSize)
Returns the bounded tile size given the current iv, loopRange and tileSize, i.e., min(tileSize,...
static void fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp, SmallVector< Value > tiledResults, ArrayRef< SmallVector< OpFoldResult >> &resultOffsets, ArrayRef< SmallVector< OpFoldResult >> &resultSizes, ArrayRef< BlockArgument > bbArgs)
After fusing consumer into scf.forall we want to yield each of the resulting values by the tiled cons...
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
Definition: Block.h:31
unsigned getNumArguments()
Definition: Block.h:126
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgListType getArguments()
Definition: Block.h:85
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:128
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:277
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
void set(IRValueT newValue)
Set the current value being used by this operand.
Definition: UseDefLists.h:163
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
This class helps build Operations.
Definition: Builders.h:210
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:448
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:559
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:434
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:401
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:439
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:415
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:466
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Value.h:212
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1298
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
LogicalResult yieldReplacementForFusedProducer(RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef< LoopLikeOpInterface > loops, ArrayRef< unsigned > yieldResultNumber=ArrayRef< unsigned >{})
Reconstruct the fused producer from within the tiled-and-fused code.
FailureOr< scf::SCFReductionTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
FailureOr< scf::SCFFuseConsumerOfSliceResult > tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp)
Implementation of fusing consumer of a single slice by computing the slice of the consumer in-place f...
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
std::optional< SCFFuseProducerOfSliceResult > tileAndFuseProducerOfSlice(RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing producer of a single slice by computing the slice of the producer in-place.
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Method to swap an tensor.extract_slice with its producer when the producer implements the TilingInter...
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:55
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:74
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:109
FailureOr< TilingResult > replaceInsertSliceWithTiledConsumer(OpBuilder &builder, OffsetSizeAndStrideOpInterface sliceOp, OpOperand &consumerOp)
Method to swap an tensor.insert_slice with its consumer when the consumer implements the TilingInterf...
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:768
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
Container for result values of tiling.
SmallVector< Operation * > tiledOps
Fuse the consumer of the source of candidateSliceOp by computing the required slice of the consumer i...
Fuse the producer of the source of candidateSliceOp by computing the required slice of the producer i...
Transformation information returned after reduction tiling.
SmallVector< Value > replacements
The replacements to use for the results of the tiled operation.
SmallVector< Value > initialValues
Initial values used for reduction.
SmallVector< Operation * > parallelTiledOps
The partial reduction tiled op generated.
SmallVector< LoopLikeOpInterface > loops
The loop operations that iterate over the tiles.
SmallVector< Operation * > mergeOps
The final reduction operation merging all the partial reductions.
Options used to control tile + fuse.
Transformation information returned after tile and fuse.
Options to use to control tiling.
SCFTileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes for each operation.
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > ts)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.