MLIR  18.0.0git
Transforms.cpp
Go to the documentation of this file.
1 //===- Transforms.cpp - Linalg transformations as patterns ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
28 #include "mlir/IR/AffineExpr.h"
29 #include "mlir/IR/Matchers.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Support/LLVM.h"
33 #include "llvm/ADT/ScopeExit.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include <type_traits>
38 #include <utility>
39 
40 #define DEBUG_TYPE "linalg-transforms"
41 
42 using namespace mlir;
43 using namespace mlir::linalg;
44 
45 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
46 #define DBGSNL() (llvm::dbgs() << "\n")
47 
48 //===----------------------------------------------------------------------===//
49 // Transformations exposed as functional-style API calls.
50 //===----------------------------------------------------------------------===//
51 
52 //===----------------------------------------------------------------------===//
53 // peelLoop transformation.
54 //===----------------------------------------------------------------------===//
55 
56 /// Try to peel and canonicalize loop `op` and return the new result.
57 /// Also applies affine_min/max bounds simplification on the fly where relevant.
58 // TODO: Add support for scf.parallel and affine.for loops.
60  Operation *op) {
62  .Case<scf::ForOp>([&](scf::ForOp forOp) {
63  scf::ForOp partialIteration;
64  if (succeeded(scf::peelForLoopAndSimplifyBounds(rewriter, forOp,
65  partialIteration)))
66  return partialIteration->getResults();
67  assert(!partialIteration && "expected that loop was not peeled");
68  return forOp->getResults();
69  })
70  .Default([&](Operation *op) { return op->getResults(); });
71 }
72 
73 /// Peel 'loops' and applies affine_min/max bounds simplification on the fly
74 /// where relevant.
76  ArrayRef<scf::ForOp> loops) {
77  for (auto loopOp : loops)
78  peelLoop(rewriter, loopOp);
79 }
80 
81 //===----------------------------------------------------------------------===//
82 // pack transformation.
83 //===----------------------------------------------------------------------===//
84 
85 #ifndef NDEBUG
86 /// Return true if `map` has 0 or 1 result function of AffineDimExpr(dim).
87 static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) {
88  bool found = false;
89  for (AffineExpr e : map.getResults()) {
90  if (!e.isFunctionOfDim(dim))
91  continue;
92  if (found)
93  return false;
94  found = true;
95  }
96  return true;
97 }
98 #endif // NDEBUG
99 
100 /// Return the index of the first result of `map` that is a function of
101 /// AffineDimExpr(dim), std::nullopt otherwise.
102 static std::optional<int64_t> getFirstResultIndexFunctionOf(AffineMap map,
103  int64_t dim) {
104  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
105  AffineExpr expr = map.getResult(i);
106  if (!expr.isFunctionOfDim(dim))
107  continue;
108  return i;
109  }
110  return std::nullopt;
111 }
112 
113 /// Perform one step of packing of a LinalgOp's metadata along `dim` into the
114 /// `newDim` at `iteratorTypes.size()` by:
115 /// 1. Appending `iteratorTypes[newDim]`, equal to `iteratorTypes[dim]`.
116 /// 2. Appending a `newDim` to the domain of every indexing map.
117 /// 3. For each operand (i.e. for each map in `indexingMaps`), perform packing
118 /// by potentially adding a `newDim` result to `map`.
119 /// The preserved invariant is that `iteratorTypes.size()` is always equal to
120 /// `map.getNumDims()` for every map in `indexingMaps`.
121 ///
122 /// Update `indexingMaps` and `iteratorTypes` inplace as one step of the update.
123 /// Return a vector that records the optional packing for each operand.
124 /// Return failure if the packed indexing cannot be represented with a LinalgOp.
125 ///
126 /// Further details:
127 /// ================
128 /// The current implementation of packing (i.e. data tiling) consists of
129 /// rewriting a linearized strip-mined form into a higher-dimensional access.
130 /// e.g. consider an access `A[I][f(j, k, l)]` and packing by 4; we rewrite
131 /// `I` into `4 * i + ii`, where `0 <= ii < 4`.
132 /// The access is further rewritten as `A[i][f(j, k, l)][ii]`.
133 ///
134 /// This rewrite into higher dimensional access is not possible for general
135 /// AffineExpr in Linalg atm, it is restricted to an AffineDimExpr:
136 /// e.g. consider an access `A[I + J][f(j, k, l)]` and packing by 4; we
137 /// rewrite `I + J` into `4 * i + ii + J`, where `0 <= ii < 4`.
138 /// The rewrite of the access would be a form not representable in Linalg:
139 /// `A[i + (ii + J) / 4][f(j, k, l)][(ii + J) % 4]`.
140 /// Note however that as `J` and `ii` iterate, the accesses do not have a
141 /// particular alignment, so packing does not achieve alignment in this case
142 ///
143 /// In the future, we may want to consider a mixed-form that allows some
144 /// alignment in the presence of multiple accesses:
145 /// `A[I][f(j, k, l)]` and `B[I + J][f(j, k, l)]`
146 /// And would rewrite accesses as:
147 /// `A[i][f(j, k, l)][ii]` and `B[4 * i + ii + J][f(j, k, l)]`
151  int64_t dim) {
152  int64_t newDim = iteratorTypes.size();
153  iteratorTypes.push_back(iteratorTypes[dim]);
154 
155  SmallVector<std::optional<int64_t>> packedDimPerIndexingMap(
156  indexingMaps.size(), std::nullopt);
157  SmallVector<AffineMap> newMaps;
158  for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
159  ++operandIdx) {
160  AffineMap map = indexingMaps[operandIdx];
161 
162  // Add the `newDim` to map whatever the case.
163  assert(map.getNumDims() == newDim && "num dims invariant violation");
164  map = map.shiftDims(1, newDim);
165 
166  // Get the at-most-1 index of the result that is a function of `dim`.
167  // If we can find one, we insert `AffineDimExpr(newDim)` to the map, which
168  // logically chunks dimension `dim` into `K * dim + newDim`, where the
169  // packing factor `K` is specified separately.
170  assert(hasAtMostOneResultFunctionOfDim(map, dim) &&
171  "num results invariant violation");
172  auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim);
173  if (!maybeOperandDimensionToPack.has_value()) {
174  newMaps.push_back(map);
175  continue;
176  }
177 
178  // We can only pack AffineDimExpr atm.
179  if (!map.getResult(maybeOperandDimensionToPack.value())
180  .isa<AffineDimExpr>())
181  return failure();
182 
183  // Add `newDim` to the results of the map.
184  map = map.insertResult(Builder(map.getContext()).getAffineDimExpr(newDim),
185  map.getNumResults());
186  newMaps.push_back(map);
187 
188  // Record the that `operandIdx` is packed.
189  packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
190  }
191  indexingMaps = newMaps;
192 
193  return packedDimPerIndexingMap;
194 }
195 
196 namespace {
197 
198 /// Helper struct to encode packing along one dimension of a LinalgOp.
199 struct PackedOperandsDim {
200  OpFoldResult packedSize;
201  SmallVector<std::optional<int64_t>> packedDimForEachOperand;
202 };
203 
204 /// Helper struct to encode packing along all dimensions of a LinalgOp.
205 struct PackedOperandsDimList {
206  void push_back(PackedOperandsDim &&packedOperandsDims) {
207  spec.emplace_back(packedOperandsDims);
208  }
209  /// Return all the dims that have been packed for operand @ `operandPos`.
210  SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos);
211  /// Return all the pack sizes by which an operand @ `operandPos` is packed.
212  SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos);
213 
214 private:
216 };
217 
218 } // namespace
219 
221  tensor::PackOp packOp) {
222  // 1. Filter out NYI cases.
223  auto packedTensorType =
224  cast<RankedTensorType>(packOp->getResultTypes().front());
225  if (llvm::any_of(packOp.getStaticInnerTiles(),
226  [](int64_t size) { return ShapedType::isDynamic(size); })) {
227  return rewriter.notifyMatchFailure(
228  packOp,
229  "non-static shape NYI, needs a more powerful tensor.expand_shape op");
230  }
231 
232  Location loc = packOp->getLoc();
233  OpBuilder::InsertionGuard g(rewriter);
234  rewriter.setInsertionPoint(packOp);
235 
236  // 2. Compute the permutation vector to shuffle packed shape into the shape
237  // before any outer or inner permutations have been applied. The permutation
238  // can be obtained from two permutations:
239  // a) Compute the permutation vector to move the last `numPackedDims` into
240  // the `innerPosDims` of a shape of rank `packedRank`.
241  // b) Compute the permutation vector to move outer dims if the pack op
242  // has outer_dims_perm.
243  // Apply (b) permutation on (a) permutation to get the final permutation.
244  int64_t numPackedDims = packOp.getInnerDimsPos().size();
245  int64_t packedRank = packedTensorType.getRank();
246  auto lastDims = llvm::to_vector(
247  llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
248  PackingMetadata packingMetadata = computePackingMetadata(
249  packedTensorType.getRank(), packOp.getInnerDimsPos());
250  SmallVector<int64_t> innerPositionsPerm = computePermutationVector(
251  packedRank, lastDims, packingMetadata.insertPositions);
252 
253  SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
254  ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
255  if (!outerPerm.empty())
256  applyPermutationToVector(outerPos, outerPerm);
258  packedRank, packingMetadata.outerPositions, outerPos);
259 
260  SmallVector<int64_t> packedToStripMinedShapePerm = innerPositionsPerm;
261  applyPermutationToVector(packedToStripMinedShapePerm, outerPositionPerm);
262 
263  // 3. Compute the stripMinedShape: this is the packed shape before any outer
264  // or inner permutations have been applied.
265  SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
266  applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
267 
268  // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
269  SmallVector<OpFoldResult> lows(packOp.getSourceRank(),
270  rewriter.getIndexAttr(0));
271  SmallVector<OpFoldResult> highs(packOp.getSourceRank(),
272  rewriter.getIndexAttr(0));
273  for (auto [pos, innerSize] :
274  llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
275  int outerPos =
276  packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
277  OpFoldResult origSize =
278  tensor::getMixedSize(rewriter, loc, packOp.getSource(), pos);
279  OpFoldResult outerSize =
280  tensor::getMixedSize(rewriter, loc, packOp.getDest(), outerPos);
281  AffineExpr s0, d0, d1;
282  bindDims(rewriter.getContext(), d0, d1);
283  bindSymbols(rewriter.getContext(), s0);
284  auto map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/1, d0 * s0 - d1);
286  rewriter, loc, map, {outerSize, origSize, innerSize});
287  }
288  RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
289  RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
290  packingMetadata.reassociations);
291  Value paddingValue = packOp.getPaddingValue();
292  if (!paddingValue) {
293  paddingValue = rewriter.create<arith::ConstantOp>(
294  loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
295  }
296  auto padOp =
297  rewriter.create<tensor::PadOp>(loc, collapsed, packOp.getSource(), lows,
298  highs, paddingValue, /*nofold=*/false);
299 
300  LLVM_DEBUG(
301  DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
302  DBGS() << "insertPositions: ");
303  DBGSNL(); llvm::interleaveComma(packingMetadata.outerPositions,
304  DBGS() << "outerPositions: ");
305  DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
306  DBGS() << "packedShape: ");
307  DBGSNL();
308  llvm::interleaveComma(outerPositionPerm, DBGS() << "outerPositionPerm: ");
309  DBGSNL(); llvm::interleaveComma(innerPositionsPerm,
310  DBGS() << "innerPositionsPerm: ");
311  DBGSNL();
312  llvm::interleaveComma(packedToStripMinedShapePerm,
313  DBGS() << "packedToStripMinedShapePerm: ");
314  DBGSNL(); llvm::interleaveComma(
315  packingMetadata.reassociations, DBGS() << "reassociations: ",
316  [&](ReassociationIndices ri) {
317  llvm::interleaveComma(ri, llvm::dbgs() << "|");
318  });
319  DBGSNL();
320  llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
321  DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
322 
323  if (packOp.isLikePad()) {
324  // Pack ops which operate as simple pads may not produce legal
325  // tensor.insert_slice operations when the packed type does not rank reduce
326  // to the padded type.
327  SliceVerificationResult rankReduces =
328  isRankReducedType(packedTensorType, padOp.getResultType());
329 
330  if (rankReduces == SliceVerificationResult::Success) {
331  // This pack is just a plain pad.
332  // Just insert the pad in the higher ranked tensor.
333  auto emptyOp =
334  rewriter.create<tensor::EmptyOp>(loc, packedTensorType, ValueRange{});
335  // Offsets.
336  SmallVector<OpFoldResult> zeros(packedRank, rewriter.getIndexAttr(0));
337  // Strides.
338  SmallVector<OpFoldResult> ones(packedRank, rewriter.getIndexAttr(1));
340  tensor::getMixedSizes(rewriter, loc, packOp.getDest());
341 
342  auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
343  loc, /*source=*/padOp, /*dest=*/emptyOp,
344  /*offsets=*/zeros, sizes,
345  /*strides=*/ones);
346 
347  LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
348 
349  rewriter.replaceOp(packOp, insertSliceOp->getResults());
350 
351  return LowerPackResult{padOp, /*reshapeOp=*/nullptr,
352  /*transposeOp=*/nullptr};
353  }
354  }
355  // 5. Expand from the padded result to the stripMinedShape.
356  auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
357  loc,
358  RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
359  padOp.getResult(), packingMetadata.reassociations);
360 
361  // 6. Transpose stripMinedShape to packedShape.
362  SmallVector<int64_t> transpPerm =
363  invertPermutationVector(packedToStripMinedShapePerm);
364  auto transposeOp = rewriter.create<linalg::TransposeOp>(
365  loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
366 
367  LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
368  DBGS() << "reshape op: " << reshapeOp; DBGSNL();
369  llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: ");
370  DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
371 
372  // 7. Replace packOp by transposeOp.
373  rewriter.replaceOp(packOp, transposeOp->getResults());
374 
375  return LowerPackResult{padOp, reshapeOp, transposeOp};
376 }
377 
379  tensor::UnPackOp unPackOp) {
380  // 1. Filter out NYI cases.
381  if (!unPackOp.getOuterDimsPerm().empty())
382  return rewriter.notifyMatchFailure(unPackOp, "outer dims perm NYI");
383 
384  RankedTensorType packedTensorType = unPackOp.getSourceType();
385  if (!packedTensorType.hasStaticShape()) {
386  return rewriter.notifyMatchFailure(
387  unPackOp,
388  "non-static shape NYI, needs a more powerful tensor.expand_shape op");
389  }
390 
391  Location loc = unPackOp->getLoc();
392  OpBuilder::InsertionGuard g(rewriter);
393  rewriter.setInsertionPoint(unPackOp);
394 
395  int64_t packedRank = packedTensorType.getRank();
396 
397  OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
398  auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
399  if (unPackOp.isLikeUnPad()) {
400  // This unpack is just a plain unpad.
401  // Just extract the slice from the higher ranked tensor.
402  ArrayRef<int64_t> destShape = destTensorType.getShape();
403  // The inner dimensions stay the same as the destination tensor, but the
404  // outer ones are additional 1s.
405  SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one);
406  sizes.append(tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()));
407 
408  auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
409  loc, destTensorType, unPackOp.getSource(),
410  SmallVector<OpFoldResult>(packedRank, zero), sizes,
411  SmallVector<OpFoldResult>(packedRank, one));
412 
413  rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
414 
415  return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
416  /*reshapeOp=*/nullptr, extractSliceOp};
417  }
418  // 2. Compute the permutation vector to move the last `numPackedDims` into
419  // the `innerPosDims` of a shape of rank `packedRank`.
420  int64_t numPackedDims = unPackOp.getInnerDimsPos().size();
421  auto lastDims = llvm::to_vector(
422  llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
423  PackingMetadata packingMetadata =
424  computePackingMetadata(packedRank, unPackOp.getInnerDimsPos());
425  SmallVector<int64_t> lastDimsToInsertPositionsPerm = computePermutationVector(
426  packedRank, lastDims, packingMetadata.insertPositions);
427 
428  // 3. Compute the stripMinedShape: this is the packed shape without outer and
429  // inner permutations.
430  SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
431  applyPermutationToVector(stripMinedShape, lastDimsToInsertPositionsPerm);
432 
433  // 4. Transpose packedShape to stripMinedShape.
434  RankedTensorType stripMinedTensorType =
435  RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
436  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
437  stripMinedTensorType, packingMetadata.reassociations);
438  auto emptyOp =
439  rewriter.create<tensor::EmptyOp>(loc, stripMinedTensorType, ValueRange{});
440  auto transposeOp = rewriter.create<linalg::TransposeOp>(
441  loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm);
442 
443  LLVM_DEBUG(
444  DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
445  DBGS() << "insertPositions: ");
446  DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
447  DBGS() << "packedShape: ");
448  DBGSNL();
449  llvm::interleaveComma(lastDimsToInsertPositionsPerm,
450  DBGS() << "lastDimsToInsertPositionsPerm: ");
451  DBGSNL(); llvm::interleaveComma(
452  packingMetadata.reassociations, DBGS() << "reassociations: ",
453  [&](ReassociationIndices ri) {
454  llvm::interleaveComma(ri, llvm::dbgs() << "|");
455  });
456  DBGSNL();
457  llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
458  DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL(););
459 
460  // 5. Collapse from the stripMinedShape to the padded result.
461  auto reshapeOp = rewriter.create<tensor::CollapseShapeOp>(
462  loc, collapsedType, transposeOp->getResult(0),
463  packingMetadata.reassociations);
464 
465  // 6. ExtractSlice.
466  int64_t destRank = destTensorType.getRank();
467  auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
468  loc, destTensorType, reshapeOp->getResult(0),
469  SmallVector<OpFoldResult>(destRank, zero),
470  tensor::getMixedSizes(rewriter, loc, unPackOp->getResult(0)),
471  SmallVector<OpFoldResult>(destRank, one));
472 
473  // 7. Inject a copy to preserve DPS.
474  auto copyOp = rewriter.create<linalg::CopyOp>(
475  loc, extractSliceOp->getResult(0), unPackOp.getDest());
476 
477  // 8. Replace unPackOp by extractSliceOp.
478  rewriter.replaceOp(unPackOp, copyOp->getResults());
479 
480  return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};
481 }
482 
484 PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
486  for (int64_t i = 0, e = spec.size(); i < e; ++i) {
487  if (!spec[i].packedDimForEachOperand[operandPos].has_value())
488  continue;
489  res.push_back(spec[i].packedDimForEachOperand[operandPos].value());
490  }
491  return res;
492 }
493 
495 PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
497  for (int64_t i = 0, e = spec.size(); i < e; ++i) {
498  if (!spec[i].packedDimForEachOperand[operandPos].has_value())
499  continue;
500  res.push_back(spec[i].packedSize);
501  }
502  return res;
503 }
504 
505 /// Implement packing of a single LinalgOp by performing packing by
506 /// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator.
507 /// Return the packed Linalg op on success, failure otherwise.
509  linalg::LinalgOp linalgOp,
510  ArrayRef<OpFoldResult> packedSizes) {
511  if (packedSizes.size() != linalgOp.getNumLoops()) {
512  return rewriter.notifyMatchFailure(linalgOp,
513  "incorrect number of pack sizes");
514  }
515 
516  Location loc = linalgOp->getLoc();
517  SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
518  SmallVector<utils::IteratorType> iteratorTypes =
519  linalgOp.getIteratorTypesArray();
520  LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n";
521  llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL();
522  llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: ");
523  DBGSNL(););
524 
527  // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i].
528  PackedOperandsDimList listOfPackedOperandsDim;
529  for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
530  std::optional<int64_t> maybeConstant = getConstantIntValue(packedSizes[i]);
531  // Skip tile sizes explicitly set to 0.
532  if (maybeConstant.has_value() && maybeConstant.value() == 0)
533  continue;
534 
535  PackedOperandsDim packedOperandsDims;
536  packedOperandsDims.packedSize = packedSizes[i];
538  maybePackedDimForEachOperand =
539  packLinalgMetadataOnce(indexingMaps, iteratorTypes, i);
540  if (failed(maybePackedDimForEachOperand))
541  return failure();
542  packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
543  listOfPackedOperandsDim.push_back(std::move(packedOperandsDims));
544 
545  LLVM_DEBUG(
546  DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i]
547  << "\n";
548  llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL();
549  llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: "); DBGSNL();
550  llvm::interleaveComma(packedOperandsDims.packedDimForEachOperand,
551  DBGS() << "packedDimForEachOperand: ");
552  DBGSNL(););
553  }
554 
555  // Step 2. Propagate packing to all LinalgOp operands.
556  SmallVector<Value> inputsAndInits, results;
557  SmallVector<OpOperand *> initOperands = llvm::to_vector(llvm::map_range(
558  linalgOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));
559  SmallVector<OpOperand *> inputOperands = linalgOp.getDpsInputOperands();
560  for (const auto &operandsList : {inputOperands, initOperands}) {
561  for (OpOperand *opOperand : operandsList) {
562  int64_t pos = opOperand->getOperandNumber();
563  Value operand = opOperand->get();
564  SmallVector<int64_t> innerPos =
565  listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
566  SmallVector<OpFoldResult> innerPackSizes =
567  listOfPackedOperandsDim.extractPackSizesForOperand(pos);
568  LLVM_DEBUG(
569  DBGS() << "operand: " << operand << "\n";
570  llvm::interleaveComma(innerPos, DBGS() << "innerPos: "); DBGSNL();
571  llvm::interleaveComma(innerPackSizes, DBGS() << "innerPackSizes: ");
572  DBGSNL(););
573  if (innerPackSizes.empty()) {
574  inputsAndInits.push_back(operand);
575  continue;
576  }
577  Value dest = tensor::PackOp::createDestinationTensor(
578  rewriter, loc, operand, innerPackSizes, innerPos,
579  /*outerDimsPerm=*/{});
580  ShapedType operandType = operand.getType().cast<ShapedType>();
581  bool areConstantTiles =
582  llvm::all_of(innerPackSizes, [](OpFoldResult tile) {
583  return getConstantIntValue(tile).has_value();
584  });
585  if (areConstantTiles && operandType.hasStaticShape() &&
586  !tensor::PackOp::requirePaddingValue(operandType.getShape(), innerPos,
587  innerPackSizes)) {
588  packOps.push_back(rewriter.create<tensor::PackOp>(
589  loc, operand, dest, innerPos, innerPackSizes));
590  } else {
591  // TODO: value of the padding attribute should be determined by
592  // consumers.
593  auto zeroAttr =
594  rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
595  Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
596  packOps.push_back(rewriter.create<tensor::PackOp>(
597  loc, operand, dest, innerPos, innerPackSizes, zero));
598  }
599  inputsAndInits.push_back(packOps.back());
600  }
601  }
602 
603  // Step 3. Build the packed op, use the type of `inits` as result types.
604  ValueRange inputs =
605  ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
606  ValueRange inits =
607  ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
608  auto packedLinalgOp = rewriter.create<linalg::GenericOp>(
609  linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps,
610  iteratorTypes);
611  packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
612 
613  // Step 4. Propagate packing to all the op results.
614  for (OpResult result : packedLinalgOp->getResults()) {
615  int64_t resultNum = result.getResultNumber();
616  tensor::PackOp maybePackedInit =
617  inits[resultNum].getDefiningOp<tensor::PackOp>();
618  if (!maybePackedInit) {
619  results.push_back(result);
620  continue;
621  }
622  // Build the symmetrical UnPackOp to the existing PackOp.
623  unPackOps.push_back(rewriter.create<tensor::UnPackOp>(
624  packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
625  maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
626  results.push_back(unPackOps.back());
627  }
628 
629  // Step 5. Replace `linalgOp`.
630  rewriter.replaceOp(linalgOp, results);
631 
632  // Return packedLinalgOp.
633  return PackResult{packOps,
634  cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
635  unPackOps};
636 }
637 
638 //===----------------------------------------------------------------------===//
639 // packTranspose transformation.
640 //===----------------------------------------------------------------------===//
641 
642 /// Return a copy of `tensorType` after permutation by `permutationVector`.
643 // Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder
644 // but this would introduce a dependence on Dialect in IR.
645 // TODO: Restructure.
646 static RankedTensorType permuteShape(RankedTensorType tensorType,
647  ArrayRef<int64_t> permutationVector) {
648  SmallVector<int64_t> shape(tensorType.getShape());
649  applyPermutationToVector(shape, permutationVector);
650  return RankedTensorType::Builder(tensorType).setShape(shape);
651 }
652 
653 /// Return a new GenericOp obtained by transposing opOperand by the permutation
654 /// vector:
655 /// - the corresponding indexing map is transposed by `permutation`
656 /// - the corresponding operand value is replaced by `transposedValue`
657 /// `linalgOp` is replaced by the return op in the process.
658 /// Asserts that `transposedValue` is of the proper transposed ShapedType.
660  RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand,
661  ArrayRef<int64_t> permutation, Value transposedValue) {
662  // Sanity check the operand.
663  assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand");
664 
665  // Sanity check of the expected transposed tensor type.
666  auto tensorType = permuteShape(
667  cast<RankedTensorType>(opOperand.get().getType()), permutation);
668  (void)tensorType;
669  assert(tensorType == transposedValue.getType() &&
670  "expected tensor type mismatch");
671 
672  // Compute the transposed indexing map.
673  // Sigh unsigned pollution.
674  SmallVector<unsigned> tmpTransposition = llvm::to_vector(
675  llvm::map_range(permutation, [](int64_t i) -> unsigned { return i; }));
676  AffineMap permutationMap =
677  AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext());
678  AffineMap transposedMap =
679  permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand));
680 
681  // Set the transposed indexing map in the proper position.
682  SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
683  indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
684  // Set the transposedValue in the proper operand position.
685  SmallVector<Value> operands = linalgOp->getOperands();
686  operands[opOperand.getOperandNumber()] = transposedValue;
687 
688  ValueRange operandsRef(operands);
689  auto transposedGenericOp = rewriter.create<linalg::GenericOp>(
690  /*location=*/linalgOp->getLoc(),
691  /*resultTensorTypes=*/
692  operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(),
693  /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()),
694  /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()),
695  /*indexingMaps=*/indexingMaps,
696  /*iteratorTypes=*/linalgOp.getIteratorTypesArray());
697  transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
698  rewriter.replaceOp(linalgOp, transposedGenericOp->getResults());
699 
700  return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
701 }
702 
704 linalg::packTranspose(RewriterBase &rewriter, tensor::PackOp packOp,
705  linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp,
706  ArrayRef<int64_t> outerPerm,
707  ArrayRef<int64_t> innerPerm) {
708  Location loc = linalgOp.getLoc();
709 
710  // Step 1. Transpose packOp.
711  rewriter.setInsertionPoint(packOp);
712  tensor::PackOp transposedPackOp =
713  packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
714 
715  if (!packOp.getResult().hasOneUse())
716  return rewriter.notifyMatchFailure(linalgOp, "expect single pack use");
717 
718  OpOperand &packUse = *packOp->getUses().begin();
719  if (packUse.getOwner() != linalgOp) {
720  return rewriter.notifyMatchFailure(
721  linalgOp, "not a single use by the LinalgOp target");
722  }
723  if (maybeUnPackOp &&
724  (!linalgOp.isDpsInit(&packUse) ||
725  maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
726  return rewriter.notifyMatchFailure(linalgOp,
727  "not produced by the LinalgOp target");
728  }
729 
730  // Step 2. Transpose linalgOp.
731  // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the
732  // identity. Don't rely on it.
733  int64_t numLeadingDims = packOp.getSourceRank();
734  int64_t numTrailingDims = packOp.getInnerDimsPos().size();
735  // Step 2.a. Compute the permutation on the whole operand.
736  // Leading part just reuse the outerPerm.
737  SmallVector<int64_t> permutation(outerPerm);
738  if (permutation.empty())
739  llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
740  // Trailing part needs to reindex positions by `numLeadingDims`.
741  if (innerPerm.empty()) {
742  llvm::append_range(
743  permutation,
744  llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
745  } else {
746  llvm::append_range(permutation,
747  llvm::map_range(innerPerm, [&](int64_t pos) {
748  return numLeadingDims + pos;
749  }));
750  }
751  if (!isPermutationVector(permutation))
752  return rewriter.notifyMatchFailure(linalgOp, "invalid permutation");
753 
754  // Step 2.b. Save the transposedPackUse operand number in case we need to
755  // get the tied OpResult after `linalgOp` has been replaced.
756  int64_t packUseOperandNumber = packUse.getOperandNumber();
757  // Step 2.c. Actually perform the transposition.
758  rewriter.setInsertionPoint(linalgOp);
759  linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace(
760  rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
761 
762  // Step 3. Maybe transpose unPackOp.
763  tensor::UnPackOp transposedUnPackOp;
764  if (maybeUnPackOp) {
765  OpOperand &opOperand =
766  transposedLinalgOp->getOpOperand(packUseOperandNumber);
767  OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
768  rewriter.setInsertionPoint(maybeUnPackOp);
769  transposedUnPackOp = maybeUnPackOp.createTransposedClone(
770  rewriter, loc, transposedResult, innerPerm, outerPerm);
771 
772  rewriter.replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
773  }
774 
775  // Step 4. Finally, replace packOp now that we don't need it anymore.
776  rewriter.replaceOp(packOp, transposedPackOp->getResults());
777 
778  return PackTransposeResult{transposedPackOp, transposedLinalgOp,
779  transposedUnPackOp};
780 }
781 
782 //===----------------------------------------------------------------------===//
783 // packMatmulGreedily transformation.
784 //===----------------------------------------------------------------------===//
785 
786 /// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m
787 /// and n are proper parallel dimensions and k is a proper reduction
788 /// dimension. Packing occurs by rewriting the op as a linalg.generic and
789 /// calling linalg::pack by `mnkPackedSizes`. The order of the packed
790 /// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2}
791 /// to reorder {m, n, k} into one of the 8 possible forms. The outer
792 /// dimensions of the operands are not permuted at this time, this is left for
793 /// future work.
795 linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
796  ArrayRef<OpFoldResult> mnkPackedSizes,
797  ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf,
798  ArrayRef<int64_t> mnkOrder) {
799  assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes");
800  assert((mnkPaddedSizesNextMultipleOf.empty() ||
801  mnkPaddedSizesNextMultipleOf.size() == 3) &&
802  "num of packing sizes next multiple should be empty or of size 3");
803  assert(mnkOrder.size() == 3 && "unexpected mnkOrder size");
804  assert(isPermutationVector(mnkOrder) && "expected a permutation");
805 
806  int64_t numLoops = linalgOp.getNumLoops();
807  if (numLoops <= 2) {
808  LLVM_DEBUG(DBGS() << "need 3+ loops to find a matmul to pack, got "
809  << numLoops << "\nin: " << linalgOp << "\n");
810  return rewriter.notifyMatchFailure(
811  linalgOp, "need 3+ loops to find a matmul to pack");
812  }
813 
814  // Locally adjust the desired iterator position of mnk and packing sizes.
815  int64_t numPackedDims = mnkPackedSizes.size();
816  SmallVector<int64_t> mmnnkkPos(numPackedDims);
817  for (int64_t i = 0, e = numPackedDims; i < e; ++i)
818  mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
819  SmallVector<OpFoldResult> packedSizes(numPackedDims);
820  for (int64_t i = 0, e = numPackedDims; i < e; ++i)
821  packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
822  SmallVector<int64_t> paddedSizesNextMultipleOf(numPackedDims);
823  for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
824  paddedSizesNextMultipleOf[mnkOrder[i]] =
825  mnkPaddedSizesNextMultipleOf.empty() ? 0
826  : mnkPaddedSizesNextMultipleOf[i];
827  }
828 
829  // 1. Infer dims that are important for matmul.
830  FailureOr<ContractionDimensions> maybeDimensions =
831  inferContractionDims(linalgOp);
832  if (failed(maybeDimensions)) {
833  LLVM_DEBUG(DBGS() << "couldn't infer matmul iterators in: " << linalgOp
834  << "\n");
835  return rewriter.notifyMatchFailure(linalgOp,
836  "couldn't infer matmul iterators");
837  }
838 
839  // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most
840  // minor iterators. In cases with multiple options for m, n, k bias towards
841  // the most minor embedding.
842  // If we wanted a different normalization order, this is where it would have
843  // to plug a heuristic.
844  int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
845  kPos = maybeDimensions->k.back();
846  LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
847  DBGS() << "Start packing generic op greedily with (m@" << mPos
848  << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp
849  << "\n";);
850 
851  // 2.a. Rewrite as a generic.
852  auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
853  if (!genericOp) {
854  FailureOr<GenericOp> generalizeResult =
855  generalizeNamedOp(rewriter, linalgOp);
856  assert(succeeded(generalizeResult) && "unexpected failure generalizing op");
857  genericOp = *generalizeResult;
858  }
859 
860  // 2.b. Interchange to move the dimensions (k, m, n) as most-minor
861  // iterators. Note that this only normalized the iteration order and does
862  // not change the indexings of any operand.
863  SmallVector<int64_t> permutation =
864  computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos);
865  LLVM_DEBUG(llvm::interleaveComma(permutation, DBGS() << "perm: "); DBGSNL(););
866  // Sign .. unsigned pollution.
867  SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end());
868  FailureOr<GenericOp> interchangeResult =
869  interchangeGenericOp(rewriter, genericOp, unsignedPerm);
870  assert(succeeded(interchangeResult) && "unexpected failure interchanging op");
871  genericOp = *interchangeResult;
872  LLVM_DEBUG(DBGS() << "Generalized Op to pack: " << genericOp << "\n";);
873 
874  // At this point, the op iterators are normalized to {leading, k, m, n}.
875  // The layouts induced by packing will always be:
876  // - LHS{leading_lhs, kk, mm}
877  // - RHS{leading_rhs, kk, nn}
878  // - RES{leading_res, mm, nn}
879  // If we wanted to change the packed order, we would reorder (k, m, n) to
880  // something else above.
881  //
882  // Additional permutations of the outer dims of the operands (i.e.
883  // leading_lhs, leading_rhs and leading_res) could follow by computing the
884  // desired outerPerm for each operand.
885  // This is left for future work.
886 
887  // TODO: this creates too much IR, go use reifyResultShapes.
888  SmallVector<Range, 4> loopRanges =
889  cast<LinalgOp>(genericOp.getOperation())
890  .createLoopRanges(rewriter, genericOp.getLoc());
891 
892  // Add leading zeros to match numLoops, we only pack the last 3 dimensions
893  // post interchange.
894  LLVM_DEBUG(llvm::interleaveComma(paddedSizesNextMultipleOf,
895  DBGS() << "paddedSizesNextMultipleOf: ");
896  DBGSNL(););
897  LLVM_DEBUG(llvm::interleaveComma(loopRanges, DBGS() << "loopRanges: ",
898  [](Range r) { llvm::dbgs() << r.size; });
899  DBGSNL(););
900  SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(),
901  rewriter.getIndexAttr(0));
902  for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
903  if (paddedSizesNextMultipleOf[i] == 0) {
904  adjustedPackedSizes.push_back(packedSizes[i]);
905  continue;
906  }
907  AffineExpr d0, s0;
908  bindDims(rewriter.getContext(), d0);
909  bindSymbols(rewriter.getContext(), s0);
910  adjustedPackedSizes.push_back(affine::makeComposedFoldedAffineApply(
911  rewriter, genericOp->getLoc(), d0.ceilDiv(s0) * s0,
912  {loopRanges[adjustedPackedSizes.size()].size,
913  rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
914  }
915  LLVM_DEBUG(llvm::interleaveComma(adjustedPackedSizes,
916  DBGS() << "adjustedPackedSizes: ");
917  DBGSNL(););
918 
919  // TODO: If we wanted to give the genericOp a name after packing, after
920  // calling `pack` would be a good time. One would still need to check that
921  // `containsMostMinorMatmul(packingRes->packedLinalgOp)` is true, since we
922  // also allow degenerate matmul cases (i.e. matvec, dot).
923  return pack(rewriter, genericOp, adjustedPackedSizes);
924 }
925 
926 //===----------------------------------------------------------------------===//
927 // Transformations exposed as rewrite patterns.
928 //===----------------------------------------------------------------------===//
929 
932  assert(!tileSizeComputationFunction && "tile sizes already set");
933  SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
934  tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
935  OpBuilder::InsertionGuard guard(b);
937  &op->getParentOfType<func::FuncOp>().getBody().front());
938  return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
939  Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
940  return v;
941  }));
942  };
943  return *this;
944 }
945 
947  memref::CopyOp copyOp, PatternRewriter &rewriter) const {
948  return vectorizeCopy(rewriter, copyOp);
949 }
950 
951 /// Filling `dest` using FillOp constant padding value if possible.
952 /// Otherwise, generate a tensor::GenerateOp.
954  RewriterBase &rewriter, tensor::PadOp padOp, Value dest,
955  const SmallVector<Value> &dynSizes) const {
956  auto padValue = padOp.getConstantPaddingValue();
957  if (padValue)
958  return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
959 
960  // Fill could not be optimized: Lower to tensor::GenerateOp with region.
961  auto generateOp = rewriter.create<tensor::GenerateOp>(
962  padOp.getLoc(), padOp.getResultType(), dynSizes);
963  // Copy region to new op.
964  IRMapping bvm;
965  padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
966  return generateOp;
967 }
968 
971  PatternRewriter &rewriter) const {
972  // Given an OpFoldResult, return an index-typed value.
973  auto getIdxValue = [&](OpFoldResult ofr) {
974  if (auto val = llvm::dyn_cast_if_present<Value>(ofr))
975  return val;
976  return rewriter
978  padOp.getLoc(), cast<IntegerAttr>(ofr.get<Attribute>()).getInt())
979  .getResult();
980  };
981 
982  auto resultType = padOp.getResultType();
983  // Compute size of EmptyOp. Any combination of static/dynamic is supported.
984  SmallVector<Value> dynSizes;
985  SmallVector<int64_t> staticSizes;
986  for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
987  if (resultType.isDynamicDim(dim)) {
988  auto srcSize = getIdxValue(tensor::getMixedSize(rewriter, padOp.getLoc(),
989  padOp.getSource(), dim));
990  // Add low and high padding value.
991  auto plusLow = rewriter.createOrFold<arith::AddIOp>(
992  padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
993  auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
994  padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
995  dynSizes.push_back(plusHigh);
996  }
997  staticSizes.push_back(resultType.getDimSize(dim));
998  }
999 
1000  // Init tensor and fill it with padding.
1001  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1002  padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes);
1003  Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes);
1004 
1005  // Try optimize the copy of source.
1006  if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
1007  return success();
1008 
1009  // tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead
1010  // for copying the PadOp source.
1011  auto sourceType = padOp.getSourceType();
1012  // Compute size of source of tensor::PadOp.
1013  SmallVector<OpFoldResult> srcSizes =
1014  tensor::getMixedSizes(rewriter, padOp.getLoc(), padOp.getSource());
1015  // Strides of InsertSliceOp are all 1.
1016  SmallVector<OpFoldResult> strides(sourceType.getRank(),
1017  rewriter.getIndexAttr(1));
1018  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
1019  padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
1020  strides);
1021 
1022  return success();
1023 }
1024 
1026  tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
1027  if (!sliceOp.hasUnitStride())
1028  return failure();
1029 
1030  auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
1031  if (!padOp)
1032  return failure();
1033 
1034  bool zeroSliceGuard = true;
1035  if (controlFn) {
1036  if (std::optional<bool> control = controlFn(sliceOp))
1037  zeroSliceGuard = *control;
1038  else
1039  return failure();
1040  }
1041 
1042  FailureOr<TilingResult> tilingResult =
1043  tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(),
1044  sliceOp.getMixedSizes(), zeroSliceGuard);
1045  if (failed(tilingResult))
1046  return failure();
1047  // All shapes are static and the data source is actually used. Rewrite into
1048  // pad(extract_slice(x)).
1049  rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
1050  return success();
1051 }
1052 
1053 /// Returns a tensor.pad op if padding value is set. Otherwise, returns the
1054 /// source directly. The method assumes that the `packOp` has static shapes.
1056  tensor::PackOp packOp) {
1057  Value input = packOp.getSource();
1058  if (!packOp.getPaddingValue()) {
1059  return input;
1060  }
1061 
1062  Location loc = packOp.getLoc();
1063  ShapedType inputType = packOp.getSourceType();
1064  int64_t inputRank = inputType.getRank();
1065  assert(llvm::all_of(packOp.getDestType().getShape().take_front(inputRank),
1066  [](int64_t val) { return val == 1; }));
1067 
1068  SmallVector<int64_t> paddedShape;
1069  DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
1070  packOp.getDimAndTileMapping();
1071  for (int64_t dim = 0; dim < inputRank; ++dim) {
1072  int64_t size = inputType.getDimSize(dim);
1073  if (!tileAndPosMapping.count(dim)) {
1074  paddedShape.push_back(size);
1075  continue;
1076  }
1077 
1078  // The size is less than or equal to tileSize because outer dims are all 1s.
1079  std::optional<int64_t> tileSize =
1080  getConstantIntValue(tileAndPosMapping.lookup(dim));
1081  assert(tileSize.has_value() && "dynamic inner tile size is not supported");
1082  paddedShape.push_back(tileSize.value());
1083  }
1084  auto resultType =
1085  RankedTensorType::get(paddedShape, inputType.getElementType());
1086  return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(),
1087  /*nofold=*/false, loc, builder);
1088 }
1089 
1090 // Normalizes a permutation on a higher rank space to its actual size, e.g.
1091 // perm = [1, 4, 2]
1092 // becomes
1093 // norm = [0, 2, 1]
1094 static SmallVector<int64_t>
1096  constexpr int64_t kNonTiledMarker = -1;
1097  SmallVector<int64_t> vec(rank, kNonTiledMarker);
1098  for (auto [index, value] : llvm::enumerate(perm))
1099  vec[value] = index;
1100  SmallVector<int64_t> normalizedPerm = llvm::to_vector(llvm::make_filter_range(
1101  vec, [&](int64_t v) { return v != kNonTiledMarker; }));
1102  // This inverts the permutation in addition to normalizing so invert back.
1103  return invertPermutationVector(normalizedPerm);
1104 }
1105 
1106 // Gets the normalized permutation implied by innerDimsPos and outerDimsPerm
1107 // assuming rank reduction of unit outer dims.
1108 static SmallVector<int64_t>
1110  ArrayRef<int64_t> innerDimsPos,
1111  ArrayRef<int64_t> outerDimsPerm) {
1112  SmallVector<int64_t> rankReducedOuterDimsPerm;
1113  SmallVector<int64_t> outerDims;
1114  SmallVector<int64_t> innerDims;
1115  int64_t dim = 0;
1116  int64_t unpackedRank = shape.size();
1117  for (auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1118  if (llvm::is_contained(innerDimsPos, i)) {
1119  innerDims.push_back(dim++);
1120  continue;
1121  }
1122  if (shape[i] == 1)
1123  continue;
1124  outerDims.push_back(dim++);
1125  if (!outerDimsPerm.empty())
1126  rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
1127  }
1128 
1129  // Get the position of the inner dims after permutation.
1130  SmallVector<int64_t> innerPerm =
1131  getPackUnpackNormalizedPerm(unpackedRank, innerDimsPos);
1132  applyPermutationToVector<int64_t>(innerDims, innerPerm);
1133 
1134  // Ditto for the outer dims.
1135  SmallVector<int64_t> perm = outerDims;
1136 
1137  rankReducedOuterDimsPerm =
1138  getPackUnpackNormalizedPerm(unpackedRank, rankReducedOuterDimsPerm);
1139  if (!rankReducedOuterDimsPerm.empty())
1140  applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm);
1141 
1142  // The tile always ends up as the inner most dims after packing.
1143  perm.append(innerDims);
1144 
1145  return perm;
1146 }
1147 
1149  tensor::PackOp packOp, PatternRewriter &rewriter) const {
1150  if (llvm::any_of(packOp.getMixedTiles(),
1151  [](OpFoldResult tile) { return tile.is<Value>(); })) {
1152  return rewriter.notifyMatchFailure(packOp,
1153  "require inner tile sizes being static");
1154  }
1155 
1156  // TODO: support the case that outer dimensions are not all 1s. A
1157  // tensor.expand_shape will be generated in this case.
1158  auto innerDimsPos = packOp.getInnerDimsPos();
1159  int64_t srcRank = packOp.getSourceRank();
1160  auto destShape = packOp.getDestType().getShape();
1161  if (llvm::any_of(innerDimsPos, [destShape](int64_t index) {
1162  return destShape[index] != 1;
1163  })) {
1164  return rewriter.notifyMatchFailure(
1165  packOp, "require the tiled outer dimensions of the result are all 1s");
1166  }
1167 
1168  // 1. Use rank-reduced tensor.extract_slice op to extract the tile and untiled
1169  // outer dims.
1170  Location loc = packOp.getLoc();
1171  Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
1172  auto inputShape = packOp.getSourceType().getShape();
1173  DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1174  packOp.getDimAndTileMapping();
1175  Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
1176  Attribute oneIdxAttr = rewriter.getIndexAttr(1);
1177  SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
1178  SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
1179  SmallVector<OpFoldResult> readSizes;
1180  SmallVector<int64_t> readShape;
1181  for (auto i : llvm::seq<unsigned>(0, srcRank)) {
1182  if (dimAndTileMapping.count(i)) {
1183  readShape.push_back(getConstantIntValue(dimAndTileMapping[i])
1184  .value_or(ShapedType::kDynamic));
1185  readSizes.push_back(dimAndTileMapping[i]);
1186  continue;
1187  }
1188  if (ShapedType::isDynamic(inputShape[i])) {
1189  readSizes.push_back(
1190  rewriter.create<tensor::DimOp>(loc, input, i).getResult());
1191  } else {
1192  readSizes.push_back(rewriter.getIndexAttr(inputShape[i]));
1193  }
1194  if (inputShape[i] != 1)
1195  readShape.push_back(inputShape[i]);
1196  }
1197 
1198  Type elemType = packOp.getSourceType().getElementType();
1199  auto readType = RankedTensorType::get(readShape, elemType);
1200 
1201  Value tile = rewriter.create<tensor::ExtractSliceOp>(
1202  loc, readType, input, readOffsets, readSizes, readStrides);
1203 
1204  // 2. Transpose the tile to match the inner tile order.
1205 
1207  inputShape, innerDimsPos, packOp.getOuterDimsPerm());
1208 
1209  LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
1210  llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
1211 
1212  SmallVector<int64_t> transpShape = readShape;
1213  applyPermutationToVector<int64_t>(transpShape, perm);
1214 
1215  Value empty = rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
1216  auto transposedOp =
1217  rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);
1218 
1219  // 3. Insert the inner tile to the destination.
1220  int64_t destRank = packOp.getDestRank();
1221  SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
1222  SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
1223  SmallVector<OpFoldResult> writeSizes =
1224  tensor::getMixedSizes(rewriter, loc, packOp.getDest());
1225 
1226  auto insert = rewriter.create<tensor::InsertSliceOp>(
1227  loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
1228  writeSizes, writeStrides);
1229  rewriter.replaceOp(packOp, insert.getResult());
1230 
1231  return success();
1232 }
1233 
1235  tensor::UnPackOp unpackOp, PatternRewriter &rewriter) const {
1236  int64_t srcRank = unpackOp.getSourceRank();
1237  int64_t destRank = unpackOp.getDestRank();
1238  ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
1239  ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1240  if (llvm::any_of(innerDimsPos, [srcShape](int64_t index) {
1241  return srcShape[index] != 1;
1242  })) {
1243  return rewriter.notifyMatchFailure(
1244  unpackOp,
1245  "require the tiled outer dimensions of the result are all 1s");
1246  }
1247 
1248  // 1. Use rank-reduced tensor.extract_slice op to extract the tile.
1249  Location loc = unpackOp.getLoc();
1250  Value source = unpackOp.getSource();
1251  DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1252  unpackOp.getDimAndTileMapping();
1253  Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
1254  Attribute oneIdxAttr = rewriter.getIndexAttr(1);
1255  SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
1256  SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
1257  SmallVector<OpFoldResult> readSizes;
1258  SmallVector<int64_t> readShape;
1259  SmallVector<Value> dynamicDims;
1260  for (auto i : llvm::seq<unsigned>(0, destRank)) {
1261  if (dimAndTileMapping.count(i)) {
1262  readSizes.push_back(oneIdxAttr);
1263  continue;
1264  }
1265 
1266  if (ShapedType::isDynamic(srcShape[i])) {
1267  Value dynamicDim =
1268  rewriter.create<tensor::DimOp>(loc, source, i).getResult();
1269  readSizes.push_back(dynamicDim);
1270  dynamicDims.push_back(dynamicDim);
1271  } else {
1272  readSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
1273  }
1274  if (srcShape[i] != 1)
1275  readShape.push_back(srcShape[i]);
1276  }
1277  auto mixedTiles = unpackOp.getMixedTiles();
1278  readSizes.append(mixedTiles.begin(), mixedTiles.end());
1279 
1280  // Explicitly create the type for extract_slice op because the inner tile
1281  // size could be 1. We want to represent the whole inner tile in this case.
1282  auto tileShape = srcShape.drop_front(destRank);
1283  // Append the inner tile shape to the permuted and rank-reduced outer shape.
1284  readShape.append(tileShape.begin(), tileShape.end());
1285  Type elemType = unpackOp.getSourceType().getElementType();
1286  auto readType = RankedTensorType::get(readShape, elemType);
1287  Value innerTile = rewriter.create<tensor::ExtractSliceOp>(
1288  loc, readType, unpackOp.getSource(), readOffsets, readSizes, readStrides);
1289 
1290  // 2. Transpose the tile to match the outer corresponding tile order.
1292  srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
1293  // Unpack is a transition out of packed space so we invert the permutation.
1294  perm = invertPermutationVector(perm);
1295  SmallVector<int64_t> transpShape(readShape);
1296  applyPermutationToVector<int64_t>(transpShape, perm);
1297 
1298  Value empty =
1299  rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType, dynamicDims);
1300  auto transposedOp =
1301  rewriter.create<linalg::TransposeOp>(loc, innerTile, empty, perm);
1302 
1303  // 3. Handle in-complete tiles if needed. It truncates trailing data from the
1304  // transposed tile.
1305  int numLoops = transpShape.size();
1306  SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr);
1307  SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr);
1308  SmallVector<OpFoldResult> tileSizes;
1309  ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
1310  for (auto i : llvm::seq<unsigned>(0, destRank)) {
1311  if (dimAndTileMapping.count(i) || destShape[i] != 1)
1312  tileSizes.push_back(
1313  tensor::getMixedSize(rewriter, loc, unpackOp.getDest(), i));
1314  }
1315 
1316  auto partialTile = rewriter.create<tensor::ExtractSliceOp>(
1317  loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides);
1318 
1319  // 4. Insert the result to the destination tensor.
1320  SmallVector<OpFoldResult> writeSizes;
1321  SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
1322  SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
1323  for (int i = 0, idx = 0; i < destRank; ++i) {
1324  if (dimAndTileMapping.count(i) || destShape[i] != 1)
1325  writeSizes.push_back(tileSizes[idx++]);
1326  else
1327  writeSizes.push_back(oneIdxAttr);
1328  }
1329  auto insert = rewriter.create<tensor::InsertSliceOp>(
1330  loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes,
1331  writeStrides);
1332  rewriter.replaceOp(unpackOp, insert.getResult());
1333 
1334  return success();
1335 }
1336 
1337 // The following are patterns for downscaling convolution ops with size-1
1338 // window dimensions.
1339 //
1340 // Note that we'd eventually want to write such transformations in a generic
1341 // way, e.g., converting to linalg.generic, removing the size-1 dimensions,
1342 // and then turning back to named ops. But for now it's fine to have a few
1343 // patterns matching special ops to get started.
1344 
1345 template <typename Conv2DOp, typename Conv1DOp>
1347  returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const {
1348  if (convOp.hasBufferSemantics())
1349  return failure(); // To be implemented.
1350 
1351  Value input = convOp.getInputs().front();
1352  Value kernel = convOp.getInputs().back();
1353  Value output = convOp.getOutputs().front();
1354 
1355  auto inputType = dyn_cast<RankedTensorType>(input.getType());
1356  auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1357  auto outputType = dyn_cast<RankedTensorType>(output.getType());
1358 
1359  auto kernelShape = kernelType.getShape();
1360  auto outputShape = outputType.getShape();
1361 
1362  // Get domain indices based on conv2D layout.
1363  auto [khIndex, kwIndex, ohIndex, owIndex] =
1365  convOp)
1366  .Case([&](linalg::Conv2DNhwcHwcfOp op) {
1367  return std::make_tuple(0, 1, 1, 2);
1368  })
1369  .Case([&](linalg::Conv2DNchwFchwOp op) {
1370  return std::make_tuple(2, 3, 2, 3);
1371  })
1372  .Case([&](linalg::PoolingNhwcSumOp op) {
1373  return std::make_tuple(0, 1, 1, 2);
1374  })
1375  .Case([&](linalg::PoolingNchwSumOp op) {
1376  return std::make_tuple(0, 1, 2, 3);
1377  })
1378  .Case([&](linalg::PoolingNhwcMaxOp op) {
1379  return std::make_tuple(0, 1, 1, 2);
1380  })
1381  .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
1382  return std::make_tuple(0, 1, 1, 2);
1383  })
1384  .Case([&](linalg::PoolingNhwcMinOp op) {
1385  return std::make_tuple(0, 1, 1, 2);
1386  })
1387  .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
1388  return std::make_tuple(0, 1, 1, 2);
1389  })
1390  .Case([&](linalg::PoolingNchwMaxOp op) {
1391  return std::make_tuple(0, 1, 2, 3);
1392  })
1393  .Default([&](Operation *op) {
1394  llvm_unreachable("unexpected conv2d/pool2d operation.");
1395  return std::make_tuple(0, 0, 0, 0);
1396  });
1397 
1398  // Only handle the case where at least one of the window dimensions is
1399  // of size 1. Other cases can rely on tiling to reduce to such cases.
1400  int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
1401  int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
1402  bool removeH = (khSize == 1 && ohSize == 1);
1403  bool removeW = (kwSize == 1 && owSize == 1);
1404  if (!removeH && !removeW)
1405  return failure();
1406 
1407  // Get new shapes and types for all operands by removing the size-1
1408  // dimension.
1409  using RTTBuilder = RankedTensorType::Builder;
1410  RankedTensorType newInputType =
1411  RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex));
1412  RankedTensorType newKernelType =
1413  RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex));
1414  RankedTensorType newOutputType =
1415  RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex));
1416 
1417  // Rank-reduce operands.
1418  Location loc = convOp.getLoc();
1420  rewriter, loc, input, newInputType);
1422  rewriter, loc, kernel, newKernelType);
1424  rewriter, loc, output, newOutputType);
1425 
1426  // Rank-reduce strides and dilations too.
1427  // TODO: dropDim 1-liner helper.
1428  auto strides =
1429  llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
1430  strides.erase(strides.begin() + (removeH ? 0 : 1));
1431  auto stridesAttr = rewriter.getI64VectorAttr(strides);
1432 
1433  auto dilations =
1434  llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
1435  dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1436  auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1437 
1438  auto conv1DOp = rewriter.create<Conv1DOp>(
1439  loc, newOutputType, ValueRange{newInput, newKernel},
1440  ValueRange{newOutput}, stridesAttr, dilationsAttr);
1441 
1442  // Insert back.
1444  rewriter, loc, conv1DOp.getResult(0), output);
1445  rewriter.replaceOp(convOp, inserted);
1446 
1447  return conv1DOp;
1448 }
1449 
1450 template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp,
1451  Conv1DNwcWcfOp>;
1452 template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
1453  Conv1DNcwFcwOp>;
1454 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp,
1455  PoolingNwcSumOp>;
1456 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp,
1457  PoolingNcwSumOp>;
1458 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp,
1459  PoolingNwcMaxOp>;
1461  PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
1462 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp,
1463  PoolingNwcMinOp>;
1465  PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
1466 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp,
1467  PoolingNcwMaxOp>;
1468 
1471  DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const {
1472  if (convOp.hasBufferSemantics())
1473  return failure(); // To be implemented.
1474 
1475  Value input = convOp.getInputs().front();
1476  Value kernel = convOp.getInputs().back();
1477  Value output = convOp.getOutputs().front();
1478 
1479  auto inputType = dyn_cast<RankedTensorType>(input.getType());
1480  auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1481  auto outputType = dyn_cast<RankedTensorType>(output.getType());
1482 
1483  auto kernelShape = kernelType.getShape();
1484  auto outputShape = outputType.getShape();
1485 
1486  // Only handle the case where at least one of the window dimensions is
1487  // of size 1. Other cases can rely on tiling to reduce to such cases.
1488  int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1489  int64_t ohSize = outputShape[1], owSize = outputShape[2];
1490  bool removeH = (khSize == 1 && ohSize == 1);
1491  bool removeW = (kwSize == 1 && owSize == 1);
1492  if (!removeH && !removeW)
1493  return failure();
1494 
1495  // Get new shapes and types for all operands by removing the size-1
1496  // dimension.
1497  using RTTBuilder = RankedTensorType::Builder;
1498  RankedTensorType newInputType =
1499  RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
1500  RankedTensorType newKernelType =
1501  RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1502  RankedTensorType newOutputType =
1503  RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
1504 
1505  // Rank-reduce operands.
1506  Location loc = convOp.getLoc();
1508  rewriter, loc, input, newInputType);
1510  rewriter, loc, kernel, newKernelType);
1512  rewriter, loc, output, newOutputType);
1513 
1514  // Rank-reduce strides and dilations too.
1515  // TODO: dropDim 1-liner helper.
1516  auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
1517  strides.erase(strides.begin() + (removeH ? 0 : 1));
1518  auto stridesAttr = rewriter.getI64VectorAttr(strides);
1519 
1520  auto dilations =
1521  llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
1522  dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1523  auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1524 
1525  auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
1526  loc, newOutputType, ValueRange{newInput, newKernel},
1527  ValueRange{newOutput}, stridesAttr, dilationsAttr);
1528 
1529  // Insert back.
1531  rewriter, loc, conv1DOp.getResult(0), output);
1532  rewriter.replaceOp(convOp, inserted);
1533 
1534  return conv1DOp;
1535 }
1536 
1539  PatternRewriter &rewriter) const {
1540  if (convOp.hasBufferSemantics())
1541  return failure(); // To be implemented.
1542 
1543  Value input = convOp.getInputs().front();
1544  Value kernel = convOp.getInputs().back();
1545  Value output = convOp.getOutputs().front();
1546 
1547  auto inputType = dyn_cast<RankedTensorType>(input.getType());
1548  auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1549  auto outputType = dyn_cast<RankedTensorType>(output.getType());
1550 
1551  auto kernelShape = kernelType.getShape();
1552  auto outputShape = outputType.getShape();
1553 
1554  // Only handle the case where at least one of the window dimensions is
1555  // of size 1. Other cases can rely on tiling to reduce to such cases.
1556  int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1557  int64_t ohSize = outputShape[0], owSize = outputShape[1];
1558  bool removeH = (khSize == 1 && ohSize == 1);
1559  bool removeW = (kwSize == 1 && owSize == 1);
1560  if (!removeH && !removeW)
1561  return failure();
1562 
1563  // Get new shapes and types for all operands by removing the size-1
1564  // dimension.
1565  using RTTBuilder = RankedTensorType::Builder;
1566  RankedTensorType newInputType =
1567  RTTBuilder(inputType).dropDim((removeH ? 0 : 1));
1568  RankedTensorType newKernelType =
1569  RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1570  RankedTensorType newOutputType =
1571  RTTBuilder(outputType).dropDim(removeH ? 0 : 1);
1572 
1573  // Rank-reduce operands.
1574  Location loc = convOp.getLoc();
1576  rewriter, loc, input, newInputType);
1578  rewriter, loc, kernel, newKernelType);
1580  rewriter, loc, output, newOutputType);
1581 
1582  auto conv1DOp = rewriter.create<Conv1DOp>(loc, newOutputType,
1583  ValueRange{newInput, newKernel},
1584  ValueRange{newOutput});
1585 
1586  // Insert back.
1588  rewriter, loc, conv1DOp.getResult(0), output);
1589  rewriter.replaceOp(convOp, inserted);
1590 
1591  return conv1DOp;
1592 }
1593 
1595  PatternBenefit benefit) {
1596  patterns.add<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
1597  Conv1DNwcWcfOp>,
1598  DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
1599  Conv1DNcwFcwOp>,
1601  patterns.getContext(), benefit);
1602  patterns.add<
1606  DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxUnsignedOp,
1607  PoolingNwcMaxUnsignedOp>,
1609  DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinUnsignedOp,
1610  PoolingNwcMinUnsignedOp>,
1612  patterns.getContext(), benefit);
1613 }
static RankedTensorType permuteShape(RankedTensorType tensorType, ArrayRef< int64_t > permutationVector)
Return a copy of tensorType after permutation by permutationVector.
Definition: Transforms.cpp:646
static SmallVector< int64_t > getPackUnpackRankReducedPerm(ArrayRef< int64_t > shape, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
static std::optional< int64_t > getFirstResultIndexFunctionOf(AffineMap map, int64_t dim)
Return the index of the first result of map that is a function of AffineDimExpr(dim),...
Definition: Transforms.cpp:102
static FailureOr< SmallVector< std::optional< int64_t > > > packLinalgMetadataOnce(SmallVectorImpl< AffineMap > &indexingMaps, SmallVectorImpl< utils::IteratorType > &iteratorTypes, int64_t dim)
Perform one step of packing of a LinalgOp's metadata along dim into the newDim at iteratorTypes....
Definition: Transforms.cpp:149
static LinalgOp transposeOneLinalgOperandAndReplace(RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand, ArrayRef< int64_t > permutation, Value transposedValue)
Return a new GenericOp obtained by transposing opOperand by the permutation vector:
Definition: Transforms.cpp:659
static Value getPackOpSourceOrPaddedSource(OpBuilder &builder, tensor::PackOp packOp)
Returns a tensor.pad op if padding value is set.
static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim)
Return true if map has 0 or 1 result function of AffineDimExpr(dim).
Definition: Transforms.cpp:87
static SmallVector< int64_t > getPackUnpackNormalizedPerm(int rank, ArrayRef< int64_t > perm)
#define DBGSNL()
Definition: Transforms.cpp:46
#define DBGS()
Definition: Transforms.cpp:45
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
Base type for affine expression.
Definition: AffineExpr.h:68
constexpr bool isa() const
Definition: AffineExpr.h:272
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
Definition: AffineExpr.cpp:293
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:829
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:44
MLIRContext * getContext() const
Definition: AffineMap.cpp:284
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
Definition: AffineMap.h:241
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
Definition: AffineMap.h:289
unsigned getNumDims() const
Definition: AffineMap.cpp:337
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:350
unsigned getNumResults() const
Definition: AffineMap.cpp:345
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:354
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:225
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:495
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:353
MLIRContext * getContext() const
Definition: Builders.h:55
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:144
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:150
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:505
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
This class represents an operand of an operation.
Definition: Value.h:261
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:217
This is a value defined by a result of an operation.
Definition: Value.h:448
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:665
result_range getResults()
Definition: Operation.h:410
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:248
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:275
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:259
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:339
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:372
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:90
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1276
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp namedOp)
Create a GenericOp from the given named operation namedOp and replace namedOp.
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp)
Rewrite pack as empty + transpose + reshape + extract_slice.
Definition: Transforms.cpp:378
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel 'loops' and applies affine_min/max bounds simplification on the fly where relevant.
Definition: Transforms.cpp:75
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
Definition: Interchange.cpp:50
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
Definition: Transforms.cpp:795
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Definition: Transforms.cpp:508
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, tensor::PackOp packOp)
Rewrite pack as pad + reshape + transpose.
Definition: Transforms.cpp:220
SmallVector< Value > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel and canonicalize loop op and return the new result.
Definition: Transforms.cpp:59
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Definition: Transforms.cpp:704
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Definition: TensorOps.cpp:2519
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
Definition: TensorOps.cpp:2171
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:50
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:60
PadOp createPadHighOp(RankedTensorType type, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder)
Definition: Utils.cpp:24
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:398
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:331
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:345
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:825
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult size
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override
Definition: Transforms.cpp:946
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Definition: Transforms.h:1224
FailureOr< DepthwiseConv1DNwcWcOp > returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D convolution ops with size-1 window dimensions into 1-D convolution ops.
Definition: Transforms.h:1204
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tensor::PackOp packOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tensor::UnPackOp unpackOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Definition: Transforms.cpp:970
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
Definition: Transforms.cpp:953
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
Definition: Transforms.h:184
Struct to hold the result of a pack call.
Definition: Transforms.h:1083
Struct to hold the result of a packTranspose call.
Definition: Transforms.h:1095