MLIR  16.0.0git
TileUsingInterface.cpp
Go to the documentation of this file.
1 //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the tiling using TilingInterface.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
26 #include "llvm/Support/Debug.h"
27 
28 #define DEBUG_TYPE "tile-using-interface"
29 
30 using namespace mlir;
31 
34  assert(!tileSizeComputationFunction && "tile sizes already set");
35  SmallVector<int64_t> tileSizes(ts.begin(), ts.end());
36  tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
39  &op->getParentOfType<func::FuncOp>().getBody().front());
40  return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
41  Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
42  return v;
43  }));
44  };
45  return *this;
46 }
47 
48 /// Helper method to adjust the interchange vector to match the iteration
49 /// domain.
52  size_t iterationDomainSize) {
53  SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
54  if (filledVector.size() < iterationDomainSize) {
55  auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
56  filledVector.append(range.begin(), range.end());
57  }
58  if (filledVector.size() > iterationDomainSize)
59  filledVector.resize(iterationDomainSize);
60  return filledVector;
61 }
62 
63 //===----------------------------------------------------------------------===//
64 // tileUsingSCFForOp implementation.
65 //===----------------------------------------------------------------------===//
66 
67 // Check if `stride` evenly divides the trip count `size - offset`.
68 static bool tileDividesIterationDomain(Range loopRange) {
69  Optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
70  if (!offsetAsInt)
71  return false;
72  Optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
73  if (!sizeAsInt)
74  return false;
75  Optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
76  if (!strideAsInt)
77  return false;
78  return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
79 }
80 
81 /// Returns the bounded tile size given the current `iv`, `loopRange` and
82 /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
84  Range loopRange, Value iv,
85  Value tileSize) {
87  if (ts && ts.value() == 1)
88  return getAsOpFoldResult(tileSize);
89 
91  Range{loopRange.offset, loopRange.size, tileSize}))
92  return tileSize;
93 
94  // The tile size to use (to avoid out of bounds access) is minimum of
95  // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
96  // loop.
97  AffineExpr s0, s1, d0;
98  bindDims(b.getContext(), d0);
99  bindSymbols(b.getContext(), s0, s1);
100  AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext());
101  Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
103  b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
104 }
105 
106 /// Generate an empty loop nest that represents the tiled loop nest shell.
107 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
108 /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops.
109 /// - In `offsets` and `sizes` return the multi-dimensional offset and size of
110 /// the
111 /// tile processed within the inner most loop.
114  ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals,
115  SmallVector<OpFoldResult> &offsets,
116  SmallVector<OpFoldResult> &sizes) {
117  assert(!loopRanges.empty() && "expected at least one loop range");
118  assert(loopRanges.size() == tileSizeVals.size() &&
119  "expected as many tile sizes as loop ranges");
120  OpBuilder::InsertionGuard guard(builder);
122  offsets.resize(loopRanges.size());
123  sizes.resize(loopRanges.size());
124 
125  for (auto loopRange : llvm::enumerate(loopRanges)) {
126  Value offset =
127  getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset);
128  Value size =
129  getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size);
130  Value tileSize = tileSizeVals[loopRange.index()];
131  // No loops if tile size is zero. Set offset and size to the loop
132  // offset and size.
133  if (matchPattern(tileSize, m_Zero())) {
134  offsets[loopRange.index()] = offset;
135  sizes[loopRange.index()] = size;
136  continue;
137  }
138 
139  auto loop = builder.create<scf::ForOp>(
140  loc, offset, size, tileSize, ValueRange{},
141  [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
142  ValueRange /*iterArgs*/) {
143  sizes[loopRange.index()] = getBoundedTileSize(
144  bodyBuilder, bodyLoc, loopRange.value(), iv, tileSize);
145  builder.create<scf::YieldOp>(loc);
146  });
147  offsets[loopRange.index()] = loop.getInductionVar();
148  loops.push_back(loop);
149  builder.setInsertionPoint(loop.getBody()->getTerminator());
150  }
151  return loops;
152 }
153 
154 /// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`,
155 /// construct the destructive update pattern that inserts the yielded
156 /// value into a destination tensor provided by `initValue` at offset
157 /// `tileOffsets` and size `tileSizes`. For example,
158 ///
159 /// ```mlir
160 /// scf.for %iv0 = ... {
161 /// %0 = tiled_op
162 /// }
163 /// ```
164 ///
165 /// is transformed to
166 ///
167 /// ```mlir
168 /// scf.for %iv0 = ... iter_args(%arg = %0) {
169 /// %1 = tensor.extract_slice %arg
170 /// %2 = tiled_op
171 /// %3 = tensor.insert_slice %2 into %arg
172 /// scf.yield %3
173 /// }
174 /// ```
175 /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`.
178  ValueRange yieldedValues,
179  ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
180  ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
182  NewYieldValueFn yieldValueFn =
183  [&](OpBuilder &b, Location loc,
185  SmallVector<Value> inserts;
186  for (const auto &yieldedValue : llvm::enumerate(yieldedValues)) {
187  ArrayRef<OpFoldResult> tileOffsets =
188  tileOffsetsList[yieldedValue.index()];
189  ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()];
190  SmallVector<OpFoldResult> tileStrides(tileOffsets.size(),
191  b.getIndexAttr(1));
192  Value insert = b.create<tensor::InsertSliceOp>(
193  loc, yieldedValue.value(), newBBArgs[yieldedValue.index()],
194  tileOffsets, tileSizes, tileStrides);
195  inserts.push_back(insert);
196  }
197  return inserts;
198  };
199 
200  SmallVector<scf::ForOp> newLoops =
201  replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn,
202  /*replaceIterOperandsUsesInLoop =*/false);
203  for (const auto &loop : llvm::enumerate(loops)) {
204  rewriter.eraseOp(loop.value());
205  loops[loop.index()] = newLoops[loop.index()];
206  }
207  return llvm::to_vector(llvm::map_range(
208  loops.front().getResults().take_back(yieldedValues.size()),
209  [](OpResult r) -> Value { return r; }));
210 }
211 
212 /// If the tiled operation is destination passing style, update the
213 /// slice of the destination used (which refers to the untiled destination)
214 /// to use the corresponding region argument of the innermost loop.
215 ///
216 /// ```mlir
217 /// %0 =
218 /// scf.for %iv0 = ... iter_args(%arg = %0) {
219 /// %1 = tensor.extract_slice %0
220 /// %2 = tiled_op
221 /// %3 = tensor.insert_slice %2 into %arg
222 /// scf.yield %3
223 /// }
224 /// ```
225 ///
226 /// is transformed to
227 ///
228 /// ```mlir
229 /// scf.for %iv0 = ... iter_args(%arg = %0) {
230 /// %1 = tensor.extract_slice %arg
231 /// %2 = tiled_op
232 /// %3 = tensor.insert_slice %2 into %arg
233 /// scf.yield %3
234 /// }
235 /// ```
236 static void
238  ValueRange tiledOpDestinationValues,
239  ValueRange bbArgsList) {
240  for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) {
241  auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>();
242  if (!sliceOp)
243  continue;
244  sliceOp.setOperand(0, bbArgsList[destValue.index()]);
245  }
246 }
247 
248 /// Implementation of tiling transformation of `op` that implements the
249 /// `TilingInterface` using `scf.for` to iterate over the tiles.
251 mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
253  OpBuilder::InsertionGuard guard(rewriter);
254  rewriter.setInsertionPointAfter(op);
255 
256  if (!options.tileSizeComputationFunction) {
257  return rewriter.notifyMatchFailure(
258  op, "missing tile size computation function");
259  }
260 
261  // Get destination tensors.
262  SmallVector<Value> destinationTensors;
263  if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
264  destinationTensors)))
265  return rewriter.notifyMatchFailure(op, "failed to get destinations");
266 
267  // 1. Get the range of the loops that are represented by the operation.
268  SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
269  size_t numLoops = iterationDomain.size();
270  if (numLoops == 0) {
271  return rewriter.notifyMatchFailure(
272  op, "unable to tile op with no iteration domain");
273  }
274 
275  // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
276  // skips tiling a particular dimension. This convention is significantly
277  // simpler to handle instead of adjusting affine maps to account for missing
278  // dimensions.
279  SmallVector<Value> tileSizeVector =
280  options.tileSizeComputationFunction(rewriter, op);
281  if (tileSizeVector.size() < iterationDomain.size()) {
282  auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
283  tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
284  }
285 
286  scf::SCFTilingResult tilingResult;
287  SmallVector<OpFoldResult> offsets, sizes;
288  {
289  // If there is an interchange specified, permute the iteration domain and
290  // the tile sizes.
291  SmallVector<int64_t> interchangeVector;
292  if (!options.interchangeVector.empty()) {
293  interchangeVector = fillInterchangeVector(options.interchangeVector,
294  iterationDomain.size());
295  }
296  if (!interchangeVector.empty()) {
297  if (!isPermutationVector(interchangeVector)) {
298  return rewriter.notifyMatchFailure(
299  op, "invalid intechange vector, not a permutation of the entire "
300  "iteration space");
301  }
302 
303  applyPermutationToVector(iterationDomain, interchangeVector);
304  applyPermutationToVector(tileSizeVector, interchangeVector);
305  }
306 
307  // 3. Materialize an empty loop nest that iterates over the tiles. These
308  // loops for now do not return any values even if the original operation has
309  // results.
310  tilingResult.loops = generateTileLoopNest(
311  rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
312 
313  if (!interchangeVector.empty()) {
314  auto inversePermutation = invertPermutationVector(interchangeVector);
317  }
318  }
319 
320  LLVM_DEBUG({
321  if (!tilingResult.loops.empty()) {
322  llvm::dbgs() << "LoopNest shell :\n";
323  tilingResult.loops.front().dump();
324  llvm::dbgs() << "\n";
325  }
326  });
327 
328  // 4. Generate the tiled implementation within the inner most loop.
329  if (!tilingResult.loops.empty())
330  rewriter.setInsertionPoint(
331  tilingResult.loops.back().getBody()->getTerminator());
332  SmallVector<Operation *> tiledImplementation =
333  op.getTiledImplementation(rewriter, offsets, sizes);
334  tilingResult.tiledOps.append(tiledImplementation);
335  if (op->getNumResults() == 0) {
336  // nothing more to do.
337  return tilingResult;
338  }
339 
340  // If loops are empty, the tiled op is used as the replacement for the untiled
341  // op.
342  if (tilingResult.loops.empty()) {
343  tilingResult.replacements = llvm::to_vector(
344  llvm::map_range(tiledImplementation[0]->getResults(),
345  [](OpResult result) -> Value { return result; }));
346  return tilingResult;
347  }
348 
349  // 5. Yield all the results of the tiled operation. The surrounding loop
350  // nest is modified to insert a destructive update pattern to yield
351  // from the loop nest values to replace the untiled op with.
352  int64_t numResults = op->getNumResults();
353  SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults),
354  resultSizesList(numResults);
355  for (const auto &result : llvm::enumerate(op->getResults())) {
356  if (failed(op.getResultTilePosition(rewriter, result.index(), offsets,
357  sizes,
358  resultOffsetsList[result.index()],
359  resultSizesList[result.index()]))) {
360  return rewriter.notifyMatchFailure(
361  op, "failed to get slice of result produced");
362  }
363  }
364 
366  rewriter, destinationTensors, tilingResult.tiledOps.back()->getResults(),
367  resultOffsetsList, resultSizesList, tilingResult.loops);
368  if (failed(replacementOr))
369  return rewriter.notifyMatchFailure(op, "failed to yield replacement");
370 
371  if (auto dstOp =
372  dyn_cast<DestinationStyleOpInterface>(tilingResult.tiledOps.back())) {
373  auto innerMostLoop = tilingResult.loops.back();
374  SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands();
375  assert(destinationTensors.size() ==
376  innerMostLoop.getRegionIterArgs().size() &&
377  "unexpected number of outputs");
378  updateDestinationOperandsForTiledOp(rewriter, destinationTensors,
379  innerMostLoop.getRegionIterArgs());
380  }
381 
382  tilingResult.replacements = replacementOr.value();
383 
384  LLVM_DEBUG({
385  if (!tilingResult.loops.empty()) {
386  llvm::dbgs() << "After tiled implementation :\n";
387  tilingResult.loops.front().dump();
388  llvm::dbgs() << "\n";
389  }
390  });
391  return tilingResult;
392 }
393 
396  PartialReductionOpInterface op,
397  ArrayRef<OpFoldResult> tileSize) {
398  Location loc = op.getLoc();
399  // Ops implementing PartialReductionOpInterface are expected to implement
400  // TilingInterface.
401  auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
402  SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
403  SmallVector<Value> tileSizeVector =
404  getValueOrCreateConstantIndexOp(b, loc, tileSize);
405  if (tileSizeVector.size() < iterationDomain.size()) {
406  auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
407  tileSizeVector.append(iterationDomain.size() - tileSizeVector.size(), zero);
408  }
409  if (op->getNumResults() != 1)
410  return b.notifyMatchFailure(
411  op, "don't support ops with multiple results for now");
413  tilingInterfaceOp.getLoopIteratorTypes();
414  int64_t numReductionDims = llvm::count(
415  tilingInterfaceOp.getLoopIteratorTypes(), utils::IteratorType::reduction);
416  if (numReductionDims != 1)
417  return b.notifyMatchFailure(
418  op, "only support ops with one reduction dimension.");
419  int reductionDim;
420  for (auto &[idx, iteratorType] :
421  llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
422  if (iteratorType == utils::IteratorType::reduction) {
423  reductionDim = idx;
424  break;
425  }
426  }
427  // 1. create the inital tensor value.
428  FailureOr<Operation *> identityTensor =
429  op.generateInitialTensorForPartialReduction(b, loc, tileSize,
430  reductionDim);
431  if (failed(identityTensor))
432  return b.notifyMatchFailure(op,
433  "cannot create a tensor of identity value.");
434  // 2. Create the nested loops.
435  SmallVector<OpFoldResult> offsets, sizes;
437  b, loc, iterationDomain, tileSizeVector, offsets, sizes);
438 
439  // 3. Generate the tiled implementation within the inner most loop.
440  b.setInsertionPoint(loops.back().getBody()->getTerminator());
441  Operation *parallelOp =
442  op.tileToPartialReduction(b, loc, identityTensor.value()->getResults(),
443  offsets, sizes, reductionDim);
444 
445  SmallVector<OpFoldResult> resultSizesList;
446  for (size_t i = 0; i < offsets.size(); i++)
447  resultSizesList.push_back(
448  b.createOrFold<tensor::DimOp>(loc, parallelOp->getResult(0), i));
449  SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
451  b, identityTensor.value()->getResults(), parallelOp->getResults(),
452  outOffsets, resultSizesList, loops);
453  if (failed(replacementOr))
454  return b.notifyMatchFailure(op, "failed to yield replacement");
455 
456  auto dstOp = cast<DestinationStyleOpInterface>(parallelOp);
457  auto innerMostLoop = loops.back();
458  SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands();
459  assert(destinationTensors.size() ==
460  innerMostLoop.getRegionIterArgs().size() &&
461  "unexpected number of outputs");
462  updateDestinationOperandsForTiledOp(b, destinationTensors,
463  innerMostLoop.getRegionIterArgs());
464 
465  // 4. Apply the merge reduction to combine all the partial values.
466  b.setInsertionPointAfter(*loops.begin());
467  Operation *mergeOp =
468  op.mergeReductions(b, loc, replacementOr.value(), reductionDim);
469  b.replaceOp(op, mergeOp->getResults());
470 
471  SCFReductionTilingResult results;
472  results.initialOp = identityTensor.value();
473  results.loops = std::move(loops);
474  results.parallelTiledOp = parallelOp;
475  results.mergeOp = mergeOp;
476  return results;
477 }
478 //===----------------------------------------------------------------------===//
479 // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
480 //===----------------------------------------------------------------------===//
481 
482 /// Return the untiled producer whose slice is used in a tiled consumer. The
483 /// method traverses the tile loop nest (`loops`) if needed, and returns the
484 /// `iter_args` of the outer most that is encountered. Traversing the iter_args
485 /// indicates that this is a destination operand of the consumer. If there was
486 /// no loop traversal needed, the second value of the returned tuple is empty.
487 static std::tuple<OpResult, Optional<OpOperand *>>
489  ArrayRef<scf::ForOp> loops) {
490  Optional<OpOperand *> destinationIterArg;
491  auto loopIt = loops.rbegin();
492  while (auto iterArg = source->get().dyn_cast<BlockArgument>()) {
493  scf::ForOp loop = *loopIt;
494  if (iterArg.getOwner()->getParentOp() != loop)
495  break;
496  source = &loop.getOpOperandForRegionIterArg(iterArg);
497  loopIt++;
498  }
499  if (loopIt == loops.rend())
500  destinationIterArg = source;
501  return {source->get().dyn_cast<OpResult>(), destinationIterArg};
502 }
503 
504 /// Implementation of tile consumer and fuse producer greedily.
507  RewriterBase &rewriter, TilingInterface consumer,
509  // This transformation is only valid for ops that return values (i.e. not
510  // valid to use with operations that have memref operands).
511  if (!consumer->getNumResults()) {
512  return rewriter.notifyMatchFailure(
513  consumer, "invalid pattern for op with no results");
514  }
515 
516  // 1. First tile the consumer.
517  scf::SCFTileAndFuseResult tileAndFuseResult;
518  llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber;
519  {
520  FailureOr<scf::SCFTilingResult> tilingResult =
521  tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
522  if (failed(tilingResult))
523  return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
524  for (auto *tiledOp : tilingResult->tiledOps)
525  tileAndFuseResult.tiledAndFusedOps.insert(tiledOp);
526  tileAndFuseResult.loops = std::move(tilingResult->loops);
527  for (const auto &result : llvm::enumerate(
528  llvm::zip(consumer->getResults(), tilingResult->replacements))) {
529  tileAndFuseResult.replacements[std::get<0>(result.value())] =
530  std::get<1>(result.value());
531  yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult(
532  result.index())] = result.index();
533  }
534  }
535 
536  // If there are no loops generated, fusion is immaterial.
537  if (tileAndFuseResult.loops.empty())
538  return tileAndFuseResult;
539 
540  // 2. Typically, the operands of the tiled operation are slices of the
541  // operands of the untiled operation. These are expressed in IR using
542  // `tensor.extract_slice` operations with source being the operands of the
543  // untiled operation. Create a worklist of these `tensor.extract_slice`
544  // operations. If the producers of the source of the `tensor.extract_slice`
545  // can be tiled such that the tiled value is generated in-place, that
546  // effectively tiles + fuses the operations.
547  auto addCandidateSlices = [](Operation *fusedOp,
548  std::deque<tensor::ExtractSliceOp> &candidates) {
549  for (Value operand : fusedOp->getOperands())
550  if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
551  candidates.push_back(sliceOp);
552  };
553 
554  std::deque<tensor::ExtractSliceOp> candidates;
555  addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
556  OpBuilder::InsertionGuard g(rewriter);
557  while (!candidates.empty()) {
558  // 2a. Traverse the slices in BFS fashion.
559  tensor::ExtractSliceOp candidateSliceOp = candidates.front();
560  candidates.pop_front();
561 
562  // 2b. Get the producer of the source (potentially walking through
563  // `iter_args` of nested `scf.for`)
564  auto [fusableProducer, destinationIterArg] =
565  getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0),
566  tileAndFuseResult.loops);
567  if (!fusableProducer)
568  continue;
569 
570  // 2c. Generate the tiled implementation of the producer of the source
571  rewriter.setInsertionPoint(candidateSliceOp);
572  FailureOr<Value> fusedProducerValue =
573  tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp,
574  fusableProducer);
575  if (failed(fusedProducerValue))
576  continue;
577  rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value());
578 
579  // 2d. The operands of the fused producer might themselved be slices of
580  // values produced by operations that implement the `TilingInterface`.
581  // Add these operations to the worklist.
582  Operation *fusedProducer = fusedProducerValue->getDefiningOp();
583  tileAndFuseResult.tiledAndFusedOps.insert(fusedProducer);
584  addCandidateSlices(fusedProducer, candidates);
585 
586  // 2e. If the slice is for a destination operand, for example,
587  //
588  // ```mlir
589  // %0 = linalg.init
590  // %1 = linalg.fill .. outs(%0 : )
591  // %2 = scf.for .. iter_args(%arg0 = %1) {
592  // %3 = scf.for .. iter_args(%arg1 = %arg0) {
593  // %4 = tensor.extract_slice %arg1 [..]
594  // .. = linalg.matmul .. outs(%4 : )
595  // }
596  // }
597  // ```
598  //
599  // the IR is currently
600  //
601  // ```
602  // %0 = linalg.init
603  // %1 = linalg.fill
604  // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
605  // %3 = scf.for .. iter_args(%arg1 = %arg0) {
606  // %4 = tensor.extract_slice %0 /*incorrect value */ [..]
607  // %5 = linalg.fill .. outs(%4 : )
608  // .. = linalg.matmul .. outs(%5 : )
609  // }
610  // }
611  // ```
612  //
613  // The untiled `linalg.fill` is still used as the `init_value` since it
614  // was originally a destination operand of the untiled `linalg.matmul`.
615  // When fusing an operand that is a destination operand.
616  // - Update the iter_arg of the outer most loop to use the destination
617  // of the untiled producer.
618  // - Update the destination of the slice of the tiled producer generated
619  // to use the same basic block argument as the slice that was used to
620  // generate inplace the tiled implementation of the producer.
621  // With this the IR will be.
622  //
623  // ```
624  // %0 = linalg.init
625  // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
626  // %2 = scf.for .. iter_args(%arg1 = %arg0) {
627  // %3 = tensor.extract_slice %arg1 /* corrected value */ [..]
628  // %4 = linalg.fill .. outs(%3 : )
629  // .. = linalg.matmul .. outs(%4 : )
630  // }
631  // }
632  // ```
633  // TODO: This can be modeled better if the `DestinationStyleOpInterface`.
634  // Update to use that when it does become available.
635  scf::ForOp outerMostLoop = tileAndFuseResult.loops.front();
636  Optional<unsigned> iterArgNumber;
637  if (destinationIterArg) {
638  iterArgNumber = outerMostLoop.getIterArgNumberForOpOperand(
639  *destinationIterArg.value());
640  }
641  if (iterArgNumber) {
642  int64_t resultNumber = fusableProducer.getResultNumber();
643  if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(
644  fusableProducer.getOwner())) {
645  outerMostLoop.setIterArg(
646  iterArgNumber.value(),
647  dstOp.getTiedOpOperand(fusableProducer)->get());
648  }
649  if (auto dstOp = fusedProducerValue.value()
650  .getDefiningOp<DestinationStyleOpInterface>()) {
651  scf::ForOp innerMostLoop = tileAndFuseResult.loops.back();
653  rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
654  innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]);
655  }
656  }
657  }
658  return tileAndFuseResult;
659 }
660 
661 //===----------------------------------------------------------------------===//
662 // lowerToLoopsUsingSCFForOp implementation.
663 //===----------------------------------------------------------------------===//
664 
667  TilingInterface op) {
668  // TODO: Handle cases where the op has results if needed.
669  if (op->getNumResults() > 0) {
670  return rewriter.notifyMatchFailure(
671  op, "unable to lower to loops operations with return values");
672  }
673 
674  SmallVector<Range> domain = op.getIterationDomain(rewriter);
675  SmallVector<Value> ivs;
677  Location loc = op.getLoc();
678  for (auto loopRange : domain) {
679  Value offsetVal =
680  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
681  Value sizeVal =
682  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
683  Value strideVal =
684  getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
685  auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
686  strideVal, ValueRange{});
687  loops.push_back(loop);
688  ivs.push_back(loop.getInductionVar());
689  rewriter.setInsertionPoint(loop.getBody()->getTerminator());
690  }
691  if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
692  return failure();
693  }
694  return loops;
695 }
static llvm::ManagedStatic< PassManagerOptions > options
static void updateDestinationOperandsForTiledOp(OpBuilder &builder, ValueRange tiledOpDestinationValues, ValueRange bbArgsList)
If the tiled operation is destination passing style, update the slice of the destination used (which ...
static SmallVector< int64_t > fillInterchangeVector(ArrayRef< int64_t > interchangeVector, size_t iterationDomainSize)
Helper method to adjust the interchange vector to match the iteration domain.
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, Range loopRange, Value iv, Value tileSize)
Returns the bounded tile size given the current iv, loopRange and tileSize, i.e., min(tileSize,...
static SmallVector< scf::ForOp > generateTileLoopNest(OpBuilder &builder, Location loc, ArrayRef< Range > loopRanges, ArrayRef< Value > tileSizeVals, SmallVector< OpFoldResult > &offsets, SmallVector< OpFoldResult > &sizes)
Generate an empty loop nest that represents the tiled loop nest shell.
static std::tuple< OpResult, Optional< OpOperand * > > getUntiledProducerFromSliceSource(OpOperand *source, ArrayRef< scf::ForOp > loops)
Return the untiled producer whose slice is used in a tiled consumer.
static FailureOr< SmallVector< Value > > yieldTiledValues(RewriterBase &rewriter, ValueRange initValues, ValueRange yieldedValues, ArrayRef< SmallVector< OpFoldResult >> tileOffsetsList, ArrayRef< SmallVector< OpFoldResult >> tileSizesList, MutableArrayRef< scf::ForOp > loops)
For a value to be yielded (yieldedValue) from within a loop nest loops, construct the destructive upd...
static bool tileDividesIterationDomain(Range loopRange)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:42
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class represents an argument of a Block.
Definition: Value.h:296
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:109
MLIRContext * getContext() const
Definition: Builders.h:54
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:137
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:300
This class helps build Operations.
Definition: Builders.h:198
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:383
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:350
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:472
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:364
This class represents a single result from folding an operation.
Definition: OpDefinition.h:233
This class represents an operand of an operation.
Definition: Value.h:247
This is a value defined by a result of an operation.
Definition: Value.h:442
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:324
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:295
result_range getResults()
Definition: Operation.h:332
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:610
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:398
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:522
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
U dyn_cast() const
Definition: Value.h:95
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:89
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
FailureOr< scf::SCFReductionTilingResult > tileReductionUsingScf(PatternRewriter &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > tileSize)
Method to tile a reduction and generate a parallel op within a serial loop.
FailureOr< SmallVector< scf::ForOp > > lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op)
Method to lower an op that implements the TilingInterface to loops/scalars.
FailureOr< SCFTilingResult > tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, const SCFTilingOptions &options)
Method to tile an op that implements the TilingInterface using scf.for for iterating over the tiles.
FailureOr< SCFTileAndFuseResult > tileConsumerAndFuseProducerGreedilyUsingSCFForOp(RewriterBase &rewriter, TilingInterface consumer, const SCFTileAndFuseOptions &options)
Method to tile and fuse a sequence of operations, by tiling the consumer and fusing its producers.
FailureOr< Value > replaceExtractSliceWithTiledProducer(OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp)
Pattern to swap an tensor.extract_slice with its producer when the producer implements the TilingInte...
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:100
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:329
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:336
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1073
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:669
Optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:306
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:343
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:53
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< scf::ForOp > replaceLoopNestWithNewYields(OpBuilder &builder, ArrayRef< scf::ForOp > loopNest, ValueRange newIterOperands, const NewYieldValueFn &newYieldValueFn, bool replaceIterOperandsUsesInLoop=true)
Update a perfectly nested loop nest to yield new values from the innermost loop and propagating it up...
Definition: Utils.cpp:109
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
Definition: IndexingUtils.h:71
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBBArgs)> NewYieldValueFn
Replace the loop with newIterOperands added as new initialization values.
Definition: Utils.h:51
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
Transformation information returned after reduction tiling.
Operation * parallelTiledOp
The partial reduction tiled op generated.
Operation * mergeOp
The final reduction operation merging all the partial reductions.
SmallVector< scf::ForOp > loops
The scf.for operations that iterate over the tiles.
Options used to control tile + fuse.
Transformation information returned after tile and fuse.
SmallVector< scf::ForOp > loops
The scf.for operations that iterate over the tiles.
llvm::DenseMap< Value, Value > replacements
The replacement values to use for the tiled and fused operations.
llvm::SetVector< Operation * > tiledAndFusedOps
List of tiled and fused operations generated.
Options to use to control tiling.
SCFTileSizeComputationFunction tileSizeComputationFunction
Computation function that returns the tile sizes for each operation.
SCFTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
Transformation information returned after tiling.
SmallVector< Operation * > tiledOps
Tiled operations that are generated during tiling.
SmallVector< scf::ForOp > loops
The scf.for operations that iterate over the tiles.
SmallVector< Value > replacements
Values to use as replacements for the untiled op.