MLIR  17.0.0git
TileUsingInterface.cpp
Go to the documentation of this file.
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
26 #include "llvm/Support/Debug.h"
27 #include <optional>
28 
29 #define DEBUG_TYPE "tile-using-interface"
30 
31 using namespace mlir;
32 
35  assert(!tileSizeComputationFunction && "tile sizes already set");
36  SmallVector<int64_t> tileSizes(ts.begin(), ts.end());
37  tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
40  &op->getParentWithTrait<OpTrait::IsIsolatedFromAbove>()
41  ->getRegion(0)
42  .front());
43  return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
44  Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
45  return v;
46  }));
47  };
48  return *this;
49 }
50 
51 /// Helper method to adjust the interchange vector to match the iteration
52 /// domain.
55  size_t iterationDomainSize) {
56  SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
57  if (filledVector.size() < iterationDomainSize) {
58  auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
59  filledVector.append(range.begin(), range.end());
60  }
61  if (filledVector.size() > iterationDomainSize)
62  filledVector.resize(iterationDomainSize);
63  return filledVector;
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // tileUsingSCFForOp implementation.
68 //===----------------------------------------------------------------------===//
69 
70 // Check if `stride` evenly divides the trip count `size - offset`.
71 static bool tileDividesIterationDomain(Range loopRange) {
72  std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
73  if (!offsetAsInt)
74  return false;
75  std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
76  if (!sizeAsInt)
77  return false;
78  std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
79  if (!strideAsInt)
80  return false;
81  return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
82 }
83 
84 /// Returns the bounded tile size given the current `iv`, `loopRange` and
85 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
87  Range loopRange, Value iv,
88  Value tileSize) {
89  std::optional<int64_t> ts = getConstantIntValue(tileSize);
90  if (ts && ts.value() == 1)
91  return getAsOpFoldResult(tileSize);
92 
94  Range{loopRange.offset, loopRange.size, tileSize}))
95  return tileSize;
96 
97  // The tile size to use (to avoid out of bounds access) is minimum of
98  // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
99  // loop.
100  AffineExpr s0, s1, d0;
101  bindDims(b.getContext(), d0);
102  bindSymbols(b.getContext(), s0, s1);
103  AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext());
104  Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
106  b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
107 }
108 
109 /// Generate an empty loop nest that represents the tiled loop nest shell.
110 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
111 /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
112 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of
113 /// the
114 /// tile processed within the inner most loop.
117  ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
118  SmallVector<OpFoldResult> &offsets,
119  SmallVector<OpFoldResult> &sizes) {
120  assert(!loopRanges.empty() && "expected at least one loop range");
121  assert(loopRanges.size() == tileSizeVals.size() &&
122  "expected as many tile sizes as loop ranges");
123  OpBuilder::InsertionGuard guard(builder);
125  offsets.resize(loopRanges.size());
126  sizes.resize(loopRanges.size());
127 
128  for (auto loopRange : llvm::enumerate(loopRanges)) {
129  Value offset =
130  getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset);
131  Value size =
132  getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size);
133  Value tileSize = tileSizeVals[loopRange.index()];
134  // No loops if tile size is zero. Set offset and size to the loop
135  // offset and size.
136  if (matchPattern(tileSize, m_Zero())) {
137  offsets[loopRange.index()] = offset;
138  sizes[loopRange.index()] = size;
139  continue;
140  }
141 
142  auto loop = builder.create<scf::ForOp>(
143  loc, offset, size, tileSize, ValueRange{},
144  [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
145  ValueRange /*iterArgs*/) {
146  sizes[loopRange.index()] = getBoundedTileSize(
147  bodyBuilder, bodyLoc, loopRange.value(), iv, tileSize);
148  builder.create<scf::YieldOp>(loc);
149  });
150  offsets[loopRange.index()] = loop.getInductionVar();
151  loops.push_back(loop);
152  builder.setInsertionPoint(loop.getBody()->getTerminator());
153  }
154  return loops;
155 }
156 
157 /// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`,
158 /// construct the destructive update pattern that inserts the yielded
159 /// value into a destination tensor provided by `initValue` at offset
160 /// `tileOffsets` and size `tileSizes`. For example,
161 ///
162 /// ```mlir
163 /// scf.for %iv0 = ... {
164 /// %0 = tiled_op
165 /// }
166 /// ```
167 ///
168 /// is transformed to
169 ///
170 /// ```mlir
171 /// scf.for %iv0 = ... iter_args(%arg = %0) {
172 /// %1 = tensor.extract_slice %arg
173 /// %2 = tiled_op
174 /// %3 = tensor.insert_slice %2 into %arg
175 /// scf.yield %3
176 /// }
177 /// ```
178 /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`.
179 static SmallVector<Value>
181  ValueRange yieldedValues,
182  ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
183  ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
185  NewYieldValueFn yieldValueFn =
186  [&](OpBuilder &b, Location loc,
188  SmallVector<Value> inserts;
189  for (const auto &yieldedValue : llvm::enumerate(yieldedValues)) {
190  ArrayRef<OpFoldResult> tileOffsets =
191  tileOffsetsList[yieldedValue.index()];
192  ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()];
193  SmallVector<OpFoldResult> tileStrides(tileOffsets.size(),
194  b.getIndexAttr(1));
195  Value insert = b.create<tensor::InsertSliceOp>(
196  loc, yieldedValue.value(), newBBArgs[yieldedValue.index()],
197  tileOffsets, tileSizes, tileStrides);
198  inserts.push_back(insert);
199  }
200  return inserts;
201  };
202 
203  SmallVector<scf::ForOp> newLoops =
204  replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn,
205  /*replaceIterOperandsUsesInLoop =*/false);
206  for (const auto &loop : llvm::enumerate(loops)) {
207  rewriter.eraseOp(loop.value());
208  loops[loop.index()] = newLoops[loop.index()];
209  }
210  return llvm::to_vector(llvm::map_range(
211  loops.front().getResults().take_back(yieldedValues.size()),
212  [](OpResult r) -> Value { return r; }));
213 }
214 
215 /// If the tiled operation is destination passing style, update the
216 /// slice of the destination used (which refers to the untiled destination)
217 /// to use the corresponding region argument of the innermost loop.
218 ///
219 /// ```mlir
220 /// %0 =
221 /// scf.for %iv0 = ... iter_args(%arg = %0) {
222 /// %1 = tensor.extract_slice %0
223 /// %2 = tiled_op
224 /// %3 = tensor.insert_slice %2 into %arg
225 /// scf.yield %3
226 /// }
227 /// ```
228 ///
229 /// is transformed to
230 ///
231 /// ```mlir
232 /// scf.for %iv0 = ... iter_args(%arg = %0) {
233 /// %1 = tensor.extract_slice %arg
234 /// %2 = tiled_op
235 /// %3 = tensor.insert_slice %2 into %arg
236 /// scf.yield %3
237 /// }
238 /// ```
239 static void
241  ValueRange tiledOpDestinationValues,
242  ValueRange bbArgsList) {
243  for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) {
244  auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>();
245  if (!sliceOp)
246  continue;
247  sliceOp.setOperand(0, bbArgsList[destValue.index()]);
248  }
249 }
250 
251 /// Helper method to yield the values of the tiled op, as well as
252 /// update the destination operands of the tiled op, if it is
253 /// a destination passing style op.
254 static SmallVector<Value>
256  TilingResult tilingResult,
257  ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
258  ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
260  SmallVector<Value> replacements =
261  yieldTiledValues(rewriter, initValues, tilingResult.tiledValues,
262  tileOffsetsList, tileSizesList, loops);
263  for (auto tiledOp : tilingResult.tiledOps) {
264  if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) {
265  auto innerMostLoop = loops.back();
266  SmallVector<Value> tiledOpDestinationTensors = dstOp.getDpsInitOperands();
267  updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors,
268  innerMostLoop.getRegionIterArgs());
269  }
270  }
271  return replacements;
272 }
273 
274 /// Implementation of tiling transformation of `op` that implements the
275 /// `TilingInterface` using `scf.for` to iterate over the tiles.
277 mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
279  OpBuilder::InsertionGuard guard(rewriter);
280  rewriter.setInsertionPointAfter(op);
281 
282  if (!options.tileSizeComputationFunction) {
283  return rewriter.notifyMatchFailure(
284  op, "missing tile size computation function");
285  }
286 
287  // 1. Get the range of the loops that are represented by the operation.
288  SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
289  size_t numLoops = iterationDomain.size();
290  if (numLoops == 0) {
291  return rewriter.notifyMatchFailure(
292  op, "unable to tile op with no iteration domain");
293  }
294 
295  // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
296  // skips tiling a particular dimension. This convention is significantly
297  // simpler to handle instead of adjusting affine maps to account for missing
298  // dimensions.
299  SmallVector<Value> tileSizeVector =
300  options.tileSizeComputationFunction(rewriter, op);
301  if (tileSizeVector.size() < iterationDomain.size()) {
302  auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
303  tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
304  }
305 
306  scf::SCFTilingResult tilingResult;
307  SmallVector<OpFoldResult> offsets, sizes;
308  {
309  // If there is an interchange specified, permute the iteration domain and
310  // the tile sizes.
311  SmallVector<int64_t> interchangeVector;
312  if (!options.interchangeVector.empty()) {
313  interchangeVector = fillInterchangeVector(options.interchangeVector,
314  iterationDomain.size());
315  }
316  if (!interchangeVector.empty()) {
317  if (!isPermutationVector(interchangeVector)) {
318  return rewriter.notifyMatchFailure(
319  op, "invalid intechange vector, not a permutation of the entire "
320  "iteration space");
321  }
322 
323  applyPermutationToVector(iterationDomain, interchangeVector);
324  applyPermutationToVector(tileSizeVector, interchangeVector);
325  }
326 
327  // 3. Materialize an empty loop nest that iterates over the tiles. These
328  // loops for now do not return any values even if the original operation has
329  // results.
330  tilingResult.loops = generateTileLoopNest(
331  rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
332 
333  if (!interchangeVector.empty()) {
334  auto inversePermutation = invertPermutationVector(interchangeVector);
337  }
338  }
339 
340  LLVM_DEBUG({
341  if (!tilingResult.loops.empty()) {
342  llvm::dbgs() << "LoopNest shell :\n";
343  tilingResult.loops.front().dump();
344  llvm::dbgs() << "\n";
345  }
346  });
347 
348  // 4. Generate the tiled implementation within the inner most loop.
349  if (!tilingResult.loops.empty())
350  rewriter.setInsertionPoint(
351  tilingResult.loops.back().getBody()->getTerminator());
352  FailureOr<TilingResult> tiledImplementation =
353  op.getTiledImplementation(rewriter, offsets, sizes);
354  tilingResult.tiledOps.append(tiledImplementation->tiledOps);
355  if (op->getNumResults() == 0) {
356  // nothing more to do.
357  return tilingResult;
358  }
359 
360  // If loops are empty, the tiled op is used as the replacement for the untiled
361  // op.
362  if (tilingResult.loops.empty()) {
363  tilingResult.replacements = tiledImplementation->tiledValues;
364  return tilingResult;
365  }
366 
367  // 5. Yield all the results of the tiled operation. The surrounding loop
368  // nest is modified to insert a destructive update pattern to yield
369  // from the loop nest values to replace the untiled op with.
370  int64_t numResults = op->getNumResults();
371  SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults),
372  resultSizesList(numResults);
373  for (const auto &result : llvm::enumerate(op->getResults())) {
374  if (failed(op.getResultTilePosition(rewriter, result.index(), offsets,
375  sizes,
376  resultOffsetsList[result.index()],
377  resultSizesList[result.index()]))) {
378  return rewriter.notifyMatchFailure(
379  op, "failed to get slice of result produced");
380  }
381  }
382 
383  SmallVector<Value> destinationTensors;
384  if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
385  destinationTensors)))
386  return rewriter.notifyMatchFailure(op, "failed to get destinations");
387 
388  tilingResult.replacements = yieldTiledValues(
389  rewriter, destinationTensors, tiledImplementation.value(),
390  resultOffsetsList, resultSizesList, tilingResult.loops);
391 
392  LLVM_DEBUG({
393  if (!tilingResult.loops.empty()) {
394  llvm::dbgs() << "After tiled implementation :\n";
395  tilingResult.loops.front().dump();
396  llvm::dbgs() << "\n";
397  }
398  });
399  return tilingResult;
400 }
401 
404  PartialReductionOpInterface op,
405  ArrayRef<OpFoldResult> tileSize) {
406  Location loc = op.getLoc();
407  // Ops implementing PartialReductionOpInterface are expected to implement
408  // TilingInterface.
409  auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
410  SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
411  SmallVector<Value> tileSizeVector =
412  getValueOrCreateConstantIndexOp(b, loc, tileSize);
413  if (tileSizeVector.size() < iterationDomain.size()) {
414  auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
415  tileSizeVector.append(iterationDomain.size() - tileSizeVector.size(), zero);
416  }
417  if (op->getNumResults() != 1)
418  return b.notifyMatchFailure(
419  op, "don't support ops with multiple results for now");
421  tilingInterfaceOp.getLoopIteratorTypes();
422  int64_t numReductionDims = llvm::count(
423  tilingInterfaceOp.getLoopIteratorTypes(), utils::IteratorType::reduction);
424  if (numReductionDims != 1)
425  return b.notifyMatchFailure(
426  op, "only support ops with one reduction dimension.");
427  int reductionDim;
428  for (auto [idx, iteratorType] :
429  llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
430  if (iteratorType == utils::IteratorType::reduction) {
431  reductionDim = idx;
432  break;
433  }
434  }
435  if (static_cast<size_t>(reductionDim) >= tileSize.size())
436  return b.notifyMatchFailure(op, "reduction dimension must be tiled");
437 
438  // 1. create the inital tensor value.
439  FailureOr<Operation *> identityTensor =
440  op.generateInitialTensorForPartialReduction(b, loc, tileSize,
441  reductionDim);
442  if (failed(identityTensor))
443  return b.notifyMatchFailure(op,
444  "cannot create a tensor of identity value.");
445  // 2. Create the nested loops.
446  SmallVector<OpFoldResult> offsets, sizes;
448  b, loc, iterationDomain, tileSizeVector, offsets, sizes);
449 
450  // 3. Generate the tiled implementation within the inner most loop.
451  b.setInsertionPoint(loops.back().getBody()->getTerminator());
452  Operation *parallelOp = op.tileToPartialReduction(
453  b, loc, (*identityTensor)->getResults(), offsets, sizes, reductionDim);
454 
455  SmallVector<OpFoldResult> resultSizesList;
456  for (size_t i = 0; i < offsets.size(); i++)
457  resultSizesList.push_back(
458  b.createOrFold<tensor::DimOp>(loc, parallelOp->getResult(0), i));
459  SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
460  SmallVector<Value> replacements = yieldTiledValues(
461  b, (*identityTensor)->getResults(), parallelOp->getResults(), outOffsets,
462  resultSizesList, loops);
463 
464  auto dstOp = cast<DestinationStyleOpInterface>(parallelOp);
465  auto innerMostLoop = loops.back();
466  SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands();
467  assert(destinationTensors.size() ==
468  innerMostLoop.getRegionIterArgs().size() &&
469  "unexpected number of outputs");
470  updateDestinationOperandsForTiledOp(b, destinationTensors,
471  innerMostLoop.getRegionIterArgs());
472 
473  // 4. Apply the merge reduction to combine all the partial values.
474  b.setInsertionPointAfter(*loops.begin());
475  Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDim);
476  b.replaceOp(op, mergeOp->getResults());
477 
478  SCFReductionTilingResult results;
479  results.initialOp = *identityTensor;
480  results.loops = std::move(loops);
481  results.parallelTiledOp = parallelOp;
482  results.mergeOp = mergeOp;
483  return results;
484 }
485 //===----------------------------------------------------------------------===//
486 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
487 //===----------------------------------------------------------------------===//
488 
489 /// Return the untiled producer whose slice is used in a tiled consumer. The
490 /// method traverses the tile loop nest (`loops`) if needed, and returns the
491 /// `iter_args` of the outer most that is encountered. Traversing the iter_args
492 /// indicates that this is a destination operand of the consumer. If there was
493 /// no loop traversal needed, the second value of the returned tuple is empty.
494 static std::tuple<OpResult, std::optional<OpOperand *>>
496  ArrayRef<scf::ForOp> loops) {
497  std::optional<OpOperand *> destinationIterArg;
498  auto loopIt = loops.rbegin();
499  while (auto iterArg = source->get().dyn_cast<BlockArgument>()) {
500  scf::ForOp loop = *loopIt;
501  if (iterArg.getOwner()->getParentOp() != loop)
502  break;
503  source = &loop.getOpOperandForRegionIterArg(iterArg);
504  loopIt++;
505  }
506  if (loopIt == loops.rend())
507  destinationIterArg = source;
508  return {source->get().dyn_cast<OpResult>(), destinationIterArg};
509 }
510 
511 /// Implementation of fusing producer of a single slice by computing the
512 /// slice of the producer in-place.
513 std::optional<scf::SCFFuseProducerOfSliceResult>
515  tensor::ExtractSliceOp candidateSliceOp,
517  // 1. Get the producer of the source (potentially walking through
518  // `iter_args` of nested `scf.for`)
519  auto [fusableProducer, destinationIterArg] =
520  getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0),
521  loops);
522  if (!fusableProducer)
523  return std::nullopt;
524 
525  // 2. Generate the tiled implementation of the producer of the source
526  OpBuilder::InsertionGuard g(rewriter);
527  rewriter.setInsertionPoint(candidateSliceOp);
528  FailureOr<TilingResult> tileAndFuseResult =
529  tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp,
530  fusableProducer);
531  if (failed(tileAndFuseResult))
532  return std::nullopt;
533  rewriter.replaceAllUsesWith(candidateSliceOp,
534  tileAndFuseResult->tiledValues[0]);
535 
536  // 3. If the slice is for a destination operand, for example,
537  //
538  // ```mlir
539  // %0 = linalg.init
540  // %1 = linalg.fill .. outs(%0 : )
541  // %2 = scf.for .. iter_args(%arg0 = %1) {
542  // %3 = scf.for .. iter_args(%arg1 = %arg0) {
543  // %4 = tensor.extract_slice %arg1 [..]
544  // .. = linalg.matmul .. outs(%4 : )
545  // }
546  // }
547  // ```
548  //
549  // the IR is currently
550  //
551  // ```
552  // %0 = linalg.init
553  // %1 = linalg.fill
554  // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
555  // %3 = scf.for .. iter_args(%arg1 = %arg0) {
556  // %4 = tensor.extract_slice %0 /*incorrect value */ [..]
557  // %5 = linalg.fill .. outs(%4 : )
558  // .. = linalg.matmul .. outs(%5 : )
559  // }
560  // }
561  // ```
562  //
563  // The untiled `linalg.fill` is still used as the `init_value` since it
564  // was originally a destination operand of the untiled `linalg.matmul`.
565  // When fusing an operand that is a destination operand.
566  // - Update the iter_arg of the outer most loop to use the destination
567  // of the untiled producer.
568  // - Update the destination of the slice of the tiled producer generated
569  // to use the same basic block argument as the slice that was used to
570  // generate inplace the tiled implementation of the producer.
571  // With this the IR will be.
572  //
573  // ```
574  // %0 = linalg.init
575  // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
576  // %2 = scf.for .. iter_args(%arg1 = %arg0) {
577  // %3 = tensor.extract_slice %arg1 /* corrected value */ [..]
578  // %4 = linalg.fill .. outs(%3 : )
579  // .. = linalg.matmul .. outs(%4 : )
580  // }
581  // }
582  // ```
583  // TODO: This can be modeled better if the `DestinationStyleOpInterface`.
584  // Update to use that when it does become available.
585  scf::ForOp outerMostLoop = loops.front();
586  std::optional<unsigned> iterArgNumber;
587  if (destinationIterArg) {
588  iterArgNumber =
589  outerMostLoop.getIterArgNumberForOpOperand(*destinationIterArg.value());
590  }
591  if (iterArgNumber) {
592  int64_t resultNumber = fusableProducer.getResultNumber();
593  if (auto dstOp =
594  dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
595  outerMostLoop.setIterArg(iterArgNumber.value(),
596  dstOp.getTiedOpOperand(fusableProducer)->get());
597  }
598  for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) {
599  auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
600  if (!dstOp)
601  continue;
602  scf::ForOp innerMostLoop = loops.back();
604  rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
605  innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]);
606  }
607  }
608  return scf::SCFFuseProducerOfSliceResult{fusableProducer,
609  tileAndFuseResult->tiledValues[0],
610  tileAndFuseResult->tiledOps};
611 }
612 
613 /// Reconstruct the fused producer from within the tiled-and-fused code.
615  RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
616  scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
618  auto [fusableProducer, fusedProducerValue, tileAndFusedOps] =
619  fusedProducerInfo;
620  SmallVector<Value> initValues;
622  rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
623  if (succeeded(initValue)) {
624  SmallVector<OpFoldResult> resultOffsets = sliceOp.getMixedOffsets();
625  SmallVector<OpFoldResult> resultSizes = sliceOp.getMixedSizes();
626  SmallVector<Value> yieldedVals =
627  yieldTiledValues(rewriter, initValue.value(), fusedProducerValue,
628  resultOffsets, resultSizes, loops);
629  }
630  for (auto tileAndFusedOp : tileAndFusedOps) {
631  auto dstStyleProducer =
632  dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
633  if (!dstStyleProducer)
634  continue;
635  Value dstValue =
636  dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber())
637  ->get();
639  rewriter, dstValue, loops.back().getRegionIterArgs().back());
640  }
641 }
642 
643 /// Implementation of tile consumer and fuse producer greedily.
646  RewriterBase &rewriter, TilingInterface consumer,
648  // This transformation is only valid for ops that return values (i.e. not
649  // valid to use with operations that have memref operands).
650  if (!consumer->getNumResults()) {
651  return rewriter.notifyMatchFailure(
652  consumer, "invalid pattern for op with no results");
653  }
654 
655  // 1. First tile the consumer.
656  scf::SCFTileAndFuseResult tileAndFuseResult;
657  llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber;
658  {
659  FailureOr<scf::SCFTilingResult> tilingResult =
660  tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
661  if (failed(tilingResult))
662  return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
663  for (auto *tiledOp : tilingResult->tiledOps)
664  tileAndFuseResult.tiledAndFusedOps.insert(tiledOp);
665  tileAndFuseResult.loops = std::move(tilingResult->loops);
666  for (const auto &result : llvm::enumerate(
667  llvm::zip(consumer->getResults(), tilingResult->replacements))) {
668  tileAndFuseResult.replacements[std::get<0>(result.value())] =
669  std::get<1>(result.value());
670  yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult(
671  result.index())] = result.index();
672  }
673  }
674 
675  // If there are no loops generated, fusion is immaterial.
676  if (tileAndFuseResult.loops.empty())
677  return tileAndFuseResult;
678 
679  // 2. Typically, the operands of the tiled operation are slices of the
680  // operands of the untiled operation. These are expressed in IR using
681  // `tensor.extract_slice` operations with source being the operands of the
682  // untiled operation. Create a worklist of these `tensor.extract_slice`
683  // operations. If the producers of the source of the `tensor.extract_slice`
684  // can be tiled such that the tiled value is generated in-place, that
685  // effectively tiles + fuses the operations.
686  auto addCandidateSlices = [](Operation *fusedOp,
687  std::deque<tensor::ExtractSliceOp> &candidates) {
688  for (Value operand : fusedOp->getOperands())
689  if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
690  candidates.push_back(sliceOp);
691  };
692 
693  std::deque<tensor::ExtractSliceOp> candidates;
694  addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
695  OpBuilder::InsertionGuard g(rewriter);
696  while (!candidates.empty()) {
697  // Traverse the slices in BFS fashion.
698  tensor::ExtractSliceOp candidateSliceOp = candidates.front();
699  candidates.pop_front();
700 
701  // The operands of the fused producer might themselved be slices of
702  // values produced by operations that implement the `TilingInterface`.
703  // Add these operations to the worklist.
704  std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
705  tileAndFuseProducerOfSlice(rewriter, candidateSliceOp,
706  tileAndFuseResult.loops);
707  if (!fusedProducer)
708  continue;
709 
710  if (Operation *tiledAndFusedOp =
711  fusedProducer->tiledAndFusedProducer.getDefiningOp()) {
712  tileAndFuseResult.tiledAndFusedOps.insert(tiledAndFusedOp);
713  addCandidateSlices(tiledAndFusedOp, candidates);
714  }
715  }
716  return tileAndFuseResult;
717 }
718 
719 //===----------------------------------------------------------------------===//
720 // lowerToLoopsUsingSCFForOp implementation.
721 //===----------------------------------------------------------------------===//
722 
725  TilingInterface op) {
726  // TODO: Handle cases where the op has results if needed.
727  if (op->getNumResults() > 0) {
728  return rewriter.notifyMatchFailure(
729  op, "unable to lower to loops operations with return values");
730  }
731 
732  SmallVector<Range> domain = op.getIterationDomain(rewriter);
733  SmallVector<Value> ivs;
735  Location loc = op.getLoc();
736  for (auto loopRange : domain) {
737  Value offsetVal =
738  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
739  Value sizeVal =
740  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
741  Value strideVal =
742  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
743  auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
744  strideVal, ValueRange{});
745  loops.push_back(loop);
746  ivs.push_back(loop.getInductionVar());
747  rewriter.setInsertionPoint(loop.getBody()->getTerminator());
748  }
749  if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
750  return failure();
751  }
752  return loops;
753 }
static llvm::ManagedStatic< PassManagerOptions > options
static void updateDestinationOperandsForTiledOp(OpBuilder &builder, ValueRange tiledOpDestinationValues, ValueRange bbArgsList)
If the tiled operation is destination passing style, update the slice of the destination used (which ...
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, Value iv, Value tileSize)
Returns the bounded tile size given the current iv, loopRange and tileSize, i.e., min(tileSize,...
static SmallVector< Value > yieldTiledValues(RewriterBase &rewriter, ValueRange initValues, ValueRange yieldedValues, ArrayRef< SmallVector< OpFoldResult >> tileOffsetsList, ArrayRef< SmallVector< OpFoldResult >> tileSizesList, MutableArrayRef< scf::ForOp > loops)
For a value to be yielded (yieldedValue) from within a loop nest loops, construct the destructive upd...
static SmallVector< scf::ForOp > generateTileLoopNest(OpBuilder &builder, Location loc, ArrayRef< Range > loopRanges, ArrayRef< Value > tileSizeVals, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
Generate an empty loop nest that represents the tiled loop nest shell.
static std::tuple< OpResult, std::optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< scf::ForOp > loops)
Return the untiled producer whose slice is used in a tiled consumer.
static bool tileDividesIterationDomain(Range loopRange)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:43
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class represents an argument of a Block.
Definition: Value.h:304
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:121
MLIRContext * getContext() const
Definition: Builders.h:55
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:137
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:329
This class helps build Operations.
Definition: Builders.h:202
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:412
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:379
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:501
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:432
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:393
This class represents a single result from folding an operation.
Definition: OpDefinition.h:235
This class represents an operand of an operation.
Definition: Value.h:255
This is a value defined by a result of an operation.
Definition: Value.h:442
This class provides the API for ops that are known to be isolated from above.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:386
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:357
result_range getResults()
Definition: Operation.h:394
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:597
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:558
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:370
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
U dyn_cast() const
Definition: Value.h:103
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:89
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:223
FailureOr< scf::SCFReductionTilingResult > tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
std::optional< SCFFuseProducerOfSliceResult > tileAndFuseProducerOfSlice(RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, MutableArrayRef< scf::ForOp > loops)
Implementation of fusing producer of a single slice by computing the slice of the producer in-place.
void yieldReplacementForFusedProducer(RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef< scf::ForOp > loops)
Reconstruct the fused producer from within the tiled-and-fused code.
FailureOr< SCFTilingResult > tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducerGreedilyUsingSCFForOp(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
FailureOr< TilingResult > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Pattern to swap an tensor.extract_slice with its producer when the producer implements the TilingInte...
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:64
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:99
This header declares functions that assit transformations in the MemRef dialect.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:322
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:329
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1280
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:667
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:299
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:343
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:57
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< scf::ForOp > replaceLoopNestWithNewYields(OpBuilder &builder, ArrayRef< scf::ForOp > loopNest, ValueRange newIterOperands, const NewYieldValueFn &newYieldValueFn, bool replaceIterOperandsUsesInLoop=true)
Update a perfectly nested loop nest to yield new values from the innermost loop and propagating it up...
Definition: Utils.cpp:109
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBBArgs)> NewYieldValueFn
Replace the loop with newIterOperands added as new initialization values.
Definition: Utils.h:52
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
Container for result values of tiling.
SmallVector< Value > tiledValues
SmallVector< Operation * > tiledOps
Fuse the producer of the source of candidateSliceOp by computing the required slice of the producer i...
Transformation information returned after reduction tiling.
Operation * parallelTiledOp
The partial reduction tiled op generated.
Operation * mergeOp
The final reduction operation merging all the partial reductions.
SmallVector< scf::ForOp > loops
The scf.for operations that iterate over the tiles.
Options used to control tile + fuse.
Transformation information returned after tile and fuse.
SmallVector< scf::ForOp > loops
The scf.for operations that iterate over the tiles.
llvm::DenseMap< Value, Value > replacements
The replacement values to use for the tiled and fused operations.
llvm::SetVector< Operation * > tiledAndFusedOps
List of tiled and fused operations generated.
Options to use to control tiling.
SCFTileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes for each operation.
SCFTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.
SmallVector< scf::ForOp > loops
The scf.for operations that iterate over the tiles.
SmallVector< Value > replacements
Values to use as replacements for the untiled op.