MLIR  18.0.0git
TileUsingInterface.cpp
Go to the documentation of this file.
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
26 #include "llvm/Support/Debug.h"
27 #include <optional>
28 
29 #define DEBUG_TYPE "tile-using-interface"
30 
31 using namespace mlir;
32 
35  assert(!tileSizeComputationFunction && "tile sizes already set");
36  auto tileSizes = llvm::to_vector(ts);
37  tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
38  return tileSizes;
39  };
40  return *this;
41 }
42 
43 /// Helper method to adjust the interchange vector to match the iteration
44 /// domain.
47  size_t iterationDomainSize) {
48  SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
49  if (filledVector.size() < iterationDomainSize) {
50  auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
51  filledVector.append(range.begin(), range.end());
52  }
53  if (filledVector.size() > iterationDomainSize)
54  filledVector.resize(iterationDomainSize);
55  return filledVector;
56 }
57 
58 //===----------------------------------------------------------------------===//
59 // tileUsingSCFForOp implementation.
60 //===----------------------------------------------------------------------===//
61 
62 // Check if `stride` evenly divides the trip count `size - offset`.
63 static bool tileDividesIterationDomain(Range loopRange) {
64  std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
65  if (!offsetAsInt)
66  return false;
67  std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
68  if (!sizeAsInt)
69  return false;
70  std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
71  if (!strideAsInt)
72  return false;
73  return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
74 }
75 
76 /// Returns the bounded tile size given the current `iv`, `loopRange` and
77 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
79  Range loopRange, Value iv,
80  Value tileSize) {
81  std::optional<int64_t> ts = getConstantIntValue(tileSize);
82  if (ts && ts.value() == 1)
83  return getAsOpFoldResult(tileSize);
84 
86  Range{loopRange.offset, loopRange.size, tileSize}))
87  return tileSize;
88 
89  // The tile size to use (to avoid out of bounds access) is minimum of
90  // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
91  // loop.
92  AffineExpr s0, s1, d0;
93  bindDims(b.getContext(), d0);
94  bindSymbols(b.getContext(), s0, s1);
95  AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext());
96  Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
98  b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
99 }
100 
101 /// Generate an empty loop nest that represents the tiled loop nest shell.
102 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
103 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
104 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of
105 /// the
106 /// tile processed within the inner most loop.
108  OpBuilder &builder, Location loc, ArrayRef<Range> loopRanges,
110  SmallVector<OpFoldResult> &sizes) {
111  assert(!loopRanges.empty() && "expected at least one loop range");
112  assert(loopRanges.size() == tileSizes.size() &&
113  "expected as many tile sizes as loop ranges");
114  OpBuilder::InsertionGuard guard(builder);
116  offsets.resize(loopRanges.size());
117  sizes.resize(loopRanges.size());
118 
119  for (auto loopRange : llvm::enumerate(loopRanges)) {
120  Value offset =
121  getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset);
122  Value size =
123  getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size);
125  builder, loc, tileSizes[loopRange.index()]);
126  // No loops if tile size is zero. Set offset and size to the loop
127  // offset and size.
128  if (matchPattern(tileSize, m_Zero())) {
129  offsets[loopRange.index()] = offset;
130  sizes[loopRange.index()] = size;
131  continue;
132  }
133 
134  auto loop = builder.create<scf::ForOp>(
135  loc, offset, size, tileSize, ValueRange{},
136  [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
137  ValueRange /*iterArgs*/) {
138  sizes[loopRange.index()] = getBoundedTileSize(
139  bodyBuilder, bodyLoc, loopRange.value(), iv, tileSize);
140  builder.create<scf::YieldOp>(loc);
141  });
142  offsets[loopRange.index()] = loop.getInductionVar();
143  loops.push_back(loop);
144  builder.setInsertionPoint(loop.getBody()->getTerminator());
145  }
146  return loops;
147 }
148 
149 /// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`,
150 /// construct the destructive update pattern that inserts the yielded
151 /// value into a destination tensor provided by `initValue` at offset
152 /// `tileOffsets` and size `tileSizes`. For example,
153 ///
154 /// ```mlir
155 /// scf.for %iv0 = ... {
156 /// %0 = tiled_op
157 /// }
158 /// ```
159 ///
160 /// is transformed to
161 ///
162 /// ```mlir
163 /// scf.for %iv0 = ... iter_args(%arg = %0) {
164 /// %1 = tensor.extract_slice %arg
165 /// %2 = tiled_op
166 /// %3 = tensor.insert_slice %2 into %arg
167 /// scf.yield %3
168 /// }
169 /// ```
170 /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`.
171 static SmallVector<Value>
173  ValueRange yieldedValues,
174  ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
175  ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
177  NewYieldValueFn yieldValueFn =
178  [&](OpBuilder &b, Location loc,
180  SmallVector<Value> inserts;
181  for (const auto &yieldedValue : llvm::enumerate(yieldedValues)) {
182  ArrayRef<OpFoldResult> tileOffsets =
183  tileOffsetsList[yieldedValue.index()];
184  ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()];
185  SmallVector<OpFoldResult> tileStrides(tileOffsets.size(),
186  b.getIndexAttr(1));
187  Value insert = b.create<tensor::InsertSliceOp>(
188  loc, yieldedValue.value(), newBBArgs[yieldedValue.index()],
189  tileOffsets, tileSizes, tileStrides);
190  inserts.push_back(insert);
191  }
192  return inserts;
193  };
194 
195  SmallVector<scf::ForOp> newLoops =
196  replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn,
197  /*replaceIterOperandsUsesInLoop =*/false);
198  for (const auto &loop : llvm::enumerate(loops)) {
199  rewriter.eraseOp(loop.value());
200  loops[loop.index()] = newLoops[loop.index()];
201  }
202  return llvm::to_vector(llvm::map_range(
203  loops.front().getResults().take_back(yieldedValues.size()),
204  [](OpResult r) -> Value { return r; }));
205 }
206 
207 /// If the tiled operation is destination passing style, update the
208 /// slice of the destination used (which refers to the untiled destination)
209 /// to use the corresponding region argument of the innermost loop.
210 ///
211 /// ```mlir
212 /// %0 =
213 /// scf.for %iv0 = ... iter_args(%arg = %0) {
214 /// %1 = tensor.extract_slice %0
215 /// %2 = tiled_op
216 /// %3 = tensor.insert_slice %2 into %arg
217 /// scf.yield %3
218 /// }
219 /// ```
220 ///
221 /// is transformed to
222 ///
223 /// ```mlir
224 /// scf.for %iv0 = ... iter_args(%arg = %0) {
225 /// %1 = tensor.extract_slice %arg
226 /// %2 = tiled_op
227 /// %3 = tensor.insert_slice %2 into %arg
228 /// scf.yield %3
229 /// }
230 /// ```
231 static void
233  ValueRange tiledOpDestinationValues,
234  ValueRange bbArgsList) {
235  for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) {
236  auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>();
237  if (!sliceOp)
238  continue;
239  sliceOp.setOperand(0, bbArgsList[destValue.index()]);
240  }
241 }
242 
243 /// Helper method to yield the values of the tiled op, as well as
244 /// update the destination operands of the tiled op, if it is
245 /// a destination passing style op.
246 static SmallVector<Value>
248  TilingResult tilingResult,
249  ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
250  ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
252  SmallVector<Value> replacements =
253  yieldTiledValues(rewriter, initValues, tilingResult.tiledValues,
254  tileOffsetsList, tileSizesList, loops);
255  for (auto tiledOp : tilingResult.tiledOps) {
256  if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) {
257  auto innerMostLoop = loops.back();
258  SmallVector<Value> tiledOpDestinationTensors =
259  llvm::to_vector(dstOp.getDpsInits());
260  updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors,
261  innerMostLoop.getRegionIterArgs());
262  }
263  }
264  return replacements;
265 }
266 
267 /// Implementation of tiling transformation of `op` that implements the
268 /// `TilingInterface` using `scf.for` to iterate over the tiles.
270 mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
272  OpBuilder::InsertionGuard guard(rewriter);
273  rewriter.setInsertionPointAfter(op);
274 
275  if (!options.tileSizeComputationFunction) {
276  return rewriter.notifyMatchFailure(
277  op, "missing tile size computation function");
278  }
279 
280  // 1. Get the range of the loops that are represented by the operation.
281  SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
282  size_t numLoops = iterationDomain.size();
283  if (numLoops == 0) {
284  return rewriter.notifyMatchFailure(
285  op, "unable to tile op with no iteration domain");
286  }
287 
288  // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
289  // skips tiling a particular dimension. This convention is significantly
290  // simpler to handle instead of adjusting affine maps to account for missing
291  // dimensions.
292  SmallVector<OpFoldResult> tileSizeVector =
293  options.tileSizeComputationFunction(rewriter, op);
294  if (tileSizeVector.size() < iterationDomain.size()) {
295  auto zero = rewriter.getIndexAttr(0);
296  tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
297  }
298 
299  scf::SCFTilingResult tilingResult;
300  SmallVector<OpFoldResult> offsets, sizes;
301  {
302  // If there is an interchange specified, permute the iteration domain and
303  // the tile sizes.
304  SmallVector<int64_t> interchangeVector;
305  if (!options.interchangeVector.empty()) {
306  interchangeVector = fillInterchangeVector(options.interchangeVector,
307  iterationDomain.size());
308  }
309  if (!interchangeVector.empty()) {
310  if (!isPermutationVector(interchangeVector)) {
311  return rewriter.notifyMatchFailure(
312  op, "invalid intechange vector, not a permutation of the entire "
313  "iteration space");
314  }
315 
316  applyPermutationToVector(iterationDomain, interchangeVector);
317  applyPermutationToVector(tileSizeVector, interchangeVector);
318  }
319 
320  // 3. Materialize an empty loop nest that iterates over the tiles. These
321  // loops for now do not return any values even if the original operation has
322  // results.
323  tilingResult.loops = generateTileLoopNest(
324  rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
325 
326  if (!interchangeVector.empty()) {
327  auto inversePermutation = invertPermutationVector(interchangeVector);
330  }
331  }
332 
333  LLVM_DEBUG({
334  if (!tilingResult.loops.empty()) {
335  llvm::dbgs() << "LoopNest shell :\n";
336  tilingResult.loops.front().dump();
337  llvm::dbgs() << "\n";
338  }
339  });
340 
341  // 4. Generate the tiled implementation within the inner most loop.
342  if (!tilingResult.loops.empty())
343  rewriter.setInsertionPoint(
344  tilingResult.loops.back().getBody()->getTerminator());
345  FailureOr<TilingResult> tiledImplementation =
346  op.getTiledImplementation(rewriter, offsets, sizes);
347  tilingResult.tiledOps.append(tiledImplementation->tiledOps);
348  if (op->getNumResults() == 0) {
349  // nothing more to do.
350  return tilingResult;
351  }
352 
353  // If loops are empty, the tiled op is used as the replacement for the untiled
354  // op.
355  if (tilingResult.loops.empty()) {
356  tilingResult.replacements = tiledImplementation->tiledValues;
357  return tilingResult;
358  }
359 
360  // 5. Yield all the results of the tiled operation. The surrounding loop
361  // nest is modified to insert a destructive update pattern to yield
362  // from the loop nest values to replace the untiled op with.
363  int64_t numResults = op->getNumResults();
364  SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults),
365  resultSizesList(numResults);
366  for (const auto &result : llvm::enumerate(op->getResults())) {
367  if (failed(op.getResultTilePosition(rewriter, result.index(), offsets,
368  sizes,
369  resultOffsetsList[result.index()],
370  resultSizesList[result.index()]))) {
371  return rewriter.notifyMatchFailure(
372  op, "failed to get slice of result produced");
373  }
374  }
375 
376  SmallVector<Value> destinationTensors;
377  if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
378  destinationTensors)))
379  return rewriter.notifyMatchFailure(op, "failed to get destinations");
380 
381  tilingResult.replacements = yieldTiledValues(
382  rewriter, destinationTensors, tiledImplementation.value(),
383  resultOffsetsList, resultSizesList, tilingResult.loops);
384 
385  LLVM_DEBUG({
386  if (!tilingResult.loops.empty()) {
387  llvm::dbgs() << "After tiled implementation :\n";
388  tilingResult.loops.front().dump();
389  llvm::dbgs() << "\n";
390  }
391  });
392  return tilingResult;
393 }
394 
397  PartialReductionOpInterface op,
398  ArrayRef<OpFoldResult> tileSizes) {
399  Location loc = op.getLoc();
400  // Ops implementing PartialReductionOpInterface are expected to implement
401  // TilingInterface.
402  auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
403  SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
404  auto tileSizesVector = llvm::to_vector(tileSizes);
405  if (tileSizesVector.size() < iterationDomain.size()) {
406  auto zero = b.getIndexAttr(0);
407  tileSizesVector.append(iterationDomain.size() - tileSizesVector.size(),
408  zero);
409  }
410  if (op->getNumResults() != 1)
411  return b.notifyMatchFailure(
412  op, "don't support ops with multiple results for now");
414  tilingInterfaceOp.getLoopIteratorTypes();
415 
416  SmallVector<int> reductionDims;
417  for (auto [idx, iteratorType] :
418  llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
419  if (iteratorType == utils::IteratorType::reduction)
420  reductionDims.push_back(idx);
421  }
422 
423  // 1. create the inital tensor value.
424  FailureOr<Operation *> identityTensor =
425  op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
426  reductionDims);
427  if (failed(identityTensor))
428  return b.notifyMatchFailure(op,
429  "cannot create a tensor of identity value.");
430  // 2. Create the nested loops.
431  SmallVector<OpFoldResult> offsets, sizes;
433  b, loc, iterationDomain, tileSizesVector, offsets, sizes);
434 
435  // 3. Generate the tiled implementation within the inner most loop.
436  b.setInsertionPoint(loops.back().getBody()->getTerminator());
437  Operation *parallelOp = op.tileToPartialReduction(
438  b, loc, (*identityTensor)->getResults(), offsets, sizes, reductionDims);
439 
440  SmallVector<OpFoldResult> resultSizesList;
441  for (size_t i = 0; i < offsets.size(); i++)
442  resultSizesList.push_back(
443  tensor::getMixedSize(b, loc, parallelOp->getResult(0), i));
444  SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
445  SmallVector<Value> replacements = yieldTiledValues(
446  b, (*identityTensor)->getResults(), parallelOp->getResults(), outOffsets,
447  resultSizesList, loops);
448 
449  auto dstOp = cast<DestinationStyleOpInterface>(parallelOp);
450  auto innerMostLoop = loops.back();
451  SmallVector<Value> destinationTensors = llvm::to_vector(dstOp.getDpsInits());
452  assert(destinationTensors.size() ==
453  innerMostLoop.getRegionIterArgs().size() &&
454  "unexpected number of outputs");
455  updateDestinationOperandsForTiledOp(b, destinationTensors,
456  innerMostLoop.getRegionIterArgs());
457 
458  // 4. Apply the merge reduction to combine all the partial values.
459  b.setInsertionPointAfter(*loops.begin());
460  Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims);
461  b.replaceOp(op, mergeOp->getResults());
462 
463  SCFReductionTilingResult results;
464  results.initialOp = *identityTensor;
465  results.loops = std::move(loops);
466  results.parallelTiledOp = parallelOp;
467  results.mergeOp = mergeOp;
468  return results;
469 }
470 //===----------------------------------------------------------------------===//
471 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
472 //===----------------------------------------------------------------------===//
473 
474 /// Return the untiled producer whose slice is used in a tiled consumer. The
475 /// method traverses the tile loop nest (`loops`) if needed, and returns the
476 /// `iter_args` of the outer most that is encountered. Traversing the iter_args
477 /// indicates that this is a destination operand of the consumer. If there was
478 /// no loop traversal needed, the second value of the returned tuple is empty.
479 static std::tuple<OpResult, std::optional<OpOperand *>>
481  ArrayRef<scf::ForOp> loops) {
482  std::optional<OpOperand *> destinationIterArg;
483  auto loopIt = loops.rbegin();
484  while (auto iterArg = dyn_cast<BlockArgument>(source->get())) {
485  scf::ForOp loop = *loopIt;
486  if (iterArg.getOwner()->getParentOp() != loop)
487  break;
488  source = &loop.getOpOperandForRegionIterArg(iterArg);
489  loopIt++;
490  }
491  if (loopIt == loops.rend())
492  destinationIterArg = source;
493  return {dyn_cast<OpResult>(source->get()), destinationIterArg};
494 }
495 
496 /// Implementation of fusing producer of a single slice by computing the
497 /// slice of the producer in-place.
498 std::optional<scf::SCFFuseProducerOfSliceResult>
500  tensor::ExtractSliceOp candidateSliceOp,
502  // 1. Get the producer of the source (potentially walking through
503  // `iter_args` of nested `scf.for`)
504  auto [fusableProducer, destinationInitArg] =
505  getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable()[0],
506  loops);
507  if (!fusableProducer)
508  return std::nullopt;
509 
510  // 2. Generate the tiled implementation of the producer of the source
511  OpBuilder::InsertionGuard g(rewriter);
512  rewriter.setInsertionPoint(candidateSliceOp);
513  FailureOr<TilingResult> tileAndFuseResult =
514  tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp,
515  fusableProducer);
516  if (failed(tileAndFuseResult))
517  return std::nullopt;
518  rewriter.replaceAllUsesWith(candidateSliceOp,
519  tileAndFuseResult->tiledValues[0]);
520 
521  // 3. If the slice is for a destination operand, for example,
522  //
523  // ```mlir
524  // %0 = linalg.init
525  // %1 = linalg.fill .. outs(%0 : )
526  // %2 = scf.for .. iter_args(%arg0 = %1) {
527  // %3 = scf.for .. iter_args(%arg1 = %arg0) {
528  // %4 = tensor.extract_slice %arg1 [..]
529  // .. = linalg.matmul .. outs(%4 : )
530  // }
531  // }
532  // ```
533  //
534  // the IR is currently
535  //
536  // ```
537  // %0 = linalg.init
538  // %1 = linalg.fill
539  // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
540  // %3 = scf.for .. iter_args(%arg1 = %arg0) {
541  // %4 = tensor.extract_slice %0 /*incorrect value */ [..]
542  // %5 = linalg.fill .. outs(%4 : )
543  // .. = linalg.matmul .. outs(%5 : )
544  // }
545  // }
546  // ```
547  //
548  // The untiled `linalg.fill` is still used as the `init_value` since it
549  // was originally a destination operand of the untiled `linalg.matmul`.
550  // When fusing an operand that is a destination operand.
551  // - Update the iter_arg of the outer most loop to use the destination
552  // of the untiled producer.
553  // - Update the destination of the slice of the tiled producer generated
554  // to use the same basic block argument as the slice that was used to
555  // generate inplace the tiled implementation of the producer.
556  // With this the IR will be.
557  //
558  // ```
559  // %0 = linalg.init
560  // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
561  // %2 = scf.for .. iter_args(%arg1 = %arg0) {
562  // %3 = tensor.extract_slice %arg1 /* corrected value */ [..]
563  // %4 = linalg.fill .. outs(%3 : )
564  // .. = linalg.matmul .. outs(%4 : )
565  // }
566  // }
567  // ```
568  // TODO: This can be modeled better if the `DestinationStyleOpInterface`.
569  // Update to use that when it does become available.
570  scf::ForOp outerMostLoop = loops.front();
571  if (destinationInitArg &&
572  (*destinationInitArg)->getOwner() == outerMostLoop) {
573  unsigned iterArgNumber =
574  outerMostLoop.getResultForOpOperand(**destinationInitArg)
575  .getResultNumber();
576  int64_t resultNumber = fusableProducer.getResultNumber();
577  if (auto dstOp =
578  dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
579  (*destinationInitArg)
580  ->set(dstOp.getTiedOpOperand(fusableProducer)->get());
581  }
582  for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) {
583  auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
584  if (!dstOp)
585  continue;
586  scf::ForOp innerMostLoop = loops.back();
588  rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
589  innerMostLoop.getRegionIterArgs()[iterArgNumber]);
590  }
591  }
592  return scf::SCFFuseProducerOfSliceResult{fusableProducer,
593  tileAndFuseResult->tiledValues[0],
594  tileAndFuseResult->tiledOps};
595 }
596 
597 /// Reconstruct the fused producer from within the tiled-and-fused code.
599  RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
600  scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
602  auto [fusableProducer, fusedProducerValue, tileAndFusedOps] =
603  fusedProducerInfo;
604  SmallVector<Value> initValues;
606  rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
607  if (succeeded(initValue)) {
608  SmallVector<OpFoldResult> resultOffsets = sliceOp.getMixedOffsets();
609  SmallVector<OpFoldResult> resultSizes = sliceOp.getMixedSizes();
610  SmallVector<Value> yieldedVals =
611  yieldTiledValues(rewriter, initValue.value(), fusedProducerValue,
612  resultOffsets, resultSizes, loops);
613  }
614  for (auto tileAndFusedOp : tileAndFusedOps) {
615  auto dstStyleProducer =
616  dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
617  if (!dstStyleProducer)
618  continue;
619  Value dstValue =
620  dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber())
621  ->get();
623  rewriter, dstValue, loops.back().getRegionIterArgs().back());
624  }
625 }
626 
627 /// Implementation of tile consumer and fuse producer greedily.
630  RewriterBase &rewriter, TilingInterface consumer,
632  // This transformation is only valid for ops that return values (i.e. not
633  // valid to use with operations that have memref operands).
634  if (!consumer->getNumResults()) {
635  return rewriter.notifyMatchFailure(
636  consumer, "invalid pattern for op with no results");
637  }
638 
639  // 1. First tile the consumer.
640  scf::SCFTileAndFuseResult tileAndFuseResult;
641  llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber;
642  {
643  FailureOr<scf::SCFTilingResult> tilingResult =
644  tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
645  if (failed(tilingResult))
646  return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
647  for (auto *tiledOp : tilingResult->tiledOps)
648  tileAndFuseResult.tiledAndFusedOps.insert(tiledOp);
649  tileAndFuseResult.loops = std::move(tilingResult->loops);
650  for (const auto &result : llvm::enumerate(
651  llvm::zip(consumer->getResults(), tilingResult->replacements))) {
652  tileAndFuseResult.replacements[std::get<0>(result.value())] =
653  std::get<1>(result.value());
654  yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult(
655  result.index())] = result.index();
656  }
657  }
658 
659  // If there are no loops generated, fusion is immaterial.
660  if (tileAndFuseResult.loops.empty())
661  return tileAndFuseResult;
662 
663  // 2. Typically, the operands of the tiled operation are slices of the
664  // operands of the untiled operation. These are expressed in IR using
665  // `tensor.extract_slice` operations with source being the operands of the
666  // untiled operation. Create a worklist of these `tensor.extract_slice`
667  // operations. If the producers of the source of the `tensor.extract_slice`
668  // can be tiled such that the tiled value is generated in-place, that
669  // effectively tiles + fuses the operations.
670  auto addCandidateSlices = [](Operation *fusedOp,
671  std::deque<tensor::ExtractSliceOp> &candidates) {
672  for (Value operand : fusedOp->getOperands())
673  if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
674  candidates.push_back(sliceOp);
675  };
676 
677  std::deque<tensor::ExtractSliceOp> candidates;
678  addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
679  OpBuilder::InsertionGuard g(rewriter);
680  while (!candidates.empty()) {
681  // Traverse the slices in BFS fashion.
682  tensor::ExtractSliceOp candidateSliceOp = candidates.front();
683  candidates.pop_front();
684 
685  // The operands of the fused producer might themselved be slices of
686  // values produced by operations that implement the `TilingInterface`.
687  // Add these operations to the worklist.
688  std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
689  tileAndFuseProducerOfSlice(rewriter, candidateSliceOp,
690  tileAndFuseResult.loops);
691  if (!fusedProducer)
692  continue;
693 
694  if (Operation *tiledAndFusedOp =
695  fusedProducer->tiledAndFusedProducer.getDefiningOp()) {
696  tileAndFuseResult.tiledAndFusedOps.insert(tiledAndFusedOp);
697  addCandidateSlices(tiledAndFusedOp, candidates);
698  }
699  }
700  return tileAndFuseResult;
701 }
702 
703 //===----------------------------------------------------------------------===//
704 // lowerToLoopsUsingSCFForOp implementation.
705 //===----------------------------------------------------------------------===//
706 
709  TilingInterface op) {
710  // TODO: Handle cases where the op has results if needed.
711  if (op->getNumResults() > 0) {
712  return rewriter.notifyMatchFailure(
713  op, "unable to lower to loops operations with return values");
714  }
715 
716  SmallVector<Range> domain = op.getIterationDomain(rewriter);
717  SmallVector<Value> ivs;
719  Location loc = op.getLoc();
720  for (auto loopRange : domain) {
721  Value offsetVal =
722  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
723  Value sizeVal =
724  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
725  Value strideVal =
726  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
727  auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
728  strideVal, ValueRange{});
729  loops.push_back(loop);
730  ivs.push_back(loop.getInductionVar());
731  rewriter.setInsertionPoint(loop.getBody()->getTerminator());
732  }
733  if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
734  return failure();
735  }
736  return loops;
737 }
static llvm::ManagedStatic< PassManagerOptions > options
static void updateDestinationOperandsForTiledOp(OpBuilder &builder, ValueRange tiledOpDestinationValues, ValueRange bbArgsList)
If the tiled operation is destination passing style, update the slice of the destination used (which ...
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static SmallVector< scf::ForOp > generateTileLoopNest(OpBuilder &builder, Location loc, ArrayRef< Range > loopRanges, ArrayRef< OpFoldResult > tileSizes, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
Generate an empty loop nest that represents the tiled loop nest shell.
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, Value iv, Value tileSize)
Returns the bounded tile size given the current iv, loopRange and tileSize, i.e., min(tileSize,...
static SmallVector< Value > yieldTiledValues(RewriterBase &rewriter, ValueRange initValues, ValueRange yieldedValues, ArrayRef< SmallVector< OpFoldResult >> tileOffsetsList, ArrayRef< SmallVector< OpFoldResult >> tileSizesList, MutableArrayRef< scf::ForOp > loops)
For a value to be yielded (yieldedValue) from within a loop nest loops, construct the destructive upd...
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< scf::ForOp > loops)
Return the untiled producer whose slice is used in a tiled consumer.
static bool tileDividesIterationDomain(Range loopRange)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:44
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
MLIRContext * getContext() const
Definition: Builders.h:55
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:150
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:397
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
This class represents an operand of an operation.
Definition: Value.h:261
This is a value defined by a result of an operation.
Definition: Value.h:448
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:615
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1379
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
FailureOr< scf::SCFReductionTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
std::optional< SCFFuseProducerOfSliceResult > tileAndFuseProducerOfSlice(RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef< scf::ForOp > loops)
Implementation of fusing producer of a single slice by computing the slice of the producer in-place.
void yieldReplacementForFusedProducer(RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef< scf::ForOp > loops)
Reconstruct the fused producer from within the tiled-and-fused code.
FailureOr< SCFTilingResult > tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducerGreedilyUsingSCFForOp(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Pattern to swap an tensor.extract_slice with its producer when the producer implements the TilingInte...
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:49
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:68
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:103
This header declares functions that assist transformations in the MemRef dialect.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:331
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:680
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:378
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:345
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:40
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< scf::ForOp > replaceLoopNestWithNewYields(OpBuilder &builder, ArrayRef< scf::ForOp > loopNest, ValueRange newIterOperands, const NewYieldValueFn &newYieldValueFn, bool replaceIterOperandsUsesInLoop=true)
Update a perfectly nested loop nest to yield new values from the innermost loop and propagating it up...
Definition: Utils.cpp:109
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBBArgs)> NewYieldValueFn
Replace the loop with newIterOperands added as new initialization values.
Definition: Utils.h:52
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
Container for result values of tiling.
SmallVector< Value > tiledValues
SmallVector< Operation * > tiledOps
Fuse the producer of the source of candidateSliceOp by computing the required slice of the producer i...
Transformation information returned after reduction tiling.
Operation * parallelTiledOp
The partial reduction tiled op generated.
Operation * mergeOp
The final reduction operation merging all the partial reductions.
SmallVector< scf::ForOp > loops
The scf.for operations that iterate over the tiles.
Options used to control tile + fuse.
Transformation information returned after tile and fuse.
SmallVector< scf::ForOp > loops
The scf.for operations that iterate over the tiles.
llvm::DenseMap< Value, Value > replacements
The replacement values to use for the tiled and fused operations.
llvm::SetVector< Operation * > tiledAndFusedOps
List of tiled and fused operations generated.
Options to use to control tiling.
SCFTileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes for each operation.
SCFTilingOptions & setTileSizes(ArrayRef< OpFoldResult > ts)
Convenience function to set the tileSizeComputationFunction to a function that computes tile sizes at...
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.
SmallVector< scf::ForOp > loops
The scf.for operations that iterate over the tiles.
SmallVector< Value > replacements
Values to use as replacements for the untiled op.