MLIR  19.0.0git
Tiling.cpp
Go to the documentation of this file.
1 //===- Tiling.cpp - Implementation of linalg Tiling -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Tiling pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
26 #include "mlir/IR/AffineExpr.h"
27 #include "mlir/IR/AffineMap.h"
28 #include "mlir/IR/BuiltinOps.h"
29 #include "mlir/IR/ValueRange.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/Support/CommandLine.h"
34 #include <utility>
35 
36 namespace mlir {
37 #define GEN_PASS_DEF_LINALGTILINGPASS
38 #include "mlir/Dialect/Linalg/Passes.h.inc"
39 } // namespace mlir
40 
41 using namespace mlir;
42 using namespace mlir::affine;
43 using namespace mlir::linalg;
44 using namespace mlir::scf;
45 
46 #define DEBUG_TYPE "linalg-tiling"
47 
48 std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
50  ArrayRef<OpFoldResult> allShapeSizes,
51  ArrayRef<OpFoldResult> allTileSizes) {
52  assert(allTileSizes.size() == map.getNumResults());
53  // Apply `map` to get shape sizes in loop order.
54  SmallVector<OpFoldResult> shapeSizes =
55  makeComposedFoldedMultiResultAffineApply(b, loc, map, allShapeSizes);
56  SmallVector<OpFoldResult> tileSizes(allTileSizes.begin(), allTileSizes.end());
57 
58  // Traverse the tile sizes, which are in loop order, erase zeros everywhere.
59  LoopIndexToRangeIndexMap loopIndexToRangeIndex;
60  for (int idx = 0, e = tileSizes.size(), zerosCount = 0; idx < e; ++idx) {
61  if (getConstantIntValue(tileSizes[idx - zerosCount]) ==
62  static_cast<int64_t>(0)) {
63  shapeSizes.erase(shapeSizes.begin() + idx - zerosCount);
64  tileSizes.erase(tileSizes.begin() + idx - zerosCount);
65  ++zerosCount;
66  continue;
67  }
68  loopIndexToRangeIndex[idx] = idx - zerosCount;
69  }
70 
71  // Create a new range with the applied tile sizes.
73  for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx)
74  res.push_back(Range{b.getIndexAttr(0), shapeSizes[idx], tileSizes[idx]});
75  return std::make_tuple(res, loopIndexToRangeIndex);
76 }
77 
79  RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
80  const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
81  SmallVector<Value> allIvs(op.getNumLoops(), nullptr);
82  for (auto en : enumerate(allIvs)) {
83  auto rangeIndex = loopIndexToRangeIndex.find(en.index());
84  if (rangeIndex == loopIndexToRangeIndex.end())
85  continue;
86  en.value() = ivs[rangeIndex->second];
87  }
88  offsetIndices(b, op, getAsOpFoldResult(allIvs));
89 }
90 
91 /// Asserts that the given index-typed value is strictly positive. If the value
92 /// is an attribute, asserts at compile time, otherwise emits an assertion
93 /// checked at runtime.
95  OpFoldResult value) {
96  if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
97  assert(cast<IntegerAttr>(attr).getValue().isStrictlyPositive() &&
98  "expected strictly positive tile size and divisor");
99  return;
100  }
101 
102  Value zero = b.create<arith::ConstantIndexOp>(0);
103  Value condition = b.create<arith::CmpIOp>(arith::CmpIPredicate::sgt,
104  value.get<Value>(), zero);
105  b.create<cf::AssertOp>(
106  condition,
107  b.getStringAttr("expected strictly positive tile size and divisor"));
108 }
109 
111 mlir::linalg::computeStaticMultiTileSizes(LinalgOp op, unsigned dimension,
112  int64_t targetSize, int64_t divisor) {
113  assert(!op.hasDynamicShape() &&
114  "cannot compute static multi-tile sizes for an op with dynamic shape");
115  assert(targetSize > 0 && "target size must be non-negative");
116  assert(divisor > 0 && "divisor must be non-negative");
117  assert(dimension < op.getNumLoops() && "dimension overflow");
118 
120  int64_t tripCount = op.getStaticLoopRanges()[dimension];
121  int64_t a = tripCount / divisor;
122  int64_t t = (targetSize + divisor - 1) / divisor;
123  int64_t totalTripCount = (a + t - 1) / t;
124  spec.lowTileSize = (a / totalTripCount) * divisor;
125  spec.highTileSize = spec.lowTileSize + divisor;
126  spec.highTripCount = a % totalTripCount;
127  spec.lowTripCount = totalTripCount - spec.highTripCount;
128  if (spec.lowTileSize * spec.lowTripCount +
129  spec.highTileSize * spec.highTripCount !=
130  tripCount) {
131  return failure();
132  }
133  return spec;
134 }
135 
138  unsigned dimension, OpFoldResult targetSize,
139  OpFoldResult divisor, bool emitAssertions) {
140  // Bail out on dimension overflow.
141  if (dimension >= op.getNumLoops())
142  return failure();
143 
144  // The code below works only on values.
145  Location loc = op.getLoc();
146  ImplicitLocOpBuilder b(loc, builder);
147  if (emitAssertions) {
148  emitIsPositiveIndexAssertion(b, targetSize);
149  emitIsPositiveIndexAssertion(b, divisor);
150  }
151  Value targetSizeValue =
152  getValueOrCreateConstantIndexOp(builder, loc, targetSize);
153  Value divisorValue = getValueOrCreateConstantIndexOp(builder, loc, divisor);
154 
155  // Find the trip count of the iteration space dimension for which the tile
156  // sizes are computed.
157  SmallVector<OpFoldResult> allShapes =
158  op.createFlatListOfOperandDims(b, b.getLoc());
159  AffineMap shapesToLoops = op.getShapesToLoopsMap();
160  SmallVector<OpFoldResult> loopRanges =
161  makeComposedFoldedMultiResultAffineApply(b, op.getLoc(), shapesToLoops,
162  allShapes);
163  Value tripCount =
164  getValueOrCreateConstantIndexOp(b, op.getLoc(), loopRanges[dimension]);
165 
166  // Compute the tile sizes and the respective numbers of tiles.
170  auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
171  return affine::makeComposedAffineApply(b, b.getLoc(), expr, ofrs);
172  };
173  Value a = apply(s0.floorDiv(s1), {tripCount, divisorValue});
174  Value t = apply((s0 + s1 - 1).floorDiv(s1), {targetSizeValue, divisorValue});
175  Value d = apply((s0 + s1 - 1).floorDiv(s1), {a, t});
176  Value s = apply(s0.floorDiv(s1) * s2, {a, d, divisorValue});
177  Value v = apply(s0 % s1, {a, d});
178  Value u = apply(s0 - s1, {d, v});
179 
181  spec.lowTileSize = s;
182  spec.highTileSize = apply(s0 + s1, {s, divisorValue});
183  spec.lowTripCount = u;
184  spec.highTripCount = v;
185 
186  // If requested, emit the check that the tile sizes are computed correctly.
187  // For example, for iteration dimension size of 15 and the target size 8 it is
188  // impossible to find two tile sizes both divisible by 8 that fully cover the
189  // original space dimension.
190  if (emitAssertions) {
191  AffineExpr s3 = builder.getAffineSymbolExpr(3);
192  Value coveredSize =
193  apply(s0 * s1 + s2 * s3, {spec.lowTileSize, spec.lowTripCount,
194  spec.highTileSize, spec.highTripCount});
195  Value equals = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
196  coveredSize, tripCount);
197  b.create<cf::AssertOp>(
198  equals, builder.getStringAttr(
199  "could not compute dynamic multi-size tile shapes"));
200  }
201 
202  return spec;
203 }
204 
205 /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
206 /// than `iterationSize`.
208  OpFoldResult numThreads,
209  OpFoldResult iterationSize) {
210  std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
211  std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
212  std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
213  if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
214  return false;
215  return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
216 }
217 
218 /// Build an `affine_max` of all the `vals`.
220  ArrayRef<OpFoldResult> vals) {
222  b, loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()),
223  vals);
224 }
225 
226 /// Build an `affine_min` of all the `vals`.
228  ArrayRef<OpFoldResult> vals) {
230  b, loc, AffineMap::getMultiDimIdentityMap(vals.size(), loc.getContext()),
231  vals);
232 }
233 
234 /// Fill out the `tiledOffsets` and `tiledSizes` to be used to tile to a given
235 /// number of threads.
237  RewriterBase &b, Location loc, scf::ForallOp forallOp,
238  ArrayRef<OpFoldResult> numThreads, SmallVector<Range> loopRanges,
239  bool omitTileOffsetBoundsCheck,
240  std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
241  SmallVector<OpFoldResult> &tiledOffsets,
242  SmallVector<OpFoldResult> &tiledSizes) {
244  b.setInsertionPointToStart(forallOp.getBody(0));
245 
246  ValueRange threadIds = forallOp.getInductionVars();
247  SmallVector<OpFoldResult> nonZeroNumThreads =
248  llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
249  return !isConstantIntValue(ofr, 0);
250  }));
251  int64_t nLoops = loopRanges.size();
252  tiledOffsets.reserve(nLoops);
253  tiledSizes.reserve(nLoops);
254  for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) {
255  bool overflow = loopIdx >= numThreads.size();
256  bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0);
257  // Degenerate case: take the whole domain.
258  if (overflow || isZero) {
259  tiledOffsets.push_back(loopRanges[loopIdx].offset);
260  tiledSizes.push_back(loopRanges[loopIdx].size);
261  continue;
262  }
263 
264  // Tiled case: compute the offset and size.
265  AffineExpr i, j, m, n, o;
266  bindDims(b.getContext(), i, j);
267  bindSymbols(b.getContext(), m, n, o);
268  OpFoldResult size = loopRanges[loopIdx].size;
269  OpFoldResult offset = loopRanges[loopIdx].offset;
270  OpFoldResult threadId = threadIds[threadIdIdx];
271  // Symbolic fixed max size per thread.
272  // TODO: floor + 0/1 depending on case for better load-balancing.
273  OpFoldResult tileSizePerThread =
274  nominalTileSizes.has_value()
275  ? (*nominalTileSizes)[loopIdx]
277  b, loc, m.ceilDiv(n),
278  ArrayRef<OpFoldResult>{size, nonZeroNumThreads[threadIdIdx]});
279 
280  // Dynamic offset shifted by threadId * maxSizePerThread.
282  b, loc, i + j * m, {offset, threadId, tileSizePerThread});
283  // Dynamic upper-bound depending on the threadId.
284  OpFoldResult residualTileSize = makeComposedFoldedAffineApply(
285  b, loc, i + j * m - n,
286  {offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size});
287  if (!isConstantIntValue(residualTileSize, 0)) {
288  OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply(
289  b, loc, -i + m, {offsetPerThread, size});
290  tileSizePerThread =
291  buildMin(b, loc, {sizeMinusOffsetPerThread, tileSizePerThread});
292  }
293 
294  tiledOffsets.push_back(offsetPerThread);
295  // TODO: if tileSizePerThread <= 0 early exit.
296  if (!omitTileOffsetBoundsCheck &&
297  !canOmitTileOffsetInBoundsCheck(tileSizePerThread,
298  nonZeroNumThreads[threadIdIdx], size))
299  tileSizePerThread =
300  buildMax(b, loc, {b.getIndexAttr(0), tileSizePerThread});
301 
302  tiledSizes.push_back(tileSizePerThread);
303  ++threadIdIdx;
304  }
305 }
306 
307 /// Returns a vector of bools representing if, for each axis, `op` can be tiled
308 /// without incurring in a race condition and thus it is thread-safe to do the
309 /// tiling. This is checked by iterating over numThreads and ensuring that the
310 /// corresponding iterator type is "parallel". If it is not, then we know that
311 /// such dimension is unsafe to tile.
313  ArrayRef<OpFoldResult> numThreads) {
314  auto iterators = linalgOp.getIteratorTypesArray();
315  SmallVector<bool> safeToTile(numThreads.size(), true);
316 
317  for (unsigned i = 0, e = numThreads.size(); i != e; i++) {
318  if (auto attr = llvm::dyn_cast_if_present<Attribute>(numThreads[i])) {
319  if (cast<IntegerAttr>(attr).getValue().getSExtValue() > 1) {
320  safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
321  }
322  } else {
323  safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
324  }
325  }
326  return safeToTile;
327 }
328 
329 /// Rewrite a TilingInterface `op` to a tiled `scf.forall`. The
330 /// tiling is specified by the number of tiles/threads `numThreads` and the
331 /// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is
332 /// not specified, then it is derived from `numThreads` as `ceilDiv(dimSize[i],
333 /// numThreads[i])`. If non-empty, the `mapping` is added as an
334 /// attribute to the resulting `scf.forall`. A zero tile sizes indicate
335 /// that the dimension is not tiled, and can be thought of as tiling by the full
336 /// size of data.
337 /// It is the user's responsibility to ensure that `numThreads` is a valid
338 /// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the
339 /// Linalg case). If the dimension is not parallelizable, a warning is issued to
340 /// notify the user that the generated code is not safe to parallelize. If
341 /// `omitTileOffsetBoundsCheck` is true, then the function will assume that
342 /// `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
344  RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads,
345  std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
346  std::optional<ArrayAttr> mapping, bool omitTileOffsetBoundsCheck) {
347  Location loc = op->getLoc();
349 
350  SmallVector<Range> loopRanges = op.getIterationDomain(b);
351  if (loopRanges.empty())
352  return op->emitOpError("expected non-empty loop ranges");
353  auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
354  if (llvm::any_of(loopRanges, hasStrideOne))
355  return op->emitOpError("only stride-1 supported atm");
356 
357  // Gather destination tensors.
358  SmallVector<Value> dest;
359  if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
360  return op->emitOpError("failed to get destination tensors");
361 
362  SmallVector<OpFoldResult> nonZeroNumThreads =
363  llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
364  return !isConstantIntValue(ofr, 0);
365  }));
366  SmallVector<Value> materializedNonZeroNumThreads =
367  llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) {
368  return getValueOrCreateConstantIndexOp(b, loc, ofr);
369  }));
370 
371  LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation());
372  if (linalgOp) {
373  // Check if tiling is thread safe and print a warning if not.
374  SmallVector<bool> tilingSafety =
375  safeToTileToForall(b.getContext(), linalgOp, numThreads);
376  for (size_t i = 0; i < tilingSafety.size(); i++)
377  if (!tilingSafety[i])
378  op.emitWarning() << "tiling is not thread safe at axis #" << i;
379  }
380 
381  // 1. Create the ForallOp. We don't use the lambda body-builder
382  // version because we require the use of RewriterBase in the body, so we
383  // manually move the insertion point to the body below.
384  scf::ForallOp forallOp = b.create<scf::ForallOp>(
385  loc, getAsOpFoldResult((materializedNonZeroNumThreads)), dest, mapping);
386 
387  // 2. Fill out the ForallOp body.
388  SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
389  calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, loopRanges,
390  omitTileOffsetBoundsCheck, nominalTileSizes,
391  tiledOffsets, tiledSizes);
392 
393  // 3. Clone the tileable op and update its destination operands to use the
394  // output bbArgs of the ForallOp.
395  ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
396  Operation *tiledOp = nullptr;
397  SmallVector<Value> tiledValues;
398  {
399  // 3.a. RAII guard, inserting within forallOp, before terminator.
401  b.setInsertionPoint(forallOp.getTerminator());
402  Operation *clonedOp = b.clone(*op.getOperation());
403  auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
404  if (destinationStyleOp) {
405  for (OpOperand &outOperand : destinationStyleOp.getDpsInitsMutable()) {
406  // Swap tensor inits with the corresponding block argument of the
407  // scf.forall op. Memref inits remain as is.
408  if (isa<TensorType>(outOperand.get().getType())) {
409  auto *it = llvm::find(dest, outOperand.get());
410  assert(it != dest.end() && "could not find destination tensor");
411  unsigned destNum = std::distance(dest.begin(), it);
412  outOperand.set(destBbArgs[destNum]);
413  }
414  }
415  }
416 
417  // 4. Tile the cloned op and delete the clone.
418  FailureOr<TilingResult> tilingResult =
419  cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
420  tiledSizes);
421  if (failed(tilingResult))
422  return clonedOp->emitError("Failed to tile op: ");
423  if (tilingResult->tiledOps.size() != 1) {
424  return clonedOp->emitError("expected a single produced tiled op, got ")
425  << tilingResult->tiledOps.size();
426  }
427 
428  b.eraseOp(clonedOp);
429  tiledOp = tilingResult->tiledOps.front();
430  tiledValues = tilingResult->tiledValues;
431  }
432 
433  // 5. Parallel insert back into the result tensor.
434  for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())),
435  tiledValues, destBbArgs)) {
436  // 5.a. Partial subset information is inserted just before the terminator.
438  b.setInsertionPoint(forallOp.getTerminator());
439 
440  SmallVector<OpFoldResult> resultOffsets, resultSizes;
441  if (failed(op.getResultTilePosition(b, std::get<0>(it), tiledOffsets,
442  tiledSizes, resultOffsets,
443  resultSizes)))
444  return op->emitOpError("output offsets couldn't be calculated");
445  SmallVector<OpFoldResult> strides(resultSizes.size(), b.getIndexAttr(1));
446 
447  // 5.b. Parallel insertions are inserted at the end of the combining
448  // terminator.
449  b.setInsertionPointToEnd(forallOp.getTerminator().getBody());
450  b.create<tensor::ParallelInsertSliceOp>(loc, std::get<1>(it),
451  std::get<2>(it), resultOffsets,
452  resultSizes, strides);
453  }
454  return ForallTilingResult{forallOp, tiledOp};
455 }
456 
458 linalg::tileToForallOp(RewriterBase &b, TilingInterface op,
459  ArrayRef<OpFoldResult> numThreads,
460  std::optional<ArrayAttr> mapping) {
461  return tileToForallOpImpl(b, op, numThreads,
462  /*nominalTileSizes=*/std::nullopt, mapping,
463  /*omitTileOffsetBoundsCheck=*/false);
464 }
465 
468  ArrayRef<OpFoldResult> tileSizes,
469  std::optional<ArrayAttr> mapping) {
470  SmallVector<Range> loopRanges = op.getIterationDomain(b);
471  unsigned nLoops = loopRanges.size();
472  SmallVector<OpFoldResult> numThreads;
473  numThreads.reserve(nLoops);
474  AffineExpr s0, s1;
475  bindSymbols(b.getContext(), s0, s1);
476  AffineExpr divExpr = s0.ceilDiv(s1);
477  for (const auto &it : llvm::zip(tileSizes, loopRanges)) {
478  OpFoldResult numTiles = std::get<0>(it);
479  if (!isConstantIntValue(numTiles, 0))
481  b, op.getLoc(), divExpr, {std::get<1>(it).size, std::get<0>(it)});
482  numThreads.push_back(numTiles);
483  }
484  return tileToForallOpImpl(b, op, numThreads,
485  /*nominalTileSizes=*/tileSizes, mapping,
486  /*omitTileOffsetBoundsCheck=*/true);
487 }
488 
489 template <typename LoopTy>
492  const LinalgTilingOptions &options) {
494 
495  auto nLoops = op.getNumLoops();
496  // Initial tile sizes may be too big, only take the first nLoops.
497  tileSizes = tileSizes.take_front(nLoops);
498 
499  if (llvm::all_of(tileSizes, [](OpFoldResult ofr) {
500  return getConstantIntValue(ofr) == static_cast<int64_t>(0);
501  })) {
502  TiledLinalgOp tiledOp;
503  tiledOp.op = cast<LinalgOp>(b.clone(*op.getOperation()));
504  tiledOp.tensorResults.assign(tiledOp.op->result_begin(),
505  tiledOp.op->result_end());
506  return tiledOp;
507  }
508 
509  // 1. Build the tiled loop ranges.
510  SmallVector<OpFoldResult> allShapeSizes =
511  op.createFlatListOfOperandDims(b, op.getLoc());
512  AffineMap shapeSizesToLoopsMap = op.getShapesToLoopsMap();
513  if (!shapeSizesToLoopsMap)
514  return failure();
515 
516  auto [loopRanges, loopIndexToRangeIndex] = makeTiledLoopRanges(
517  b, op.getLoc(), shapeSizesToLoopsMap, allShapeSizes, tileSizes);
518 
520  for (const auto &attr : enumerate(op.getIteratorTypesArray())) {
521  if (loopIndexToRangeIndex.count(attr.index()))
522  iteratorTypes.push_back(attr.value());
523  }
524  // If interchangeVector is empty, use the identity. Build the permutation map
525  // otherwise.
526  auto invPermutationMap =
527  AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext());
528  if (!options.interchangeVector.empty()) {
529  // Based on the pruned iterations (due to zero tile size), recompute the
530  // interchange vector.
531  SmallVector<unsigned, 4> interchangeVector;
532  interchangeVector.reserve(options.interchangeVector.size());
533  for (auto pos : options.interchangeVector) {
534  auto it = loopIndexToRangeIndex.find(pos);
535  if (it == loopIndexToRangeIndex.end())
536  continue;
537  interchangeVector.push_back(it->second);
538  }
539  // Interchange vector is guaranteed to be a permutation,
540  // `inversePermutation` must succeed.
541  invPermutationMap = inversePermutation(
542  AffineMap::getPermutationMap(interchangeVector, b.getContext()));
543  assert(invPermutationMap);
544  SmallVector<int64_t> permutation(interchangeVector.begin(),
545  interchangeVector.end());
546  applyPermutationToVector(loopRanges, permutation);
547  applyPermutationToVector(iteratorTypes, permutation);
548  }
549 
550  // Handle distribution. Create a vector of the same size of loops that are to
551  // be tiled.
553  if (options.distribution) {
554  procInfo.resize(
555  iteratorTypes.size(),
556  linalg::ProcInfo{nullptr, nullptr, linalg::DistributionMethod::None});
557  // Collect loop ranges of tiled loops, loops that are parallel.
558  SmallVector<Range> parallelLoopRanges;
559  for (const auto &iteratorType : llvm::enumerate(iteratorTypes)) {
560  if (!isParallelIterator(iteratorType.value()))
561  break;
562  parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);
563  }
564  auto returnedProcInfo =
565  options.distribution->procInfo(b, op.getLoc(), parallelLoopRanges);
566  unsigned procIdIdx = 0;
567  // Update the distribution information for the loops.
568  for (const auto &iteratorType : llvm::enumerate(iteratorTypes)) {
569  if (!isParallelIterator(iteratorType.value()))
570  break;
571  procInfo[iteratorType.index()] = returnedProcInfo[procIdIdx++];
572  }
573  }
574 
575  // 2. Create the tiled loops.
576  LinalgOp res = op;
577  SmallVector<Value, 4> ivs, tensorResults;
578  auto tiledLoopBodyBuilder =
579  [&](OpBuilder &builder, Location loc, ValueRange localIvs,
580  ValueRange operandValuesToUse) -> scf::ValueVector {
581  ivs.assign(localIvs.begin(), localIvs.end());
582 
583  // When an `interchangeVector` is present, it has been applied to the
584  // loop ranges and the iterator types. Apply its inverse to the
585  // resulting loop `ivs` to match the op definition.
586  SmallVector<Value, 4> interchangedIvs;
587  if (!options.interchangeVector.empty()) {
588  for (AffineExpr result : invPermutationMap.getResults())
589  interchangedIvs.push_back(
590  ivs[cast<AffineDimExpr>(result).getPosition()]);
591  } else {
592  interchangedIvs.assign(ivs.begin(), ivs.end());
593  }
594 
595  // Tile the `operandValuesToUse` that either match the `op` operands
596  // themselves or the tile loop arguments forwarding them.
597  assert(operandValuesToUse.size() ==
598  static_cast<size_t>(op->getNumOperands()) &&
599  "expect the number of operands and inputs and outputs to match");
600  SmallVector<Value> valuesToTile = operandValuesToUse;
601  SmallVector<OpFoldResult> sizeBounds =
602  makeComposedFoldedMultiResultAffineApply(b, loc, shapeSizesToLoopsMap,
603  allShapeSizes);
604  SmallVector<Value> tiledOperands = makeTiledShapes(
605  b, loc, op, valuesToTile, getAsOpFoldResult(interchangedIvs), tileSizes,
606  sizeBounds,
607  /*omitPartialTileCheck=*/false);
608 
609  SmallVector<Type> resultTensorTypes =
610  getTensorOutputTypes(op, tiledOperands);
611  res = clone(b, op, resultTensorTypes, tiledOperands);
612  tensorResults =
613  insertSlicesBack(builder, loc, op, tiledOperands, res->getResults());
614  return scf::ValueVector(tensorResults.begin(), tensorResults.end());
615  };
616  GenerateLoopNest<LoopTy>::doit(b, op.getLoc(), loopRanges, op, iteratorTypes,
617  tiledLoopBodyBuilder, procInfo);
618 
619  // 3. Transform IndexOp results w.r.t. the tiling.
620  transformIndexOps(b, res, ivs, loopIndexToRangeIndex);
621 
622  // 4. Gather the newly created loops and return them with the new op.
624  loops.reserve(ivs.size());
625  for (auto iv : ivs) {
626  if (isa<BlockArgument>(iv)) {
627  loops.push_back(cast<BlockArgument>(iv).getOwner()->getParentOp());
628  assert(loops.back() && "no owner found for induction variable!");
629  } else {
630  // TODO: Instead of doing this, try to recover the ops used instead of the
631  // loop.
632  loops.push_back(nullptr);
633  }
634  }
635 
636  // 5. Get the tensor results from the outermost loop if available. Otherwise
637  // use the previously captured `tensorResults`.
638  Operation *outermostLoop = nullptr;
639  for (Operation *loop : loops)
640  if ((outermostLoop = loop))
641  break;
642 
643  return TiledLinalgOp{
644  res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults};
645 }
646 
648  RewriterBase &b, PartialReductionOpInterface op,
649  ArrayRef<OpFoldResult> numThreads, ArrayRef<OpFoldResult> tileSizes,
650  std::optional<ArrayAttr> mapping) {
651  Location loc = op.getLoc();
653 
654  // Ops implementing PartialReductionOpInterface are expected to implement
655  // TilingInterface.
656  // TODO: proper core mechanism to tie interfaces together.
657  auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
658 
659  // Ops implementing PartialReductionOpInterface are not necessarily expected
660  // to implement TilingInterface.. This cast is unsafe atm.
661  // TODO: proper core mechanism to tie interfaces together.
662  // TODO: this function requires a pair of interfaces ..
663  auto destinationStyleOp =
664  dyn_cast<DestinationStyleOpInterface>(op.getOperation());
665  if (!destinationStyleOp)
666  return b.notifyMatchFailure(op, "not a destination style op");
667 
668  // Actually this only work for Linalg ops atm.
669  auto linalgOp = dyn_cast<linalg::LinalgOp>(op.getOperation());
670  if (!linalgOp)
671  return b.notifyMatchFailure(op, "not a linalg op");
672 
673  SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
674  if (op->getNumResults() != 1)
675  return b.notifyMatchFailure(
676  op, "don't support ops with multiple results for now");
677 
679  tilingInterfaceOp.getLoopIteratorTypes();
680  SmallVector<unsigned> redDims;
681  linalgOp.getReductionDims(redDims);
682  if (redDims.size() != 1)
683  return b.notifyMatchFailure(
684  op, "only support ops with one reduction dimension.");
685  if (!tileSizes.empty() && tileSizes.size() != numThreads.size())
686  return b.notifyMatchFailure(op, "if tile sizes are present it must have as "
687  "many elements as number of threads");
688  int reductionDim = static_cast<int>(redDims.front());
689 
690  if (redDims.front() >= numThreads.size())
691  return b.notifyMatchFailure(
692  op, "reduction dimension must be mapped to threads");
693 
694  // 1. Create the inital tensor value.
695  FailureOr<Operation *> identityTensor =
696  op.generateInitialTensorForPartialReduction(b, loc, numThreads,
697  reductionDim);
698  if (failed(identityTensor))
699  return b.notifyMatchFailure(op,
700  "cannot create a tensor of identity value.");
701 
702  // Gather destination tensors.
703  SmallVector<Value> dest;
704  if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
705  return b.notifyMatchFailure(op, "failed to get destination tensors");
706 
707  Operation *tiledOp = nullptr;
708 
709  SmallVector<OpFoldResult> nonZeroNumThreads =
710  llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
711  return !isConstantIntValue(ofr, 0);
712  }));
713  SmallVector<Value> materializedNonZeroNumThreads =
714  getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads);
715 
716  // 2. Create the ForallOp with an empty region.
717  scf::ForallOp forallOp = b.create<scf::ForallOp>(
718  loc, getAsOpFoldResult(materializedNonZeroNumThreads),
719  (*identityTensor)->getResults(), mapping);
720 
721  // 3. Calculate the tile offsets and sizes for the subsequent loop that will
722  // be nested under `forallOp`.
723  SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
724  calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, iterationDomain,
725  /*omitTileOffsetBoundsCheck =*/false,
726  /*nominalTileSizes=*/std::nullopt, tiledOffsets,
727  tiledSizes);
728 
729  // 4. Clone the tileable op and update its destination operands to use the
730  // output bbArgs of the ForallOp.
731  SmallVector<Value> tilingResults;
732  ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
733  {
734  // 4.a. RAII guard, inserting within forallOp, before terminator.
736  b.setInsertionPoint(forallOp.getTerminator());
737 
738  SmallVector<Value> tiledDpsInitOperands;
739  for (Value initOperand : destinationStyleOp.getDpsInits()) {
740  auto *it = llvm::find(dest, initOperand);
741  assert(it != dest.end() && "dest operand not found in dest");
742  unsigned destNum = std::distance(dest.begin(), it);
743  SmallVector<OpFoldResult> strides(numThreads.size(), b.getIndexAttr(1));
744  SmallVector<OpFoldResult> outOffsets(numThreads.size(),
745  b.getIndexAttr(0));
746  SmallVector<OpFoldResult> sizes = tiledSizes;
747  sizes[reductionDim] = b.getIndexAttr(1);
748  outOffsets[reductionDim] = forallOp.getInductionVars().front();
749  // TODO: use SubsetExtractOpInterface once it is available.
750  tiledDpsInitOperands.push_back(b.create<tensor::ExtractSliceOp>(
751  loc, cast<RankedTensorType>(initOperand.getType()),
752  destBbArgs[destNum], outOffsets, sizes, strides));
753  }
754 
755  // 4.b. Clone the op and update init operands.
756  // We cannot use a IRMapping here because it can replace
757  // different OpOperands with the same value.
758  Operation *clonedOp = b.clone(*op.getOperation());
759  b.modifyOpInPlace(clonedOp, [&]() {
760  for (auto [initOperandPtr, tiledInitValue] : llvm::zip_equal(
761  cast<DestinationStyleOpInterface>(clonedOp).getDpsInitsMutable(),
762  tiledDpsInitOperands)) {
763  initOperandPtr.set(tiledInitValue);
764  }
765  });
766 
767  // 5. Tile the cloned op and delete the clone.
768  if (tileSizes.empty()) {
769  FailureOr<TilingResult> tilingResult =
770  cast<TilingInterface>(clonedOp).getTiledImplementation(
771  b, tiledOffsets, tiledSizes);
772  if (failed(tilingResult))
773  return clonedOp->emitError("Failed to tile op: ");
774  if (tilingResult->tiledOps.size() != 1) {
775  return clonedOp->emitError("expected a single produced tiled op, got ")
776  << tilingResult->tiledOps.size();
777  }
778  tiledOp = tilingResult->tiledOps.front();
779  tilingResults = tilingResult->tiledValues;
780  } else {
782  FailureOr<TiledLinalgOp> maybeTiled = tileLinalgOpImpl<scf::ForOp>(
783  b, cast<LinalgOp>(clonedOp), tileSizes, options);
784  if (failed(maybeTiled))
785  return b.notifyMatchFailure(op, "failed tileLinalgOpImpl");
786 
787  SmallVector<Value> ids = forallOp.getInductionVars();
788  mapLoopToProcessorIds(cast<scf::ForOp>(maybeTiled->loops.back()), ids,
789  materializedNonZeroNumThreads);
790  if (maybeTiled->loops.size() != 1) {
791  return clonedOp->emitError("expected a single produced loop");
792  }
793  tiledOp = maybeTiled->op;
794  tilingResults = maybeTiled->loops.front()->getResults();
795  }
796 
797  b.eraseOp(clonedOp);
798  }
799 
800  // 6. Insert the partial reductions back into a new tensor.
801  for (auto [index, result, bbArg] : llvm::zip(
802  llvm::seq<unsigned>(0, dest.size()), tilingResults, destBbArgs)) {
803  // 6.a. Partial subset information is inserted just before the terminator.
805  b.setInsertionPoint(forallOp.getTerminator());
806 
807  SmallVector<OpFoldResult> resultOffsets, resultSizes;
808  if (failed(tilingInterfaceOp.getResultTilePosition(
809  b, index, tiledOffsets, tiledSizes, resultOffsets, resultSizes)))
810  return op->emitOpError("output offsets couldn't be calculated");
811  SmallVector<OpFoldResult> resultOffsetsRank, resultSizesRank;
812  int64_t offIdx = 0;
813  int64_t sizeIdx = 0;
814  for (int64_t i = 0, e = numThreads.size(); i < e; ++i) {
815  if (i == reductionDim) {
816  resultOffsetsRank.push_back(forallOp.getInductionVars().front());
817  resultSizesRank.push_back(b.getIndexAttr(1));
818  continue;
819  }
820  resultOffsetsRank.push_back(resultOffsets[offIdx++]);
821  resultSizesRank.push_back(resultSizes[sizeIdx++]);
822  }
823  SmallVector<OpFoldResult> strides(resultSizesRank.size(),
824  b.getIndexAttr(1));
825 
826  // 6.b. Parallel insertions are inserted at the end of the combining
827  // terminator.
828  b.setInsertionPointToEnd(forallOp.getTerminator().getBody());
829  b.create<tensor::ParallelInsertSliceOp>(
830  loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides);
831  }
832 
833  // 7. Merge the partial reductions.
834  b.setInsertionPointAfter(forallOp);
835  Operation *mergeOp =
836  op.mergeReductions(b, loc, forallOp->getResults(), reductionDim);
837  b.replaceOp(op, mergeOp->getResults());
838 
839  // 8. Return.
841  results.initialOp = *identityTensor;
842  results.loops = forallOp;
843  results.parallelTiledOp = tiledOp;
844  results.mergeOp = mergeOp;
845  return results;
846 }
847 
848 template <typename LoopTy>
850  RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) {
852  b.setInsertionPoint(op);
853 
854  if (!options.tileSizeComputationFunction)
855  return failure();
856 
857  // Enforce the convention that "tiling by zero" skips tiling a particular
858  // dimension. This convention is significantly simpler to handle instead of
859  // adjusting affine maps to account for missing dimensions.
860  auto nLoops = op.getNumLoops();
861  SmallVector<OpFoldResult> tileSizeVector =
862  getAsOpFoldResult(options.tileSizeComputationFunction(b, op));
863  if (tileSizeVector.size() < nLoops) {
864  tileSizeVector.append(nLoops - tileSizeVector.size(), b.getIndexAttr(0));
865  }
866 
867  return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options);
868 }
869 
872  const LinalgTilingOptions &options) {
873  switch (options.loopType) {
875  return tileLinalgOpImpl<scf::ForOp>(b, op, options);
876  case LinalgTilingLoopType::ParallelLoops:
877  return tileLinalgOpImpl<scf::ParallelOp>(b, op, options);
878  default:;
879  }
880  return failure();
881 }
882 
883 namespace {
884 /// Helper classes for type list expansion.
885 template <typename... OpTypes>
886 class CanonicalizationPatternList;
887 
888 template <>
889 class CanonicalizationPatternList<> {
890 public:
891  static void insert(RewritePatternSet &patterns) {}
892 };
893 
894 template <typename OpTy, typename... OpTypes>
895 class CanonicalizationPatternList<OpTy, OpTypes...> {
896 public:
897  static void insert(RewritePatternSet &patterns) {
898  OpTy::getCanonicalizationPatterns(patterns, patterns.getContext());
899  CanonicalizationPatternList<OpTypes...>::insert(patterns);
900  }
901 };
902 } // namespace
903 
906  RewritePatternSet patterns(ctx);
908  return patterns;
909 }
910 
912  RewritePatternSet &patterns) {
913  auto *ctx = patterns.getContext();
914  affine::AffineApplyOp::getCanonicalizationPatterns(patterns, ctx);
915  affine::AffineForOp::getCanonicalizationPatterns(patterns, ctx);
916  affine::AffineMinOp::getCanonicalizationPatterns(patterns, ctx);
917  affine::AffineMaxOp::getCanonicalizationPatterns(patterns, ctx);
918  arith::ConstantIndexOp::getCanonicalizationPatterns(patterns, ctx);
919 
920  memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx);
921  memref::ViewOp::getCanonicalizationPatterns(patterns, ctx);
922 
923  scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
924  scf::ParallelOp::getCanonicalizationPatterns(patterns, ctx);
925 
926  tensor::CastOp::getCanonicalizationPatterns(patterns, ctx);
927  tensor::EmptyOp::getCanonicalizationPatterns(patterns, ctx);
928  tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx);
929  tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx);
930  tensor::PadOp::getCanonicalizationPatterns(patterns, ctx);
931  ctx->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns(patterns);
932 
933  CanonicalizationPatternList<
934 #define GET_OP_LIST
935 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
936  >::insert(patterns);
937 }
DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, transform::ApplyToEachResultList &results, transform::TransformState &state)
static llvm::ManagedStatic< PassManagerOptions > options
SmallVector< bool > safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp, ArrayRef< OpFoldResult > numThreads)
Returns a vector of bools representing if, for each axis, op can be tiled without incurring in a race...
Definition: Tiling.cpp:312
static FailureOr< ForallTilingResult > tileToForallOpImpl(RewriterBase &b, TilingInterface op, ArrayRef< OpFoldResult > numThreads, std::optional< ArrayRef< OpFoldResult >> nominalTileSizes, std::optional< ArrayAttr > mapping, bool omitTileOffsetBoundsCheck)
Rewrite a TilingInterface op to a tiled scf.forall.
Definition: Tiling.cpp:343
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize, OpFoldResult numThreads, OpFoldResult iterationSize)
Returns true if the maximum tile offset tileSize * numThreads-1 is less than iterationSize.
Definition: Tiling.cpp:207
static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b, OpFoldResult value)
Asserts that the given index-typed value is strictly positive.
Definition: Tiling.cpp:94
static OpFoldResult buildMax(OpBuilder &b, Location loc, ArrayRef< OpFoldResult > vals)
Build an affine_max of all the vals.
Definition: Tiling.cpp:219
static void calculateTileOffsetsAndSizes(RewriterBase &b, Location loc, scf::ForallOp forallOp, ArrayRef< OpFoldResult > numThreads, SmallVector< Range > loopRanges, bool omitTileOffsetBoundsCheck, std::optional< ArrayRef< OpFoldResult >> nominalTileSizes, SmallVector< OpFoldResult > &tiledOffsets, SmallVector< OpFoldResult > &tiledSizes)
Fill out the tiledOffsets and tiledSizes to be used to tile to a given number of threads.
Definition: Tiling.cpp:236
static FailureOr< TiledLinalgOp > tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef< OpFoldResult > tileSizes, const LinalgTilingOptions &options)
Definition: Tiling.cpp:491
static OpFoldResult buildMin(OpBuilder &b, Location loc, ArrayRef< OpFoldResult > vals)
Build an affine_min of all the vals.
Definition: Tiling.cpp:227
Base type for affine expression.
Definition: AffineExpr.h:69
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:883
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:926
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:318
unsigned getNumResults() const
Definition: AffineMap.cpp:386
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:248
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:375
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
MLIRContext * getContext() const
Definition: Builders.h:55
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext * getContext() const
Return the context this location is uniqued in.
Definition: Location.h:73
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:553
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:438
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:414
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:263
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Definition: Operation.cpp:280
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
MLIRContext * getContext() const
Definition: PatternMatch.h:822
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1235
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1138
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
Definition: AffineOps.cpp:1301
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1294
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1188
void mapLoopToProcessorIds(scf::ForOp forOp, ArrayRef< Value > processorId, ArrayRef< Value > numProcessors)
Maps forOp for execution on a parallel grid of virtual processorIds of size given by numProcessors.
Definition: LoopUtils.cpp:1761
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
FailureOr< ForallTilingResult > tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op, ArrayRef< OpFoldResult > tileSizes, std::optional< ArrayAttr > mapping)
Same as tileToForallOp, but calculate the number of threads required using the given tileSizes.
Definition: Tiling.cpp:467
SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Creates extract_slice/subview ops for all valuesToTile of the given linalgOp with builder,...
Definition: Utils.cpp:829
void transformIndexOps(RewriterBase &b, LinalgOp op, SmallVectorImpl< Value > &ivs, const LoopIndexToRangeIndexMap &loopIndexToRangeIndex)
All indices returned by IndexOp should be invariant with respect to tiling.
Definition: Tiling.cpp:78
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition: Utils.cpp:184
void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns)
Definition: Tiling.cpp:911
FailureOr< ForallTilingResult > tileToForallOp(RewriterBase &builder, TilingInterface op, ArrayRef< OpFoldResult > numThreads, std::optional< ArrayAttr > mapping)
Definition: Tiling.cpp:458
SmallVector< Value > insertSlicesBack(OpBuilder &builder, Location loc, LinalgOp op, ValueRange operands, ValueRange results)
Creates insert_slice ops that insert results back into larger tensors they were originally extracted ...
Definition: Utils.cpp:749
std::tuple< SmallVector< Range, 4 >, LoopIndexToRangeIndexMap > makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > allShapeSizes, ArrayRef< OpFoldResult > allTileSizes)
Definition: Tiling.cpp:49
void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef< OpFoldResult > offests)
Add the specified offsets to any linalg.index ops contained in the given linalgOp.
Definition: Utils.cpp:850
FailureOr< StaticMultiSizeSpecification > computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, int64_t divisor)
Definition: Tiling.cpp:111
FailureOr< ForallReductionTilingResult > tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes={}, std::optional< ArrayAttr > mapping=std::nullopt)
Method to tile a reduction to parallel iterations computing partial reductions.
Definition: Tiling.cpp:647
FailureOr< TiledLinalgOp > tileLinalgOp(RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options)
Definition: Tiling.cpp:871
RewritePatternSet getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx)
Canonicalization patterns relevant to apply after tiling patterns.
Definition: Tiling.cpp:905
SmallVector< Type > getTensorOutputTypes(LinalgOp op, ValueRange operands)
Returns the list of tensor output types produced when the given structured operation op is applied to...
Definition: Utils.cpp:740
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
Definition: Tiling.cpp:137
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition: SCF.h:70
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:105
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
int64_t floorDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's floordiv operation on constants.
Definition: MathExtras.h:33
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:349
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:753
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:363
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:41
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< scf::ForOp, 8 > Loops
Tile a nest of standard for loops rooted at rootForOp by finding such parametric tile sizes that the ...
Definition: Utils.h:126
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Transformation information returned after reduction tiling.
Definition: Transforms.h:874
Operation * parallelTiledOp
The partial reduction tiled op generated.
Definition: Transforms.h:876
Operation * initialOp
The op initializing the tensor used for partial reductions.
Definition: Transforms.h:880
scf::ForallOp loops
The scf.forall operation that iterate over the tiles.
Definition: Transforms.h:882
Operation * mergeOp
The final reduction operation merging all the partial reductions.
Definition: Transforms.h:878
Rewrite a TilingInterface op to a tiled scf.forall, applying tiling by numThreads.
Definition: Transforms.h:857
A description of a multi-size tiling comprising tile sizes and numbers of tiles, expressed as Values ...
Definition: Transforms.h:810
Callback function type used to get processor ID, and number of processors used for distribution for a...
Definition: Utils.h:295
Perform standalone tiling of a single LinalgOp by tileSizes.
Definition: Transforms.h:668
SmallVector< Value, 4 > tensorResults
Definition: Transforms.h:671
T lowTripCount
Number of tiles associated with each size.
Definition: Transforms.h:802
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.