MLIR  19.0.0git
TileUsingInterface.cpp
Go to the documentation of this file.
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/Debug.h"
28 #include <optional>
29 
30 #define DEBUG_TYPE "tile-using-interface"
31 
32 using namespace mlir;
33 
36  assert(!tileSizeComputationFunction && "tile sizes already set");
37  auto tileSizes = llvm::to_vector(ts);
38  tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
39  return tileSizes;
40  };
41  return *this;
42 }
43 
44 /// Helper method to adjust the interchange vector to match the iteration
45 /// domain.
48  size_t iterationDomainSize) {
49  SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
50  if (filledVector.size() < iterationDomainSize) {
51  auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
52  filledVector.append(range.begin(), range.end());
53  }
54  if (filledVector.size() > iterationDomainSize)
55  filledVector.resize(iterationDomainSize);
56  return filledVector;
57 }
58 
59 //===----------------------------------------------------------------------===//
60 // tileUsingSCF implementation.
61 //===----------------------------------------------------------------------===//
62 
63 // Check if `stride` evenly divides the trip count `size - offset`.
64 static bool tileDividesIterationDomain(Range loopRange) {
65  std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
66  if (!offsetAsInt)
67  return false;
68  std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
69  if (!sizeAsInt)
70  return false;
71  std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
72  if (!strideAsInt)
73  return false;
74  return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
75 }
76 
77 /// Returns the bounded tile size given the current `iv`, `loopRange` and
78 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
80  Range loopRange, Value iv,
81  OpFoldResult tileSize) {
82  std::optional<int64_t> ts = getConstantIntValue(tileSize);
83  if (ts && ts.value() == 1)
84  return tileSize;
85 
87  Range{loopRange.offset, loopRange.size, tileSize}))
88  return tileSize;
89 
90  // The tile size to use (to avoid out of bounds access) is minimum of
91  // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
92  // loop.
93  AffineExpr s0, s1, d0;
94  bindDims(b.getContext(), d0);
95  bindSymbols(b.getContext(), s0, s1);
96  AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext());
97  Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
99  b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
100 }
101 
102 /// A function that allows returning additional yielded values during
103 /// `yieldTiledValuesAndReplace`.
104 /// - `ivs` induction variable for the loop.
105 /// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
106 /// - `tiledValues` the tiled values to return. Must be of same size as
107 /// `newbbArgs`, each element of this array is inserted into the corresponding
108 /// element in `newbbArgs`.
109 /// - `resultOffsets` is of the same size as `tiledValues` and represents
110 /// the offsets to use when inserting corresponding element from `tiledValues`
111 /// into the element from `newBbArgs`.
112 /// - `resultSizes` is of the same size as `tiledValues` and represents
113 /// the size of the corresponding element from `tiledValues` inserted into
114 /// the element from `newBbArgs`.
115 /// In case the method needs to return `failure()` the method is expected
116 /// to clean up any inserted operations.
117 using YieldTiledValuesFn = std::function<LogicalResult(
118  RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
119  SmallVector<Value> &tiledValues,
120  SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
121  SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
122 
123 /// Clones the operation and updates the destination if the operation
124 /// implements the `DestinationStyleOpInterface`.
126  Operation *op,
127  ValueRange newDestArgs) {
128  Operation *clonedOp = rewriter.clone(*op);
129  if (newDestArgs.empty())
130  return clonedOp;
131  if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
132  destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
133  return clonedOp;
134 }
135 
136 /// Generate the tile-loop nest using `scf.for` operation.
137 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
138 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
139 /// - `destinationTensors` are the init values to use for the outer most loop.
140 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
141 /// most
142 /// loop.
143 /// - `loops` is an in-out parameter into which the generated loops are
144 /// populated.
146  RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
147  ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
148  YieldTiledValuesFn yieldTiledValuesFn,
150  assert(!loopRanges.empty() && "unexpected empty loop ranges");
151  assert(loopRanges.size() == tileSizes.size() &&
152  "expected as many tile sizes as loop ranges");
153  OpBuilder::InsertionGuard guard(rewriter);
154  SmallVector<Value> ivs;
155 
156  for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
157  // No loops if tile size is zero. Set offset and size to the loop
158  // offset and size.
159  if (isConstantIntValue(tileSize, 0))
160  continue;
161 
162  Value lb = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
163  Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
164  Value step = getValueOrCreateConstantIndexOp(rewriter, loc, tileSize);
165  auto loop =
166  rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
167  [](OpBuilder &bodyBuilder, Location bodyLoc,
168  Value iv, ValueRange /*iterArgs*/) {});
169  loops.push_back(loop);
170  ivs.push_back(loop.getInductionVar());
171  rewriter.setInsertionPointToEnd(loop.getBody());
172  destinationTensors = loop.getRegionIterArgs();
173  }
174 
175  SmallVector<Value> tiledResults;
176  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
177  if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
178  tiledResults, resultOffsets, resultSizes))) {
179  return rewriter.notifyMatchFailure(
180  loc, "failed to generate inner tile loop body");
181  }
182  if (loops.empty())
183  return success();
184 
185  // 6. Yield all the results of the tiled operation.
186  SmallVector<Value> yieldedValues;
187  for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
188  llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
189  resultSizes)) {
190  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
191  rewriter.getIndexAttr(1));
192  auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
193  loc, tiledValue, destinationTensor, resultOffset, resultSize,
194  resultStride);
195  yieldedValues.push_back(insertSlice);
196  }
197  rewriter.create<scf::YieldOp>(loc, yieldedValues);
198 
199  // Add the scf.yield operations for all the outer loops.
200  for (auto [outerLoop, innerLoop] :
201  llvm::zip_equal(MutableArrayRef(loops).drop_back(),
202  MutableArrayRef(loops).drop_front())) {
203  rewriter.setInsertionPointToEnd(
204  cast<scf::ForOp>(outerLoop.getOperation()).getBody());
205  rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
206  }
207  return success();
208 }
209 
210 /// Generate the tile-loop nest using `scf.forall` operation.
211 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
212 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
213 /// - `destinationTensors` are the init values to use for the outer most loop.
214 /// - `mappingVector` is the mapping attributes to use for loop construction.
215 /// Can be empty.
216 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
217 /// most
218 /// loop.
219 /// - `loops` is an in-out parameter into which the generated loops are
220 /// populated.
222  RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
223  ArrayRef<OpFoldResult> tileSizes, ArrayRef<Attribute> mappingVector,
224  ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn,
226  SmallVector<OpFoldResult> lbs, ubs, steps;
227  assert(!loopRanges.empty() && "unexpected empty loop ranges");
228  assert(loopRanges.size() == tileSizes.size() &&
229  "expected as many tile sizes as loop ranges");
230  OpBuilder::InsertionGuard guard(rewriter);
231  SmallVector<OpFoldResult> offsets(loopRanges.size()),
232  sizes(loopRanges.size());
233 
234  for (auto [tileSize, loopRange] : llvm::zip_equal(tileSizes, loopRanges)) {
235  if (isConstantIntValue(tileSize, 0))
236  continue;
237  lbs.push_back(loopRange.offset);
238  ubs.push_back(loopRange.size);
239  steps.push_back(tileSize);
240  }
241  assert(!lbs.empty() && "Expected at least one loop range");
242 
243  std::optional<ArrayAttr> mappingAttr;
244  if (!mappingVector.empty())
245  mappingAttr = rewriter.getArrayAttr(mappingVector);
246 
247  auto forallOp = rewriter.create<scf::ForallOp>(
248  loc, lbs, ubs, steps, destinationTensors, mappingAttr);
249  loops.push_back(forallOp);
250 
251  rewriter.setInsertionPoint(forallOp.getTerminator());
252  destinationTensors = forallOp.getRegionOutArgs();
253 
254  SmallVector<Value> tiledResults;
255  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
256  if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
257  destinationTensors, tiledResults, resultOffsets,
258  resultSizes)))
259  return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
260 
261  rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
262  for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
263  llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
264  resultSizes)) {
265  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
266  rewriter.getIndexAttr(1));
267 
268  rewriter.create<tensor::ParallelInsertSliceOp>(
269  loc, tiledValue, destinationTensor, resultOffset, resultSize,
270  resultStride);
271  }
272  return success();
273 }
274 
275 /// Generate the tile-loop nest using the loop construct specifed in `options`.
276 /// - `options`: Tiling options specified.
277 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
278 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
279 /// - `destinationTensors` are the init values to use for the outer most loop.
280 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
281 /// most
282 /// loop.
283 /// - `loops` is an in-out parameter into which the generated loops are
284 /// populated.
287  ArrayRef<Range> loopRanges,
288  ArrayRef<OpFoldResult> tileSizes,
289  ValueRange destinationTensors,
290  YieldTiledValuesFn tiledBodyFn,
292  // If the tile sizes are all zero, no loops are generated. Just call the
293  // callback function to handle untiled case.
294  if (llvm::all_of(tileSizes, isZeroIndex)) {
295  SmallVector<Value> tiledResults;
296  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
297  return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
298  tiledResults, resultOffsets, resultSizes);
299  }
301  return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
302  destinationTensors, tiledBodyFn, loops);
303  }
306  rewriter, loc, loopRanges, tileSizes, options.mappingVector,
307  destinationTensors, tiledBodyFn, loops);
308  }
309  return rewriter.notifyMatchFailure(loc, "unhandled loop type");
310 }
311 
312 /// Append the specified additional `newInitOperands` operands to the
313 /// loops existing `init` operands (or similar), and replace `loopOp` with
314 /// the new loop that has the additional init operands. The loop body of
315 /// this loop is moved over to the new loop. `yieldTiledValuesFn`
316 /// is called to get the new tiled values returned, and the offset
317 /// and sizes at which the tiled value is inserted into the
318 /// new region iter_args that correspond to the newly added init operands.
319 template <typename LoopType>
321 yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter,
322  ValueRange newInitOperands,
323  YieldTiledValuesFn yieldTiledValuesFn) {
324  return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
325 }
326 
327 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`.
328 template <>
329 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
330  scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
331  YieldTiledValuesFn yieldTiledValuesFn) {
332  OpBuilder::InsertionGuard g(rewriter);
333  Location loc = loopOp.getLoc();
334  rewriter.setInsertionPoint(loopOp);
335 
336  auto inits = llvm::to_vector(loopOp.getInitArgs());
337  inits.append(newInitOperands.begin(), newInitOperands.end());
338  auto newLoop = rewriter.create<scf::ForOp>(
339  loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
340  inits, [](OpBuilder &, Location, Value, ValueRange) {});
341 
342  // Move the loop body to the new op.
343  Block *loopBody = loopOp.getBody();
344  Block *newLoopBody = newLoop.getBody();
345  rewriter.mergeBlocks(
346  loopBody, newLoopBody,
347  newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
348 
349  auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator());
350  rewriter.setInsertionPoint(yieldOp);
351 
352  SmallVector<Value> tiledValues;
353  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
354  ValueRange newRegionIterArgs =
355  newLoop.getRegionIterArgs().take_back(newInitOperands.size());
356  if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
357  newRegionIterArgs, tiledValues, resultOffsets,
358  resultSizes))) {
359  rewriter.eraseOp(newLoop);
360  return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");
361  }
362 
363  SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
364  for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
365  llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
366  resultSizes)) {
367  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
368  rewriter.getIndexAttr(1));
369  Value insert = rewriter.create<tensor::InsertSliceOp>(
370  yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
371  resultStride);
372  newYieldValues.push_back(insert);
373  }
374 
375  rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
376  rewriter.replaceOp(loopOp,
377  newLoop->getResults().take_front(loopOp.getNumResults()));
378  return cast<LoopLikeOpInterface>(newLoop.getOperation());
379 }
380 
381 /// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall`
382 template <>
383 FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
384  scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
385  YieldTiledValuesFn yieldTiledValuesFn) {
386  OpBuilder::InsertionGuard g(rewriter);
387  Location loc = loopOp.getLoc();
388  rewriter.setInsertionPoint(loopOp);
389  auto inits = llvm::to_vector(loopOp.getOutputs());
390  inits.append(newInitOperands.begin(), newInitOperands.end());
391  auto newLoop = rewriter.create<scf::ForallOp>(
392  loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
393  loopOp.getMixedStep(), inits, loopOp.getMapping(),
394  [](OpBuilder &, Location, ValueRange) {});
395 
396  // Move the region of the current block to the newly created op.
397  Block *loopBody = loopOp.getBody();
398  Block *newLoopBody = newLoop.getBody();
399  rewriter.mergeBlocks(
400  loopBody, newLoopBody,
401  newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
402 
403  auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
404  rewriter.setInsertionPoint(terminator);
405  SmallVector<Value> tiledValues;
406  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
407  ValueRange regionIterArgs =
408  newLoop.getRegionIterArgs().take_back(newInitOperands.size());
409  if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
410  regionIterArgs, tiledValues, resultOffsets,
411  resultSizes))) {
412  rewriter.eraseOp(newLoop);
413  return rewriter.notifyMatchFailure(loopOp,
414  "failed to get yielded tiled values");
415  }
416 
417  // Update the terminator.
418  rewriter.setInsertionPointToEnd(terminator.getBody());
419 
420  for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
421  tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
422  SmallVector<OpFoldResult> resultStride(resultOffset.size(),
423  rewriter.getIndexAttr(1));
424  rewriter.create<tensor::ParallelInsertSliceOp>(
425  terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
426  resultStride);
427  }
428 
429  rewriter.replaceOp(loopOp,
430  newLoop->getResults().take_front(loopOp.getNumResults()));
431  return cast<LoopLikeOpInterface>(newLoop.getOperation());
432 }
433 
434 /// Implementation of `yieldTiledValuesAndReplaceLoop` for
435 /// `LoopLikeOpInterface`, that just dispatches to the implementation for each
436 /// supported loop type.
438  LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
439  ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
441  loopLikeOp.getOperation())
442  .Case<scf::ForOp, scf::ForallOp>(
443  [&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
445  loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
446  })
447  .Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
448  return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
449  });
450 }
451 
452 /// Method to add new init values to a loop nest. Updates `loops` in-place with
453 /// new loops that use the `newInitValues`.
454 /// The outer-loops are updated to yield the new result values of the inner
455 /// loop. For the innermost loop, the call back `getNewYields` is invoked to get
456 /// the additional values to yield form the innermost loop.
459  ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
460  SmallVector<scf::ForOp> newLoops;
461  if (loops.empty())
462  return success();
463  OpBuilder::InsertionGuard g(rewriter);
464  rewriter.setInsertionPoint(loops.front());
465 
466  SmallVector<Value> ivs;
467  for (auto &loop : loops.drop_back()) {
468  rewriter.setInsertionPoint(loop);
469 
470  // if loops.size() > 1 we assume that scf.for is used for the loops.
471  auto forLoop = cast<scf::ForOp>(loop.getOperation());
472 
473  // Create a new loop with the new init values for this loop.
474  SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs());
475  newInits.append(newInitValues.begin(), newInitValues.end());
476  auto newLoop = rewriter.create<scf::ForOp>(
477  forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
478  forLoop.getStep(), newInits,
479  [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
480 
481  // Merge the body of the new loop with the body of the old loops.
482  SmallVector<Value> sourceBlockArgs;
483  sourceBlockArgs.push_back(newLoop.getInductionVar());
484  auto newRegionIterArgs = newLoop.getRegionIterArgs();
485  sourceBlockArgs.append(
486  newRegionIterArgs.begin(),
487  std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
488  rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
489  rewriter.replaceOp(
490  forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
491  loop = newLoop;
492  ivs.push_back(newLoop.getInductionVar());
493  newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
494  }
495 
496  // Update the loop body of the innermost loop to get new yield values.
497  LoopLikeOpInterface innerMostLoop = loops.back();
498  FailureOr<LoopLikeOpInterface> newInnerMostLoop =
499  yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues,
500  getNewTiledYieldsFn);
501 
502  if (failed(newInnerMostLoop))
503  return innerMostLoop.emitOpError("failed to return additional yields");
504  loops.back() = newInnerMostLoop.value();
505 
506  // Make all other loops except the innermost loops yield the values returned
507  // by the inner loop.
508  for (auto [outerLoop, innerLoop] :
509  llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
510  // Again assume that all the outer loops are scf.for operations.
511  auto outerForLoop = cast<scf::ForOp>(outerLoop);
512  auto outerLoopYield =
513  cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
514  SmallVector<Value> newYields =
515  llvm::to_vector(outerLoopYield.getOperands());
516  ValueRange additionalYields =
517  innerLoop->getResults().take_back(newInitValues.size());
518  newYields.append(additionalYields.begin(), additionalYields.end());
519  rewriter.setInsertionPoint(outerLoopYield);
520  rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
521  }
522  return success();
523 }
524 
525 /// Implementation of tiling transformation of `op` that implements the
526 /// `TilingInterface` using `scf.for` to iterate over the tiles.
528 mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
530  OpBuilder::InsertionGuard guard(rewriter);
531  rewriter.setInsertionPointAfter(op);
532 
533  if (!options.tileSizeComputationFunction) {
534  return rewriter.notifyMatchFailure(
535  op, "missing tile size computation function");
536  }
537 
538  // 1. Get the range of the loops that are represented by the operation.
539  SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
540  size_t numLoops = iterationDomain.size();
541 
542  // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
543  // skips tiling a particular dimension. This convention is significantly
544  // simpler to handle instead of adjusting affine maps to account for missing
545  // dimensions.
546  SmallVector<OpFoldResult> tileSizes =
547  options.tileSizeComputationFunction(rewriter, op);
548  if (tileSizes.size() < iterationDomain.size()) {
549  auto zero = rewriter.getIndexAttr(0);
550  tileSizes.append(numLoops - tileSizes.size(), zero);
551  }
552 
553  // 3. If there is an interchange specified, permute the iteration domain and
554  // the tile sizes.
555  SmallVector<int64_t> interchangeVector;
556  if (!options.interchangeVector.empty()) {
557  interchangeVector = fillInterchangeVector(options.interchangeVector,
558  iterationDomain.size());
559  }
560  if (!interchangeVector.empty()) {
561  if (!isPermutationVector(interchangeVector)) {
562  return rewriter.notifyMatchFailure(
563  op, "invalid intechange vector, not a permutation of the entire "
564  "iteration space");
565  }
566 
567  applyPermutationToVector(iterationDomain, interchangeVector);
568  applyPermutationToVector(tileSizes, interchangeVector);
569  }
570 
571  FailureOr<TilingResult> tilingResult;
572  // 4. Define the lambda function used later to generate the body of the
573  // innermost tiled loop.
574  YieldTiledValuesFn innerYieldTiledValuesFn =
575  [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
576  ValueRange regionIterArgs, SmallVector<Value> &tiledResults,
579  -> LogicalResult {
580  // 4a. Compute the `offsets` and `sizes` to use for tiling.
581  SmallVector<OpFoldResult> offsets, sizes;
582  {
583  int materializedLoopNum = 0;
584  for (auto [tileSize, loopRange] :
585  llvm::zip_equal(tileSizes, iterationDomain)) {
586  if (isConstantIntValue(tileSize, 0)) {
587  offsets.push_back(loopRange.offset);
588  sizes.push_back(loopRange.size);
589  continue;
590  }
591  Value iv = ivs[materializedLoopNum++];
592  offsets.push_back(iv);
593  sizes.push_back(
594  getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
595  }
596  }
597 
598  // 4b. If interchange was provided, apply inverse of the interchange
599  // to get back the offsets/sizes in the order to be specified.
600  if (!interchangeVector.empty()) {
601  auto inversePermutation = invertPermutationVector(interchangeVector);
604  }
605 
606  // 5. Generate the tiled implementation within the inner most loop.
607 
608  // 5a. Clone the operation within the loop body.
609  auto clonedOp = cast<TilingInterface>(
610  cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
611 
612  // 5b. Early return cloned op if tiling is not happening. We can not return
613  // the original op because it could lead to
614  // `rewriter.replaceOp(op, op->getResults())` and users would get crash.
615  if (llvm::all_of(tileSizes, isZeroIndex)) {
616  tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
617  tilingResult =
618  TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults()};
619  return success();
620  }
621 
622  // 5c. Tile the cloned operation.
623  tilingResult = clonedOp.getTiledImplementation(rewriter, offsets, sizes);
624  if (failed(tilingResult)) {
625  rewriter.eraseOp(clonedOp);
626  return op.emitOpError("faild to tile operation");
627  }
628 
629  // 5d. Delete the cloned operation.
630  rewriter.eraseOp(clonedOp);
631 
632  // 5e. Compute the offsets at which the result values are to be inserted
633  // back into its destinations.
634  for (auto [index, tiledValue] :
635  llvm::enumerate(tilingResult->tiledValues)) {
636  tiledResults.push_back(tiledValue);
637  SmallVector<OpFoldResult> resultOffset, resultSize;
638  if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
639  resultOffset, resultSize))) {
640  for (auto op : tilingResult->tiledOps) {
641  rewriter.eraseOp(op);
642  }
643  return rewriter.notifyMatchFailure(
644  op, "failed to get slice of result produced");
645  }
646  resultOffsets.emplace_back(std::move(resultOffset));
647  resultSizes.emplace_back(std::move(resultSize));
648  }
649 
650  return success();
651  };
652 
653  // 6. Find the destination tensors to use for the operation.
654  SmallVector<Value> destinationTensors;
655  if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
656  destinationTensors))) {
657  return rewriter.notifyMatchFailure(op,
658  "unable to create destination tensors");
659  }
660 
661  // 7. Generate the tiled loops nest using the callback defined above.
663  if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
664  tileSizes, destinationTensors,
665  innerYieldTiledValuesFn, loops)))
666  return op.emitOpError("failed to generate tiling loops");
667  assert(succeeded(tilingResult) &&
668  "expected tiling result to be computed after loop generation");
669 
670  // If loops are empty, the tiled op is used as the replacement for the untiled
671  // op.
672  if (loops.empty()) {
673  return scf::SCFTilingResult{tilingResult->tiledOps, loops,
674  tilingResult->tiledValues};
675  }
676 
677  SmallVector<Value> replacements = llvm::map_to_vector(
678  loops.front()->getResults(), [](OpResult r) -> Value { return r; });
679  return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements};
680 }
681 
684  PartialReductionOpInterface op,
685  ArrayRef<OpFoldResult> tileSizes) {
686  Location loc = op.getLoc();
687  // Ops implementing PartialReductionOpInterface are expected to implement
688  // TilingInterface.
689  auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
690  SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
691  auto tileSizesVector = llvm::to_vector(tileSizes);
692  if (tileSizesVector.size() < iterationDomain.size()) {
693  auto zero = b.getIndexAttr(0);
694  tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
695  zero);
696  }
697  if (op->getNumResults() != 1)
698  return b.notifyMatchFailure(
699  op, "don't support ops with multiple results for now");
701  tilingInterfaceOp.getLoopIteratorTypes();
702 
703  SmallVector<int> reductionDims;
704  for (auto [idx, iteratorType] :
705  llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
706  if (iteratorType == utils::IteratorType::reduction)
707  reductionDims.push_back(idx);
708  }
709 
710  // 2. create the inital tensor value.
711  FailureOr<Operation *> identityTensor =
712  op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
713  reductionDims);
714  if (failed(identityTensor))
715  return b.notifyMatchFailure(op,
716  "cannot create a tensor of identity value.");
717 
718  // 3. Define the callback to use for generating the inner most tile loop body.
719  Operation *parallelOp = nullptr;
720  auto innerYieldTiledValuesFn =
721  [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
722  ValueRange regionIterArgs, SmallVector<Value> &tiledResult,
725  -> LogicalResult {
726  SmallVector<OpFoldResult> offsets, sizes;
727  {
728  int materializedLoopNum = 0;
729  for (auto [tileSize, loopRange] :
730  llvm::zip_equal(tileSizesVector, iterationDomain)) {
731  if (isConstantIntValue(tileSize, 0)) {
732  offsets.push_back(loopRange.offset);
733  sizes.push_back(loopRange.size);
734  continue;
735  }
736  Value iv = ivs[materializedLoopNum++];
737  offsets.push_back(iv);
738  sizes.push_back(
739  getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
740  }
741  }
742 
743  // 4a. Clone the operation.
744  auto clonedOp = cast<PartialReductionOpInterface>(
745  cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
746 
747  // 4b. Tile the cloned operation.
748  parallelOp = clonedOp.tileToPartialReduction(b, loc, regionIterArgs,
749  offsets, sizes, reductionDims);
750  // 4c. Delete the cloned operation.
751  b.eraseOp(clonedOp);
752 
753  tiledResult.append(parallelOp->result_begin(), parallelOp->result_end());
754  // 4d. Compute the offsets and sizes needed to insert the result of the
755  // tiled value back into destination before yielding the destination.
756  SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
757  resultOffsets.emplace_back(std::move(outOffsets));
758 
759  SmallVector<OpFoldResult> outSizes;
760  for (size_t i = 0; i < offsets.size(); i++) {
761  outSizes.push_back(
762  tensor::getMixedSize(b, loc, parallelOp->getResult(0), i));
763  }
764  resultSizes.emplace_back(std::move(outSizes));
765  return success();
766  };
767 
768  // 5. Generate the tiled implementation using the destination tensors.
769  SmallVector<Value> destinationTensors =
770  llvm::map_to_vector(identityTensor.value()->getResults(),
771  [](OpResult res) -> Value { return res; });
772 
776  if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
777  destinationTensors, innerYieldTiledValuesFn,
778  loops)))
779  return b.notifyMatchFailure(op, "failed to tile for parallel reduction");
780 
781  SmallVector<Value> replacements = llvm::map_to_vector(
782  loops.front()->getResults(), [](OpResult r) -> Value { return r; });
783 
784  // 5. Apply the merge reduction to combine all the partial values.
785  b.setInsertionPointAfter(*loops.begin());
786  Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims);
787  b.replaceOp(op, mergeOp->getResults());
788 
789  SCFReductionTilingResult results;
790  results.initialOp = *identityTensor;
791  results.loops = loops;
792  results.parallelTiledOp = parallelOp;
793  results.mergeOp = mergeOp;
794  return results;
795 }
796 
797 //===----------------------------------------------------------------------===//
798 // tileConsumerAndFuseProducersUsingSCF implementation.
799 //===----------------------------------------------------------------------===//
800 
801 /// Return the untiled producer whose slice is used in a tiled consumer. The
802 /// method traverses the tile loop nest (`loops`) if needed, and returns the
803 /// `iter_args` of the outer most that is encountered. Traversing the iter_args
804 /// indicates that this is a destination operand of the consumer. If there was
805 /// no loop traversal needed, the second value of the returned tuple is empty.
806 static std::tuple<OpResult, std::optional<OpOperand *>>
809  std::optional<OpOperand *> destinationIterArg;
810  auto loopIt = loops.rbegin();
811  while (auto iterArg = dyn_cast<BlockArgument>(source->get())) {
812  auto loop = *loopIt;
813  if (iterArg.getOwner()->getParentOp() != loop)
814  break;
815  source = loop.getTiedLoopInit(iterArg);
816  loopIt++;
817  }
818  if (loopIt == loops.rend())
819  destinationIterArg = source;
820  return {dyn_cast<OpResult>(source->get()), destinationIterArg};
821 }
822 
823 /// Implementation of fusing producer of a single slice by computing the
824 /// slice of the producer in-place.
825 std::optional<scf::SCFFuseProducerOfSliceResult>
827  RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
829  // 1. Get the producer of the source (potentially walking through
830  // `iter_args` of nested `scf.for`)
831  auto [fusableProducer, destinationInitArg] =
832  getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
833  loops);
834  if (!fusableProducer)
835  return std::nullopt;
836  unsigned resultNumber = fusableProducer.getResultNumber();
837 
838  OpBuilder::InsertionGuard g(rewriter);
839  rewriter.setInsertionPoint(candidateSliceOp);
840 
841  // 2. Clone the fused producer
842  // 2a. Compute the destination operands to use for the cloned operation.
843  SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
844  Operation *fusableProducerOp = fusableProducer.getOwner();
845  if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
847  rewriter, fusableProducerOp->getLoc(), fusableProducerOp,
848  origDestinationTensors)))
849  return std::nullopt;
850 
851  clonedOpDestinationTensors = origDestinationTensors;
852  if (destinationInitArg &&
853  isa<DestinationStyleOpInterface>(fusableProducerOp)) {
854  // 2b. If the producer is also destination style, then to maintain the
855  // destination passing style, update the destination of the producer to be
856  // the source of the slice.
857  clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
858  }
859  // 2c. Clone the fused producer.
860  Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
861  rewriter, fusableProducerOp, clonedOpDestinationTensors);
862  // 2d. Update the source of the candidateSlice to be the cloned producer.
863  // Easier to just clone the slice with different source since replacements
864  // and DCE of cloned ops becomes easier
865  SmallVector<Value> candidateSliceOpOperands =
866  llvm::to_vector(candidateSliceOp->getOperands());
867  candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
868  tensor::ExtractSliceOp clonedCandidateSliceOp =
869  mlir::clone(rewriter, candidateSliceOp,
870  candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
871 
872  // 3. Generate the tiled implementation of the producer of the source
873  FailureOr<TilingResult> tileAndFuseResult =
875  rewriter, clonedCandidateSliceOp,
876  clonedProducerOp->getResult(resultNumber));
877  if (failed(tileAndFuseResult))
878  return std::nullopt;
879  // Note: Do not delete the candidateSliceOp, since its passed in from the
880  // caller.
881  rewriter.replaceAllUsesWith(candidateSliceOp,
882  tileAndFuseResult->tiledValues[0]);
883  rewriter.eraseOp(clonedCandidateSliceOp);
884  rewriter.eraseOp(clonedProducerOp);
885 
886  // 3. If the slice is for a destination operand, for example,
887  //
888  // ```mlir
889  // %0 = linalg.init
890  // %1 = linalg.fill .. outs(%0 : )
891  // %2 = scf.for .. iter_args(%arg0 = %1) {
892  // %3 = scf.for .. iter_args(%arg1 = %arg0) {
893  // %4 = tensor.extract_slice %arg1 [..]
894  // .. = linalg.matmul .. outs(%4 : )
895  // }
896  // }
897  // ```
898  //
899  // the IR is currently
900  //
901  // ```
902  // %0 = linalg.init
903  // %1 = linalg.fill
904  // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
905  // %3 = scf.for .. iter_args(%arg1 = %arg0) {
906  // %4 = tensor.extract_slice %arg1[..]
907  // %5 = linalg.fill .. outs(%4 : )
908  // .. = linalg.matmul .. outs(%5 : )
909  // }
910  // }
911  // ```
912  //
913  // The untiled `linalg.fill` is still used as the `init_value` since it
914  // was originally a destination operand of the untiled `linalg.matmul`.
915  // When fusing an operand that is a destination operand, the iter_arg of
916  // the outer most loop should be changed to use the destination of the
917  // fused operation. With this the IR will be.
918  //
919  // ```
920  // %0 = linalg.init
921  // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
922  // %2 = scf.for .. iter_args(%arg1 = %arg0) {
923  // %3 = tensor.extract_slice %arg1[..]
924  // %4 = linalg.fill .. outs(%3 : )
925  // .. = linalg.matmul .. outs(%4 : )
926  // }
927  // }
928  // ```
929  if (destinationInitArg &&
930  isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
931  loops.front()
932  ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
933  .set(origDestinationTensors[resultNumber]);
934  }
935  return scf::SCFFuseProducerOfSliceResult{fusableProducer,
936  tileAndFuseResult->tiledValues[0],
937  tileAndFuseResult->tiledOps};
938 }
939 
940 /// Reconstruct the fused producer from within the tiled-and-fused code.
942  RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
943  scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
945  if (loops.empty())
946  return success();
947 
948  OpResult fusableProducer = fusedProducerInfo.origProducer;
949  Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer;
951  rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
952  if (succeeded(initValue)) {
953 
954  YieldTiledValuesFn newYieldValuesFn =
955  [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
956  ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
959  -> LogicalResult {
960  OpBuilder::InsertionGuard g(innerRewriter);
961  if (auto tiledDestStyleOp =
962  tiledAndFusedProducer
963  .getDefiningOp<DestinationStyleOpInterface>()) {
964  rewriter.setInsertionPoint(tiledDestStyleOp);
965  Value newRegionArg = newRegionIterArgs.back();
966  auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
967  sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
968  sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
969  unsigned resultNumber = fusableProducer.getResultNumber();
970  rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
971  tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
972  });
973  }
974  Block *block = rewriter.getInsertionPoint()->getBlock();
975  rewriter.setInsertionPoint(block->getTerminator());
976  tiledResult.push_back(fusedProducerInfo.tiledAndFusedProducer);
977  tiledOffset.emplace_back(sliceOp.getMixedOffsets());
978  tiledSizes.emplace_back(sliceOp.getMixedSizes());
979  return success();
980  };
981 
982  return addInitOperandsToLoopNest(rewriter, loops,
983  SmallVector<Value>{initValue.value()},
984  newYieldValuesFn);
985  }
986  return success();
987 }
988 
989 /// Implementation of tile consumer and fuse producer greedily.
992  RewriterBase &rewriter, TilingInterface consumer,
994  // This transformation is only valid for ops that return values (i.e. not
995  // valid to use with operations that have memref operands).
996  if (!consumer->getNumResults()) {
997  return rewriter.notifyMatchFailure(
998  consumer, "invalid pattern for op with no results");
999  }
1000 
1001  // 1. First tile the consumer.
1002  SetVector<Operation *> fusedProducers, tiledAndFusedOps;
1003  llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
1004 
1005  FailureOr<scf::SCFTilingResult> tilingResult =
1006  tileUsingSCF(rewriter, consumer, options.tilingOptions);
1007 
1008  if (failed(tilingResult))
1009  return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
1010  for (auto *tiledOp : tilingResult->tiledOps)
1011  tiledAndFusedOps.insert(tiledOp);
1012 
1013  // If there are no loops generated, fusion is immaterial.
1014  auto &loops = tilingResult->loops;
1015  if (loops.empty()) {
1016  DenseMap<Value, Value> replacements;
1017  for (auto [origVal, replacement] :
1018  llvm::zip_equal(consumer->getResults(), tilingResult->replacements)) {
1019  replacements[origVal] = replacement;
1020  }
1021  return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1022  replacements};
1023  }
1024 
1025  // To keep track of replacements for now just record the map from the original
1026  // untiled value to the result number of the for loop. Since the loop gets
1027  // potentially replaced during fusion, keeping the value directly wont work.
1028  DenseMap<Value, size_t> origValToResultNumber;
1029  for (auto [index, result] : llvm::enumerate(consumer->getResults())) {
1030  origValToResultNumber[result] = index;
1031  }
1032 
1033  // 2. Typically, the operands of the tiled operation are slices of the
1034  // operands of the untiled operation. These are expressed in IR using
1035  // `tensor.extract_slice` operations with source being the operands of the
1036  // untiled operation. Create a worklist of these `tensor.extract_slice`
1037  // operations. If the producers of the source of the `tensor.extract_slice`
1038  // can be tiled such that the tiled value is generated in-place, that
1039  // effectively tiles + fuses the operations.
1040  auto addCandidateSlices = [](Operation *fusedOp,
1041  std::deque<tensor::ExtractSliceOp> &candidates) {
1042  for (Value operand : fusedOp->getOperands())
1043  if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
1044  candidates.push_back(sliceOp);
1045  };
1046 
1047  std::deque<tensor::ExtractSliceOp> candidates;
1048  addCandidateSlices(tiledAndFusedOps.back(), candidates);
1049  OpBuilder::InsertionGuard g(rewriter);
1050  while (!candidates.empty()) {
1051  // Traverse the slices in BFS fashion.
1052  tensor::ExtractSliceOp candidateSliceOp = candidates.front();
1053  candidates.pop_front();
1054 
1055  // Find the original producer of the slice.
1056  auto [fusableProducer, destinationInitArg] =
1057  getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
1058  loops);
1059  if (!fusableProducer)
1060  continue;
1061 
1062  auto [fuseSlice, yieldReplacement] = options.fusionControlFn(
1063  candidateSliceOp, fusableProducer, destinationInitArg.has_value());
1064  if (!fuseSlice)
1065  continue;
1066 
1067  // The operands of the fused producer might themselved be slices of
1068  // values produced by operations that implement the `TilingInterface`.
1069  // Add these operations to the worklist.
1070  std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
1071  tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, loops);
1072  if (!fusedResult)
1073  continue;
1074 
1075  if (yieldReplacement) {
1077  rewriter, candidateSliceOp, fusedResult.value(), loops))) {
1078  return rewriter.notifyMatchFailure(
1079  fusableProducer.getOwner(), "failed to replacement value for this "
1080  "oepration from within the tiled loop");
1081  }
1082  origValToResultNumber[fusableProducer] =
1083  loops.front()->getNumResults() - 1;
1084  }
1085 
1086  if (Operation *tiledAndFusedOp =
1087  fusedResult->tiledAndFusedProducer.getDefiningOp()) {
1088  fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
1089  tiledAndFusedOps.insert(tiledAndFusedOp);
1090  addCandidateSlices(tiledAndFusedOp, candidates);
1091  }
1092  }
1093 
1094  DenseMap<Value, Value> replacements;
1095  for (auto [origVal, resultNumber] : origValToResultNumber) {
1096  replacements[origVal] = loops.front()->getResult(resultNumber);
1097  }
1098 
1099  return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
1100  replacements};
1101 }
1102 
1103 //===----------------------------------------------------------------------===//
1104 // lowerToLoopsUsingSCFForOp implementation.
1105 //===----------------------------------------------------------------------===//
1106 
1109  TilingInterface op) {
1110  // TODO: Handle cases where the op has results if needed.
1111  if (op->getNumResults() > 0) {
1112  return rewriter.notifyMatchFailure(
1113  op, "unable to lower to loops operations with return values");
1114  }
1115 
1116  SmallVector<Range> domain = op.getIterationDomain(rewriter);
1117  SmallVector<Value> ivs;
1119  Location loc = op.getLoc();
1120  for (auto loopRange : domain) {
1121  Value offsetVal =
1122  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
1123  Value sizeVal =
1124  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
1125  Value strideVal =
1126  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
1127  auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
1128  strideVal, ValueRange{});
1129  loops.push_back(loop);
1130  ivs.push_back(loop.getInductionVar());
1131  rewriter.setInsertionPoint(loop.getBody()->getTerminator());
1132  }
1133  if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
1134  return failure();
1135  }
1136  return loops;
1137 }
static llvm::ManagedStatic< PassManagerOptions > options
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
std::function< LogicalResult(RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs, SmallVector< Value > &tiledValues, SmallVector< SmallVector< OpFoldResult > > &resultOffsets, SmallVector< SmallVector< OpFoldResult > > &resultSizes)> YieldTiledValuesFn
A function that allows returning additional yielded values during yieldTiledValuesAndReplace.
static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using the loop construct specifed in options.
FailureOr< LoopLikeOpInterface > yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter, ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn)
Append the specified additional newInitOperands operands to the loops existing init operands (or simi...
static LogicalResult generateLoopNestUsingForOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ValueRange destinationTensors, YieldTiledValuesFn yieldTiledValuesFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.for operation.
static LogicalResult generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, ArrayRef< Attribute > mappingVector, ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn, SmallVector< LoopLikeOpInterface > &loops)
Generate the tile-loop nest using scf.forall operation.
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< LoopLikeOpInterface > loops)
Return the untiled producer whose slice is used in a tiled consumer.
static bool tileDividesIterationDomain(Range loopRange)
static Operation * cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter, Operation *op, ValueRange newDestArgs)
Clones the operation and updates the destination if the operation implements the DestinationStyleOpIn...
static LogicalResult addInitOperandsToLoopNest(RewriterBase &rewriter, MutableArrayRef< LoopLikeOpInterface > loops, ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn)
Method to add new init values to a loop nest.
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, Value iv, OpFoldResult tileSize)
Returns the bounded tile size given the current iv, loopRange and tileSize, i.e., min(tileSize,...
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgListType getArguments()
Definition: Block.h:84
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:447
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:553
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:438
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:263
This is a value defined by a result of an operation.
Definition: Value.h:453
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:462
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:465
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_iterator result_begin()
Definition: Operation.h:408
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_iterator result_end()
Definition: Operation.h:409
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1294
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
FailureOr< scf::SCFReductionTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SCFTilingResult > tileUsingSCF(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
LogicalResult yieldReplacementForFusedProducer(RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef< LoopLikeOpInterface > loops)
Reconstruct the fused producer from within the tiled-and-fused code.
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
std::optional< SCFFuseProducerOfSliceResult > tileAndFuseProducerOfSlice(RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef< LoopLikeOpInterface > loops)
Implementation of fusing producer of a single slice by computing the slice of the producer in-place.
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Pattern to swap an tensor.extract_slice with its producer when the producer implements the TilingInte...
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:51
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:70
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:105
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:349
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:753
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:363
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:41
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
Container for result values of tiling.
Fuse the producer of the source of candidateSliceOp by computing the required slice of the producer i...
Transformation information returned after reduction tiling.
Operation * parallelTiledOp
The partial reduction tiled op generated.
Operation * mergeOp
The final reduction operation merging all the partial reductions.
SmallVector< LoopLikeOpInterface > loops
The loop operations that iterate over the tiles.
Options used to control tile + fuse.
Transformation information returned after tile and fuse.
Options to use to control tiling.
SCFTileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes for each operation.
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > ts)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
Transformation information returned after tiling.