MLIR  20.0.0git
Transforms.cpp
Go to the documentation of this file.
1 //===- Transforms.cpp - Linalg transformations as patterns ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
28 #include "mlir/IR/AffineExpr.h"
29 #include "mlir/IR/Matchers.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Support/LLVM.h"
33 #include "llvm/ADT/ScopeExit.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include <type_traits>
38 #include <utility>
39 
40 #define DEBUG_TYPE "linalg-transforms"
41 
42 using namespace mlir;
43 using namespace mlir::linalg;
44 
45 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
46 #define DBGSNL() (llvm::dbgs() << "\n")
47 
48 //===----------------------------------------------------------------------===//
49 // Transformations exposed as functional-style API calls.
50 //===----------------------------------------------------------------------===//
51 
52 //===----------------------------------------------------------------------===//
53 // peelLoop transformation.
54 //===----------------------------------------------------------------------===//
55 
56 /// Try to peel and canonicalize loop `op` and return the new result.
57 /// Also applies affine_min/max bounds simplification on the fly where relevant.
58 // TODO: Add support for scf.parallel and affine.for loops.
60  Operation *op) {
62  .Case<scf::ForOp>([&](scf::ForOp forOp) {
63  scf::ForOp partialIteration;
64  if (succeeded(scf::peelForLoopAndSimplifyBounds(rewriter, forOp,
65  partialIteration)))
66  return partialIteration->getResults();
67  assert(!partialIteration && "expected that loop was not peeled");
68  return forOp->getResults();
69  })
70  .Default([&](Operation *op) { return op->getResults(); });
71 }
72 
73 /// Peel 'loops' and applies affine_min/max bounds simplification on the fly
74 /// where relevant.
76  ArrayRef<scf::ForOp> loops) {
77  for (auto loopOp : loops)
78  peelLoop(rewriter, loopOp);
79 }
80 
81 //===----------------------------------------------------------------------===//
82 // pack transformation.
83 //===----------------------------------------------------------------------===//
84 
85 #ifndef NDEBUG
86 /// Return true if `map` has 0 or 1 result function of AffineDimExpr(dim).
87 static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) {
88  bool found = false;
89  for (AffineExpr e : map.getResults()) {
90  if (!e.isFunctionOfDim(dim))
91  continue;
92  if (found)
93  return false;
94  found = true;
95  }
96  return true;
97 }
98 #endif // NDEBUG
99 
100 /// Return the index of the first result of `map` that is a function of
101 /// AffineDimExpr(dim), std::nullopt otherwise.
102 static std::optional<int64_t> getFirstResultIndexFunctionOf(AffineMap map,
103  int64_t dim) {
104  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
105  AffineExpr expr = map.getResult(i);
106  if (!expr.isFunctionOfDim(dim))
107  continue;
108  return i;
109  }
110  return std::nullopt;
111 }
112 
113 /// Perform one step of packing of a LinalgOp's metadata along `dim` into the
114 /// `newDim` at `iteratorTypes.size()` by:
115 /// 1. Appending `iteratorTypes[newDim]`, equal to `iteratorTypes[dim]`.
116 /// 2. Appending a `newDim` to the domain of every indexing map.
117 /// 3. For each operand (i.e. for each map in `indexingMaps`), perform packing
118 /// by potentially adding a `newDim` result to `map`.
119 /// The preserved invariant is that `iteratorTypes.size()` is always equal to
120 /// `map.getNumDims()` for every map in `indexingMaps`.
121 ///
122 /// Update `indexingMaps` and `iteratorTypes` inplace as one step of the update.
123 /// Return a vector that records the optional packing for each operand.
124 /// Return failure if the packed indexing cannot be represented with a LinalgOp.
125 ///
126 /// Further details:
127 /// ================
128 /// The current implementation of packing (i.e. data tiling) consists of
129 /// rewriting a linearized strip-mined form into a higher-dimensional access.
130 /// e.g. consider an access `A[I][f(j, k, l)]` and packing by 4; we rewrite
131 /// `I` into `4 * i + ii`, where `0 <= ii < 4`.
132 /// The access is further rewritten as `A[i][f(j, k, l)][ii]`.
133 ///
134 /// This rewrite into higher dimensional access is not possible for general
135 /// AffineExpr in Linalg atm, it is restricted to an AffineDimExpr:
136 /// e.g. consider an access `A[I + J][f(j, k, l)]` and packing by 4; we
137 /// rewrite `I + J` into `4 * i + ii + J`, where `0 <= ii < 4`.
138 /// The rewrite of the access would be a form not representable in Linalg:
139 /// `A[i + (ii + J) / 4][f(j, k, l)][(ii + J) % 4]`.
140 /// Note however that as `J` and `ii` iterate, the accesses do not have a
141 /// particular alignment, so packing does not achieve alignment in this case
142 ///
143 /// In the future, we may want to consider a mixed-form that allows some
144 /// alignment in the presence of multiple accesses:
145 /// `A[I][f(j, k, l)]` and `B[I + J][f(j, k, l)]`
146 /// And would rewrite accesses as:
147 /// `A[i][f(j, k, l)][ii]` and `B[4 * i + ii + J][f(j, k, l)]`
148 static FailureOr<SmallVector<std::optional<int64_t>>>
151  int64_t dim) {
152  int64_t newDim = iteratorTypes.size();
153  iteratorTypes.push_back(iteratorTypes[dim]);
154 
155  SmallVector<std::optional<int64_t>> packedDimPerIndexingMap(
156  indexingMaps.size(), std::nullopt);
157  SmallVector<AffineMap> newMaps;
158  for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
159  ++operandIdx) {
160  AffineMap map = indexingMaps[operandIdx];
161 
162  // Add the `newDim` to map whatever the case.
163  assert(map.getNumDims() == newDim && "num dims invariant violation");
164  map = map.shiftDims(1, newDim);
165 
166  // Get the at-most-1 index of the result that is a function of `dim`.
167  // If we can find one, we insert `AffineDimExpr(newDim)` to the map, which
168  // logically chunks dimension `dim` into `K * dim + newDim`, where the
169  // packing factor `K` is specified separately.
170  assert(hasAtMostOneResultFunctionOfDim(map, dim) &&
171  "num results invariant violation");
172  auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim);
173  if (!maybeOperandDimensionToPack.has_value()) {
174  newMaps.push_back(map);
175  continue;
176  }
177 
178  // We can only pack AffineDimExpr atm.
179  if (!isa<AffineDimExpr>(map.getResult(maybeOperandDimensionToPack.value())))
180  return failure();
181 
182  // Add `newDim` to the results of the map.
183  map = map.insertResult(Builder(map.getContext()).getAffineDimExpr(newDim),
184  map.getNumResults());
185  newMaps.push_back(map);
186 
187  // Record the that `operandIdx` is packed.
188  packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
189  }
190  indexingMaps = newMaps;
191 
192  return packedDimPerIndexingMap;
193 }
194 
195 namespace {
196 
197 /// Helper struct to encode packing along one dimension of a LinalgOp.
198 struct PackedOperandsDim {
199  OpFoldResult packedSize;
200  SmallVector<std::optional<int64_t>> packedDimForEachOperand;
201 };
202 
203 /// Helper struct to encode packing along all dimensions of a LinalgOp.
204 struct PackedOperandsDimList {
205  void pushBack(PackedOperandsDim &&packedOperandsDims) {
206  spec.emplace_back(packedOperandsDims);
207  }
208  /// Return all the dims that have been packed for operand @ `operandPos`.
209  SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos);
210  /// Return all the pack sizes by which an operand @ `operandPos` is packed.
211  SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos);
212 
213 private:
215 };
216 
217 } // namespace
218 
219 FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
220  tensor::PackOp packOp,
221  bool lowerPadLikeWithInsertSlice) {
222  // 1. Filter out NYI cases.
223  auto packedTensorType =
224  cast<RankedTensorType>(packOp->getResultTypes().front());
225  if (llvm::any_of(packOp.getStaticInnerTiles(),
226  [](int64_t size) { return ShapedType::isDynamic(size); })) {
227  return rewriter.notifyMatchFailure(
228  packOp,
229  "non-static shape NYI, needs a more powerful tensor.expand_shape op");
230  }
231 
232  Location loc = packOp->getLoc();
233  OpBuilder::InsertionGuard g(rewriter);
234  rewriter.setInsertionPoint(packOp);
235 
236  // 2. Compute the permutation vector to shuffle packed shape into the shape
237  // before any outer or inner permutations have been applied.
238  PackingMetadata packingMetadata = computePackingMetadata(
239  packedTensorType.getRank(), packOp.getInnerDimsPos());
240  SmallVector<int64_t> packedToStripMinedShapePerm =
242 
243  // 3. Compute the stripMinedShape: this is the packed shape before any outer
244  // or inner permutations have been applied.
245  SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
246  applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
247 
248  // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
249  SmallVector<OpFoldResult> lows(packOp.getSourceRank(),
250  rewriter.getIndexAttr(0));
251  SmallVector<OpFoldResult> highs(packOp.getSourceRank(),
252  rewriter.getIndexAttr(0));
253  for (auto [pos, innerSize] :
254  llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
255  int outerPos =
256  packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
257  OpFoldResult origSize =
258  tensor::getMixedSize(rewriter, loc, packOp.getSource(), pos);
259  OpFoldResult outerSize =
260  tensor::getMixedSize(rewriter, loc, packOp.getDest(), outerPos);
261  AffineExpr s0, d0, d1;
262  bindDims(rewriter.getContext(), d0, d1);
263  bindSymbols(rewriter.getContext(), s0);
264  auto map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/1, d0 * s0 - d1);
266  rewriter, loc, map, {outerSize, origSize, innerSize});
267  }
268  RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
269  RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
270  packingMetadata.reassociations);
271  Value paddingValue = packOp.getPaddingValue();
272  if (!paddingValue) {
273  paddingValue = rewriter.create<arith::ConstantOp>(
274  loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
275  }
276  auto padOp =
277  rewriter.create<tensor::PadOp>(loc, collapsed, packOp.getSource(), lows,
278  highs, paddingValue, /*nofold=*/false);
279 
280  LLVM_DEBUG(
281  DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
282  DBGS() << "insertPositions: ");
283  DBGSNL(); llvm::interleaveComma(packingMetadata.outerPositions,
284  DBGS() << "outerPositions: ");
285  DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
286  DBGS() << "packedShape: ");
287  DBGSNL();
288  llvm::interleaveComma(packedToStripMinedShapePerm,
289  DBGS() << "packedToStripMinedShapePerm: ");
290  DBGSNL(); llvm::interleaveComma(
291  packingMetadata.reassociations, DBGS() << "reassociations: ",
292  [&](ReassociationIndices ri) {
293  llvm::interleaveComma(ri, llvm::dbgs() << "|");
294  });
295  DBGSNL();
296  llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
297  DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
298 
299  if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
300  // Pack ops which operate as simple pads may not produce legal
301  // tensor.insert_slice operations when the packed type does not rank reduce
302  // to the padded type.
303  SliceVerificationResult rankReduces =
304  isRankReducedType(packedTensorType, padOp.getResultType());
305 
306  if (rankReduces == SliceVerificationResult::Success) {
307  // This pack is just a plain pad.
308  // Just insert the pad in the higher ranked tensor.
309  // Offsets.
310  SmallVector<OpFoldResult> zeros(packOp.getDestRank(),
311  rewriter.getIndexAttr(0));
312  // Strides.
313  SmallVector<OpFoldResult> ones(packOp.getDestRank(),
314  rewriter.getIndexAttr(1));
316  tensor::getMixedSizes(rewriter, loc, packOp.getDest());
317 
318  auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
319  loc, /*source=*/padOp, /*dest=*/packOp.getDest(),
320  /*offsets=*/zeros, sizes, /*strides=*/ones);
321 
322  LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
323 
324  rewriter.replaceOp(packOp, insertSliceOp->getResults());
325 
326  return LowerPackResult{padOp, /*reshapeOp=*/nullptr,
327  /*transposeOp=*/nullptr};
328  }
329  }
330 
331  // 5. Expand from the padded result to the stripMinedShape.
332  auto expandShapeResultType =
333  RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
334  auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
335  loc, expandShapeResultType, padOp.getResult(),
336  packingMetadata.reassociations);
337 
338  // 6. Transpose stripMinedShape to packedShape.
339  SmallVector<int64_t> transpPerm =
340  invertPermutationVector(packedToStripMinedShapePerm);
341  auto transposeOp = rewriter.create<linalg::TransposeOp>(
342  loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
343 
344  LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
345  DBGS() << "reshape op: " << reshapeOp; DBGSNL();
346  llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: ");
347  DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
348 
349  // 7. Replace packOp by transposeOp.
350  rewriter.replaceOp(packOp, transposeOp->getResults());
351 
352  return LowerPackResult{padOp, reshapeOp, transposeOp};
353 }
354 
355 FailureOr<LowerUnPackOpResult>
356 linalg::lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp,
357  bool lowerUnpadLikeWithExtractSlice) {
358  Location loc = unPackOp->getLoc();
359  OpBuilder::InsertionGuard g(rewriter);
360  rewriter.setInsertionPoint(unPackOp);
361 
362  RankedTensorType packedTensorType = unPackOp.getSourceType();
363  int64_t packedRank = packedTensorType.getRank();
364 
365  OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
366  auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
367  if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
368  // This unpack is just a plain unpad.
369  // Just extract the slice from the higher ranked tensor.
370  ArrayRef<int64_t> destShape = destTensorType.getShape();
371  // The inner dimensions stay the same as the destination tensor, but the
372  // outer ones are additional 1s.
373  SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one);
374  sizes.append(tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()));
375 
376  auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
377  loc, destTensorType, unPackOp.getSource(),
378  SmallVector<OpFoldResult>(packedRank, zero), sizes,
379  SmallVector<OpFoldResult>(packedRank, one));
380 
381  rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
382 
383  return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
384  /*reshapeOp=*/nullptr, extractSliceOp};
385  }
386 
387  // 1. Compute the permutation vector to shuffle packed shape into the shape
388  // before any outer or inner permutations have been applied.
389  PackingMetadata packingMetadata;
390  SmallVector<int64_t> packedToStripMinedShapePerm =
391  tensor::getUnPackInverseSrcPerm(unPackOp, packingMetadata);
392 
393  // 2. Compute the stripMinedShape: this is the packed shape without outer and
394  // inner permutations.
395  SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
396  applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
397 
398  // 3. Transpose packedShape to stripMinedShape.
399  RankedTensorType stripMinedTensorType =
400  RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
401  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
402  stripMinedTensorType, packingMetadata.reassociations);
403 
404  // Get dynamic dims from input tensor based on packedToStripMinedShapePerm
405  // permutation.
407  tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
408  applyPermutationToVector(dims, packedToStripMinedShapePerm);
409  auto emptyOp = rewriter.create<tensor::EmptyOp>(
410  loc, dims, stripMinedTensorType.getElementType());
411  auto transposeOp = rewriter.create<linalg::TransposeOp>(
412  loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
413 
414  LLVM_DEBUG(
415  DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,
416  DBGS() << "insertPositions: ");
417  DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(),
418  DBGS() << "packedShape: ");
419  DBGSNL();
420  llvm::interleaveComma(packedToStripMinedShapePerm,
421  DBGS() << "packedToStripMinedShapePerm: ");
422  DBGSNL(); llvm::interleaveComma(
423  packingMetadata.reassociations, DBGS() << "reassociations: ",
424  [&](ReassociationIndices ri) {
425  llvm::interleaveComma(ri, llvm::dbgs() << "|");
426  });
427  DBGSNL();
428  llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: ");
429  DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL(););
430 
431  // 4. Collapse from the stripMinedShape to the padded result.
432  auto reshapeOp = rewriter.create<tensor::CollapseShapeOp>(
433  loc, collapsedType, transposeOp->getResult(0),
434  packingMetadata.reassociations);
435 
436  // 5. ExtractSlice.
437  int64_t destRank = destTensorType.getRank();
438  auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
439  loc, destTensorType, reshapeOp->getResult(0),
440  SmallVector<OpFoldResult>(destRank, zero),
441  tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()),
442  SmallVector<OpFoldResult>(destRank, one));
443 
444  // 6. Inject a copy to preserve DPS.
445  auto copyOp = rewriter.create<linalg::CopyOp>(
446  loc, extractSliceOp->getResult(0), unPackOp.getDest());
447 
448  // 7. Replace unPackOp by copyOp.
449  rewriter.replaceOp(unPackOp, copyOp->getResults());
450 
451  return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};
452 }
453 
455 PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
457  for (auto &i : spec) {
458  if (!i.packedDimForEachOperand[operandPos].has_value())
459  continue;
460  res.push_back(i.packedDimForEachOperand[operandPos].value());
461  }
462  return res;
463 }
464 
466 PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
468  for (auto &i : spec) {
469  if (!i.packedDimForEachOperand[operandPos].has_value())
470  continue;
471  res.push_back(i.packedSize);
472  }
473  return res;
474 }
475 
476 /// Implement packing of a single LinalgOp by performing packing by
477 /// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator.
478 /// Return the packed Linalg op on success, failure otherwise.
479 FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
480  linalg::LinalgOp linalgOp,
481  ArrayRef<OpFoldResult> packedSizes) {
482  if (packedSizes.size() != linalgOp.getNumLoops()) {
483  return rewriter.notifyMatchFailure(linalgOp,
484  "incorrect number of pack sizes");
485  }
486 
487  Location loc = linalgOp->getLoc();
488  SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
489  SmallVector<utils::IteratorType> iteratorTypes =
490  linalgOp.getIteratorTypesArray();
491  LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n";
492  llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL();
493  llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: ");
494  DBGSNL(););
495 
498  // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i].
499  PackedOperandsDimList listOfPackedOperandsDim;
500  for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
501  std::optional<int64_t> maybeConstant = getConstantIntValue(packedSizes[i]);
502  // Skip tile sizes explicitly set to 0.
503  if (maybeConstant.has_value() && maybeConstant.value() == 0)
504  continue;
505 
506  PackedOperandsDim packedOperandsDims;
507  packedOperandsDims.packedSize = packedSizes[i];
508  FailureOr<SmallVector<std::optional<int64_t>>>
509  maybePackedDimForEachOperand =
510  packLinalgMetadataOnce(indexingMaps, iteratorTypes, i);
511  if (failed(maybePackedDimForEachOperand))
512  return failure();
513  packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
514  listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
515 
516  LLVM_DEBUG(
517  DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i]
518  << "\n";
519  llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL();
520  llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: "); DBGSNL();
521  llvm::interleaveComma(packedOperandsDims.packedDimForEachOperand,
522  DBGS() << "packedDimForEachOperand: ");
523  DBGSNL(););
524  }
525 
526  // Step 2. Propagate packing to all LinalgOp operands.
527  SmallVector<Value> inputsAndInits, results;
528  SmallVector<OpOperand *> initOperands = llvm::to_vector(llvm::map_range(
529  linalgOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; }));
530  SmallVector<OpOperand *> inputOperands = linalgOp.getDpsInputOperands();
531  for (const auto &operandsList : {inputOperands, initOperands}) {
532  for (OpOperand *opOperand : operandsList) {
533  int64_t pos = opOperand->getOperandNumber();
534  Value operand = opOperand->get();
535  SmallVector<int64_t> innerPos =
536  listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
537  SmallVector<OpFoldResult> innerPackSizes =
538  listOfPackedOperandsDim.extractPackSizesForOperand(pos);
539  LLVM_DEBUG(
540  DBGS() << "operand: " << operand << "\n";
541  llvm::interleaveComma(innerPos, DBGS() << "innerPos: "); DBGSNL();
542  llvm::interleaveComma(innerPackSizes, DBGS() << "innerPackSizes: ");
543  DBGSNL(););
544  if (innerPackSizes.empty()) {
545  inputsAndInits.push_back(operand);
546  continue;
547  }
548  Value dest = tensor::PackOp::createDestinationTensor(
549  rewriter, loc, operand, innerPackSizes, innerPos,
550  /*outerDimsPerm=*/{});
551  ShapedType operandType = cast<ShapedType>(operand.getType());
552  bool areConstantTiles =
553  llvm::all_of(innerPackSizes, [](OpFoldResult tile) {
554  return getConstantIntValue(tile).has_value();
555  });
556  if (areConstantTiles && operandType.hasStaticShape() &&
557  !tensor::PackOp::requirePaddingValue(
558  operandType.getShape(), innerPos,
559  cast<ShapedType>(dest.getType()).getShape(), {},
560  innerPackSizes)) {
561  packOps.push_back(rewriter.create<tensor::PackOp>(
562  loc, operand, dest, innerPos, innerPackSizes));
563  } else {
564  // TODO: value of the padding attribute should be determined by
565  // consumers.
566  auto zeroAttr =
567  rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
568  Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
569  packOps.push_back(rewriter.create<tensor::PackOp>(
570  loc, operand, dest, innerPos, innerPackSizes, zero));
571  }
572  inputsAndInits.push_back(packOps.back());
573  }
574  }
575 
576  // Step 3. Build the packed op, use the type of `inits` as result types.
577  ValueRange inputs =
578  ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
579  ValueRange inits =
580  ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
581  auto packedLinalgOp = rewriter.create<linalg::GenericOp>(
582  linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps,
583  iteratorTypes);
584  packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
585 
586  // Step 4. Propagate packing to all the op results.
587  for (OpResult result : packedLinalgOp->getResults()) {
588  int64_t resultNum = result.getResultNumber();
589  tensor::PackOp maybePackedInit =
590  inits[resultNum].getDefiningOp<tensor::PackOp>();
591  if (!maybePackedInit) {
592  results.push_back(result);
593  continue;
594  }
595  // Build the symmetrical UnPackOp to the existing PackOp.
596  unPackOps.push_back(rewriter.create<tensor::UnPackOp>(
597  packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
598  maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
599  results.push_back(unPackOps.back());
600  }
601 
602  // Step 5. Replace `linalgOp`.
603  rewriter.replaceOp(linalgOp, results);
604 
605  // Return packedLinalgOp.
606  return PackResult{packOps,
607  cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
608  unPackOps};
609 }
610 
611 //===----------------------------------------------------------------------===//
612 // packTranspose transformation.
613 //===----------------------------------------------------------------------===//
614 
615 /// Return a copy of `tensorType` after permutation by `permutationVector`.
616 // Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder
617 // but this would introduce a dependence on Dialect in IR.
618 // TODO: Restructure.
619 static RankedTensorType permuteShape(RankedTensorType tensorType,
620  ArrayRef<int64_t> permutationVector) {
621  SmallVector<int64_t> shape(tensorType.getShape());
622  applyPermutationToVector(shape, permutationVector);
623  return RankedTensorType::Builder(tensorType).setShape(shape);
624 }
625 
626 /// Return a new GenericOp obtained by transposing opOperand by the permutation
627 /// vector:
628 /// - the corresponding indexing map is transposed by `permutation`
629 /// - the corresponding operand value is replaced by `transposedValue`
630 /// `linalgOp` is replaced by the return op in the process.
631 /// Asserts that `transposedValue` is of the proper transposed ShapedType.
633  RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand,
634  ArrayRef<int64_t> permutation, Value transposedValue) {
635  // Sanity check the operand.
636  assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand");
637 
638  // Sanity check of the expected transposed tensor type.
639  auto tensorType = permuteShape(
640  cast<RankedTensorType>(opOperand.get().getType()), permutation);
641  (void)tensorType;
642  assert(tensorType == transposedValue.getType() &&
643  "expected tensor type mismatch");
644 
645  // Compute the transposed indexing map.
646  // Sigh unsigned pollution.
647  SmallVector<unsigned> tmpTransposition = llvm::to_vector(
648  llvm::map_range(permutation, [](int64_t i) -> unsigned { return i; }));
649  AffineMap permutationMap =
650  AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext());
651  AffineMap transposedMap =
652  permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand));
653 
654  // Set the transposed indexing map in the proper position.
655  SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
656  indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
657  // Set the transposedValue in the proper operand position.
658  SmallVector<Value> operands = linalgOp->getOperands();
659  operands[opOperand.getOperandNumber()] = transposedValue;
660 
661  ValueRange operandsRef(operands);
662  auto transposedGenericOp = rewriter.create<linalg::GenericOp>(
663  /*location=*/linalgOp->getLoc(),
664  /*resultTensorTypes=*/
665  operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(),
666  /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()),
667  /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()),
668  /*indexingMaps=*/indexingMaps,
669  /*iteratorTypes=*/linalgOp.getIteratorTypesArray());
670  transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
671  rewriter.replaceOp(linalgOp, transposedGenericOp->getResults());
672 
673  return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
674 }
675 
676 FailureOr<PackTransposeResult>
677 linalg::packTranspose(RewriterBase &rewriter, tensor::PackOp packOp,
678  linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp,
679  ArrayRef<int64_t> outerPerm,
680  ArrayRef<int64_t> innerPerm) {
681  Location loc = linalgOp.getLoc();
682 
683  // Step 1. Transpose packOp.
684  rewriter.setInsertionPoint(packOp);
685  tensor::PackOp transposedPackOp =
686  packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
687 
688  if (!packOp.getResult().hasOneUse())
689  return rewriter.notifyMatchFailure(linalgOp, "expect single pack use");
690 
691  OpOperand &packUse = *packOp->getUses().begin();
692  if (packUse.getOwner() != linalgOp) {
693  return rewriter.notifyMatchFailure(
694  linalgOp, "not a single use by the LinalgOp target");
695  }
696  if (maybeUnPackOp &&
697  (!linalgOp.isDpsInit(&packUse) ||
698  maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
699  return rewriter.notifyMatchFailure(linalgOp,
700  "not produced by the LinalgOp target");
701  }
702 
703  // Step 2. Transpose linalgOp.
704  // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the
705  // identity. Don't rely on it.
706  int64_t numLeadingDims = packOp.getSourceRank();
707  int64_t numTrailingDims = packOp.getInnerDimsPos().size();
708  // Step 2.a. Compute the permutation on the whole operand.
709  // Leading part just reuse the outerPerm.
710  SmallVector<int64_t> permutation(outerPerm);
711  if (permutation.empty())
712  llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
713  // Trailing part needs to reindex positions by `numLeadingDims`.
714  if (innerPerm.empty()) {
715  llvm::append_range(
716  permutation,
717  llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
718  } else {
719  llvm::append_range(permutation,
720  llvm::map_range(innerPerm, [&](int64_t pos) {
721  return numLeadingDims + pos;
722  }));
723  }
724  if (!isPermutationVector(permutation))
725  return rewriter.notifyMatchFailure(linalgOp, "invalid permutation");
726 
727  // Step 2.b. Save the transposedPackUse operand number in case we need to
728  // get the tied OpResult after `linalgOp` has been replaced.
729  int64_t packUseOperandNumber = packUse.getOperandNumber();
730  // Step 2.c. Actually perform the transposition.
731  rewriter.setInsertionPoint(linalgOp);
732  linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace(
733  rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
734 
735  // Step 3. Maybe transpose unPackOp.
736  tensor::UnPackOp transposedUnPackOp;
737  if (maybeUnPackOp) {
738  OpOperand &opOperand =
739  transposedLinalgOp->getOpOperand(packUseOperandNumber);
740  OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
741  rewriter.setInsertionPoint(maybeUnPackOp);
742  transposedUnPackOp = maybeUnPackOp.createTransposedClone(
743  rewriter, loc, transposedResult, innerPerm, outerPerm);
744 
745  rewriter.replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
746  }
747 
748  // Step 4. Finally, replace packOp now that we don't need it anymore.
749  rewriter.replaceOp(packOp, transposedPackOp->getResults());
750 
751  return PackTransposeResult{transposedPackOp, transposedLinalgOp,
752  transposedUnPackOp};
753 }
754 
755 //===----------------------------------------------------------------------===//
756 // packMatmulGreedily transformation.
757 //===----------------------------------------------------------------------===//
758 
759 /// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m
760 /// and n are proper parallel dimensions and k is a proper reduction
761 /// dimension. Packing occurs by rewriting the op as a linalg.generic and
762 /// calling linalg::pack by `mnkPackedSizes`. The order of the packed
763 /// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2}
764 /// to reorder {m, n, k} into one of the 8 possible forms. The outer
765 /// dimensions of the operands are not permuted at this time, this is left for
766 /// future work.
767 FailureOr<PackResult>
768 linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
769  ArrayRef<OpFoldResult> mnkPackedSizes,
770  ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf,
771  ArrayRef<int64_t> mnkOrder) {
772  assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes");
773  assert((mnkPaddedSizesNextMultipleOf.empty() ||
774  mnkPaddedSizesNextMultipleOf.size() == 3) &&
775  "num of packing sizes next multiple should be empty or of size 3");
776  assert(mnkOrder.size() == 3 && "unexpected mnkOrder size");
777  assert(isPermutationVector(mnkOrder) && "expected a permutation");
778 
779  int64_t numLoops = linalgOp.getNumLoops();
780  if (numLoops <= 2) {
781  LLVM_DEBUG(DBGS() << "need 3+ loops to find a matmul to pack, got "
782  << numLoops << "\nin: " << linalgOp << "\n");
783  return rewriter.notifyMatchFailure(
784  linalgOp, "need 3+ loops to find a matmul to pack");
785  }
786 
787  // Locally adjust the desired iterator position of mnk and packing sizes.
788  int64_t numPackedDims = mnkPackedSizes.size();
789  SmallVector<int64_t> mmnnkkPos(numPackedDims);
790  for (int64_t i = 0, e = numPackedDims; i < e; ++i)
791  mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
792  SmallVector<OpFoldResult> packedSizes(numPackedDims);
793  for (int64_t i = 0, e = numPackedDims; i < e; ++i)
794  packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
795  SmallVector<int64_t> paddedSizesNextMultipleOf(numPackedDims);
796  for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
797  paddedSizesNextMultipleOf[mnkOrder[i]] =
798  mnkPaddedSizesNextMultipleOf.empty() ? 0
799  : mnkPaddedSizesNextMultipleOf[i];
800  }
801 
802  // 1. Infer dims that are important for matmul.
803  FailureOr<ContractionDimensions> maybeDimensions =
804  inferContractionDims(linalgOp);
805  if (failed(maybeDimensions)) {
806  LLVM_DEBUG(DBGS() << "couldn't infer matmul iterators in: " << linalgOp
807  << "\n");
808  return rewriter.notifyMatchFailure(linalgOp,
809  "couldn't infer matmul iterators");
810  }
811 
812  // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most
813  // minor iterators. In cases with multiple options for m, n, k bias towards
814  // the most minor embedding.
815  // If we wanted a different normalization order, this is where it would have
816  // to plug a heuristic.
817  int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
818  kPos = maybeDimensions->k.back();
819  LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
820  DBGS() << "Start packing generic op greedily with (m@" << mPos
821  << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp
822  << "\n";);
823 
824  // 2.a. Rewrite as a generic.
825  auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
826  if (!genericOp) {
827  FailureOr<GenericOp> generalizeResult =
828  generalizeNamedOp(rewriter, linalgOp);
829  assert(succeeded(generalizeResult) && "unexpected failure generalizing op");
830  genericOp = *generalizeResult;
831  }
832 
833  // 2.b. Interchange to move the dimensions (k, m, n) as most-minor
834  // iterators. Note that this only normalized the iteration order and does
835  // not change the indexings of any operand.
836  SmallVector<int64_t> permutation =
837  computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos);
838  LLVM_DEBUG(llvm::interleaveComma(permutation, DBGS() << "perm: "); DBGSNL(););
839  // Sign .. unsigned pollution.
840  SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end());
841  FailureOr<GenericOp> interchangeResult =
842  interchangeGenericOp(rewriter, genericOp, unsignedPerm);
843  assert(succeeded(interchangeResult) && "unexpected failure interchanging op");
844  genericOp = *interchangeResult;
845  LLVM_DEBUG(DBGS() << "Generalized Op to pack: " << genericOp << "\n";);
846 
847  // At this point, the op iterators are normalized to {leading, k, m, n}.
848  // The layouts induced by packing will always be:
849  // - LHS{leading_lhs, kk, mm}
850  // - RHS{leading_rhs, kk, nn}
851  // - RES{leading_res, mm, nn}
852  // If we wanted to change the packed order, we would reorder (k, m, n) to
853  // something else above.
854  //
855  // Additional permutations of the outer dims of the operands (i.e.
856  // leading_lhs, leading_rhs and leading_res) could follow by computing the
857  // desired outerPerm for each operand.
858  // This is left for future work.
859 
860  // TODO: this creates too much IR, go use reifyResultShapes.
861  SmallVector<Range, 4> loopRanges =
862  cast<LinalgOp>(genericOp.getOperation())
863  .createLoopRanges(rewriter, genericOp.getLoc());
864 
865  // Add leading zeros to match numLoops, we only pack the last 3 dimensions
866  // post interchange.
867  LLVM_DEBUG(llvm::interleaveComma(paddedSizesNextMultipleOf,
868  DBGS() << "paddedSizesNextMultipleOf: ");
869  DBGSNL(););
870  LLVM_DEBUG(llvm::interleaveComma(loopRanges, DBGS() << "loopRanges: ",
871  [](Range r) { llvm::dbgs() << r.size; });
872  DBGSNL(););
873  SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(),
874  rewriter.getIndexAttr(0));
875  for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
876  if (paddedSizesNextMultipleOf[i] == 0) {
877  adjustedPackedSizes.push_back(packedSizes[i]);
878  continue;
879  }
880  AffineExpr d0, s0;
881  bindDims(rewriter.getContext(), d0);
882  bindSymbols(rewriter.getContext(), s0);
883  adjustedPackedSizes.push_back(affine::makeComposedFoldedAffineApply(
884  rewriter, genericOp->getLoc(), d0.ceilDiv(s0) * s0,
885  {loopRanges[adjustedPackedSizes.size()].size,
886  rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
887  }
888  LLVM_DEBUG(llvm::interleaveComma(adjustedPackedSizes,
889  DBGS() << "adjustedPackedSizes: ");
890  DBGSNL(););
891 
892  // TODO: If we wanted to give the genericOp a name after packing, after
893  // calling `pack` would be a good time. One would still need to check that
894  // `containsMostMinorMatmul(packingRes->packedLinalgOp)` is true, since we
895  // also allow degenerate matmul cases (i.e. matvec, dot).
896  return pack(rewriter, genericOp, adjustedPackedSizes);
897 }
898 
899 //===----------------------------------------------------------------------===//
900 // Transformations exposed as rewrite patterns.
901 //===----------------------------------------------------------------------===//
902 
905  assert(!tileSizeComputationFunction && "tile sizes already set");
906  SmallVector<int64_t, 4> tileSizes(ts);
907  tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
908  OpBuilder::InsertionGuard guard(b);
910  &op->getParentOfType<func::FuncOp>().getBody().front());
911  return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
912  Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
913  return v;
914  }));
915  };
916  return *this;
917 }
918 
920  memref::CopyOp copyOp, PatternRewriter &rewriter) const {
921  return vectorizeCopy(rewriter, copyOp);
922 }
923 
924 /// Filling `dest` using FillOp constant padding value if possible.
925 /// Otherwise, generate a tensor::GenerateOp.
927  RewriterBase &rewriter, tensor::PadOp padOp, Value dest,
928  const SmallVector<Value> &dynSizes) const {
929  auto padValue = padOp.getConstantPaddingValue();
930  if (padValue)
931  return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
932 
933  // Fill could not be optimized: Lower to tensor::GenerateOp with region.
934  auto generateOp = rewriter.create<tensor::GenerateOp>(
935  padOp.getLoc(), padOp.getResultType(), dynSizes);
936  // Copy region to new op.
937  IRMapping bvm;
938  padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
939  return generateOp;
940 }
941 
942 LogicalResult
944  PatternRewriter &rewriter) const {
945  // Given an OpFoldResult, return an index-typed value.
946  auto getIdxValue = [&](OpFoldResult ofr) {
947  if (auto val = llvm::dyn_cast_if_present<Value>(ofr))
948  return val;
949  return rewriter
951  padOp.getLoc(), cast<IntegerAttr>(cast<Attribute>(ofr)).getInt())
952  .getResult();
953  };
954 
955  auto resultType = padOp.getResultType();
956  // Compute size of EmptyOp. Any combination of static/dynamic is supported.
957  SmallVector<Value> dynSizes;
958  SmallVector<int64_t> staticSizes;
959  for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
960  if (resultType.isDynamicDim(dim)) {
961  auto srcSize = getIdxValue(tensor::getMixedSize(rewriter, padOp.getLoc(),
962  padOp.getSource(), dim));
963  // Add low and high padding value.
964  auto plusLow = rewriter.createOrFold<arith::AddIOp>(
965  padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
966  auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
967  padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
968  dynSizes.push_back(plusHigh);
969  }
970  staticSizes.push_back(resultType.getDimSize(dim));
971  }
972 
973  // Init tensor and fill it with padding.
974  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
975  padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes);
976  Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes);
977 
978  // Generate a InsertSliceOp for copying the PadOp source.
979  auto sourceType = padOp.getSourceType();
980  // Compute size of source of tensor::PadOp.
981  SmallVector<OpFoldResult> srcSizes =
982  tensor::getMixedSizes(rewriter, padOp.getLoc(), padOp.getSource());
983  // Strides of InsertSliceOp are all 1.
984  SmallVector<OpFoldResult> strides(sourceType.getRank(),
985  rewriter.getIndexAttr(1));
986  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
987  padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
988  strides);
989 
990  return success();
991 }
992 
994  tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
995  if (!sliceOp.hasUnitStride())
996  return failure();
997 
998  auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
999  if (!padOp)
1000  return failure();
1001 
1002  bool zeroSliceGuard = true;
1003  if (controlFn) {
1004  if (std::optional<bool> control = controlFn(sliceOp))
1005  zeroSliceGuard = *control;
1006  else
1007  return failure();
1008  }
1009 
1010  FailureOr<TilingResult> tilingResult =
1011  tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(),
1012  sliceOp.getMixedSizes(), zeroSliceGuard);
1013  if (failed(tilingResult))
1014  return failure();
1015  // All shapes are static and the data source is actually used. Rewrite into
1016  // pad(extract_slice(x)).
1017  rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
1018  return success();
1019 }
1020 
1021 /// If padding value is set, returns a tensor.pad Op for the source tensor,
1022 /// with the output shape matching the output of `packOp`. Otherwise, returns
1023 /// the source directly.
1024 ///
1025 /// This method assumes that all outer dims for this pack Op are 1.
1027  tensor::PackOp packOp) {
1028  Value input = packOp.getSource();
1029  if (!packOp.getPaddingValue()) {
1030  return input;
1031  }
1032 
1033  assert(llvm::all_of(packOp.getAllOuterDims(),
1034  [](int64_t val) { return val == 1; }) &&
1035  "some outer dims are != 1");
1036 
1037  Location loc = packOp.getLoc();
1038  ShapedType inputType = packOp.getSourceType();
1039  int64_t inputRank = inputType.getRank();
1040 
1041  DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
1042  packOp.getDimAndTileMapping();
1043 
1044  // The sizes of dynamic tiles
1045  SmallVector<Value> dynamicTileSizes;
1046 
1047  // Collect dims for the padded shape.
1048  SmallVector<int64_t> paddedShape;
1049  for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1050  // 1. Non-tiled outer dims.
1051  // These dims should be 1 and we simply preserve them.
1052  if (!tileAndPosMapping.count(dimIdx)) {
1053  int64_t inputDimSize = inputType.getDimSize(dimIdx);
1054  assert(inputDimSize == 1 &&
1055  "with all outer dims == 1, this non-tiled input dim should be 1!");
1056  paddedShape.push_back(inputDimSize);
1057  continue;
1058  }
1059 
1060  // 2. Tiled outer dims
1061  // As all outer dims == 1, it is safe to use the tile size for the padded
1062  // shape.
1063  OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
1064 
1065  // 2.1 Static tile sizes
1066  std::optional<int64_t> cstTileSize = getConstantIntValue(tileSizeForDim);
1067  if (cstTileSize.has_value()) {
1068  paddedShape.push_back(cstTileSize.value());
1069  continue;
1070  }
1071 
1072  // 2.2 Dynamic tile sizes
1073  paddedShape.push_back(ShapedType::kDynamic);
1074 
1075  // Get the value that holds the dynamic size.
1076  dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
1077  }
1078  auto resultType =
1079  RankedTensorType::get(paddedShape, inputType.getElementType());
1080  return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(),
1081  /*nofold=*/false, loc, builder,
1082  dynamicTileSizes);
1083 }
1084 
1085 // Normalizes a permutation on a higher rank space to its actual size, e.g.
1086 // perm = [1, 4, 2]
1087 // becomes
1088 // norm = [0, 2, 1]
1089 static SmallVector<int64_t>
1091  constexpr int64_t kNonTiledMarker = -1;
1092  SmallVector<int64_t> vec(rank, kNonTiledMarker);
1093  for (auto [index, value] : llvm::enumerate(perm))
1094  vec[value] = index;
1095  SmallVector<int64_t> normalizedPerm = llvm::filter_to_vector(
1096  vec, [&](int64_t v) { return v != kNonTiledMarker; });
1097  // This inverts the permutation in addition to normalizing so invert back.
1098  return invertPermutationVector(normalizedPerm);
1099 }
1100 
1101 // Gets the normalized permutation implied by innerDimsPos and outerDimsPerm
1102 // assuming rank reduction of unit outer dims.
1103 static SmallVector<int64_t>
1105  ArrayRef<int64_t> innerDimsPos,
1106  ArrayRef<int64_t> outerDimsPerm) {
1107  SmallVector<int64_t> rankReducedOuterDimsPerm;
1108  SmallVector<int64_t> outerDims;
1109  SmallVector<int64_t> innerDims;
1110  int64_t dim = 0;
1111  int64_t unpackedRank = shape.size();
1112  for (auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1113  if (llvm::is_contained(innerDimsPos, i)) {
1114  innerDims.push_back(dim++);
1115  continue;
1116  }
1117  if (shape[i] == 1)
1118  continue;
1119  outerDims.push_back(dim++);
1120  if (!outerDimsPerm.empty())
1121  rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
1122  }
1123 
1124  // Get the position of the inner dims after permutation.
1125  SmallVector<int64_t> innerPerm =
1126  getPackUnpackNormalizedPerm(unpackedRank, innerDimsPos);
1127  applyPermutationToVector<int64_t>(innerDims, innerPerm);
1128 
1129  // Ditto for the outer dims.
1130  SmallVector<int64_t> perm = outerDims;
1131 
1132  rankReducedOuterDimsPerm =
1133  getPackUnpackNormalizedPerm(unpackedRank, rankReducedOuterDimsPerm);
1134  if (!rankReducedOuterDimsPerm.empty())
1135  applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm);
1136 
1137  // The tile always ends up as the inner most dims after packing.
1138  perm.append(innerDims);
1139 
1140  return perm;
1141 }
1142 
1144  tensor::PackOp packOp, PatternRewriter &rewriter) const {
1145  // TODO: support the case that outer dimensions are not all 1s. A
1146  // tensor.expand_shape will be generated in this case.
1147  if (llvm::any_of(packOp.getAllOuterDims(),
1148  [](int64_t dim) { return dim != 1; })) {
1149  return rewriter.notifyMatchFailure(
1150  packOp, "not all outer dimensions of the result are 1s");
1151  }
1152 
1153  Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
1154  Attribute oneIdxAttr = rewriter.getIndexAttr(1);
1155  Location loc = packOp.getLoc();
1156 
1157  Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
1158  DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1159  packOp.getDimAndTileMapping();
1160  int64_t srcRank = packOp.getSourceRank();
1161  int64_t destRank = packOp.getDestRank();
1162  int64_t numTiles = destRank - srcRank;
1163 
1164  if (!llvm::all_of(packOp.getInnerDimsPos(),
1165  [&srcRank, &numTiles](int64_t dimPos) {
1166  return dimPos >= (srcRank - numTiles - 1);
1167  }))
1168  return rewriter.notifyMatchFailure(
1169  packOp, "Attempting to tile non-trailing source dims!");
1170 
1171  // 1. Extract the inner tile sizes.
1172  // Where possible, values are replaced with constant attributes (to match the
1173  // behaviour of `getPackOpSourceOrPaddedSource`).
1174  SmallVector<OpFoldResult> tileSizes;
1175  for (auto i : llvm::seq<unsigned>(0, srcRank)) {
1176  if (dimAndTileMapping.count(i)) {
1177  // Rather than taking the tile size as is, extact the actual constant
1178  // value Attribute where possible, e.g.:
1179  // [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
1180  auto [_, tileSize] =
1181  getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
1182  tileSizes.push_back(tileSize);
1183  }
1184  }
1185 
1186  // 2. Transpose the input to match the inner tile order:
1187  // %init = tensor.empty()
1188  // %transposed_tile = linalg.transpose ins(%source_or_padded_source),
1189  // outs(%init)
1190  // Two assumptions are made:
1191  // 1. All outer dims are 1 - the corresponding transposition doesn't matter.
1192  // 2. Inner dims position correspond to the trailing `numTiles` dims.
1193  SmallVector<int64_t> tilesPermNormalized =
1194  getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos());
1195  SmallVector<int64_t> srcPermForTranspose;
1196  for (int64_t i = 0; i < (srcRank - numTiles); i++)
1197  srcPermForTranspose.push_back(i);
1198 
1199  srcPermForTranspose.append(SmallVector<int64_t>(packOp.getInnerDimsPos()));
1200 
1201  LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
1202  llvm::interleaveComma(srcPermForTranspose, DBGS() << "perm: ");
1203  DBGSNL(););
1204 
1205  // 2.1 Create tensor.empty (init value for TransposeOp)
1206  SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,
1207  oneIdxAttr);
1208  transShapeForEmptyOp.append(tileSizes);
1209 
1210  applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
1211  srcPermForTranspose);
1212  Value empty = rewriter.create<tensor::EmptyOp>(
1213  loc, transShapeForEmptyOp, packOp.getSourceType().getElementType());
1214 
1215  // 2.2 Create linalg.transpose
1216  auto transposedOp = rewriter.create<linalg::TransposeOp>(loc, input, empty,
1217  srcPermForTranspose);
1218 
1219  // 3. Insert the inner tile to the destination:
1220  // %inserted_tile = tensor.insert_slice(%transposed_tile)
1221  SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
1222  SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
1223  // Outer dims are all 1s!
1224  SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(),
1225  oneIdxAttr);
1226  SmallVector<int64_t> writeShape;
1227 
1228  for (auto tileSize : packOp.getMixedTiles()) {
1229  auto [tileSizeStatic, tileSizeOfr] =
1230  getSimplifiedOfrAndStaticSizePair(tileSize, rewriter);
1231  writeSizes.push_back(tileSizeOfr);
1232  writeShape.push_back(tileSizeStatic);
1233  }
1234 
1235  // 4. Replace tensor.packOp with tensor.insert_slice created above
1236  auto insert = rewriter.create<tensor::InsertSliceOp>(
1237  loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
1238  writeSizes, writeStrides);
1239  rewriter.replaceOp(packOp, insert.getResult());
1240 
1241  return success();
1242 }
1243 
1245  tensor::UnPackOp unpackOp, PatternRewriter &rewriter) const {
1246  int64_t srcRank = unpackOp.getSourceRank();
1247  int64_t destRank = unpackOp.getDestRank();
1248  ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
1249  ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1250  if (llvm::any_of(unpackOp.getTiledOuterDims(),
1251  [](int64_t dim) { return dim != 1; })) {
1252  return rewriter.notifyMatchFailure(
1253  unpackOp,
1254  "require the tiled outer dimensions of the result are all 1s");
1255  }
1256 
1257  // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
1258  // %extracted_tile = tensor.extract_slice(%unpack_op_input)
1259  Location loc = unpackOp.getLoc();
1260  Value source = unpackOp.getSource();
1261  DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1262  unpackOp.getDimAndTileMapping();
1263  Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
1264  Attribute oneIdxAttr = rewriter.getIndexAttr(1);
1265 
1266  // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of
1267  // dims:
1268  // [ outer-untiled-dims, outer-tiled-dims, tile-sizes ]
1269  SmallVector<int64_t> readShapeForExtractSlice;
1270  // The sizes attribute for ExtractSliceOp. Due to rank-reducing (and
1271  // outer-tiled-dims being all 1), this will be
1272  // [ outer-untiled-dims, tile-sizes ]
1273  SmallVector<OpFoldResult> extractSliceSizes;
1274  // The offset and strides attributes for ExtractSliceOp.
1275  SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr);
1276  SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr);
1277 
1278  // Shape for EmptyOp that's used as the init value for TransposeOp below.
1279  // This should be:
1280  // [ outer-untiled-dims, tile-sizes ]
1281  // However, skip unit dims - TransposeOp (below) applies rank-reduced
1282  // permutation.
1283  SmallVector<OpFoldResult> shapeForEmptyOp;
1284 
1285  for (auto i : llvm::seq<unsigned>(0, destRank)) {
1286  // Compute sizes attribute for ExtractSliceOp - outer-tiled-dims.
1287  //
1288  // As all outer tiled dims are 1, so the corresponding
1289  // slice size to read will also 1. As this will be rank-reducing "extract
1290  // slice" (i.e. the unit dims will be "collapsed"), there's no need to
1291  // update:
1292  // * the output shape for ExtractSliceOp, nor
1293  // * the shape for EmptyOp.
1294  if (dimAndTileMapping.count(i)) {
1295  extractSliceSizes.push_back(oneIdxAttr);
1296  continue;
1297  }
1298 
1299  // Compute sizes attribute for ExtractSliceOp + EmptyOp -
1300  // outer-untiled-dims
1301  if (ShapedType::isDynamic(srcShape[i])) {
1302  OpFoldResult dynamicDim =
1303  rewriter.create<tensor::DimOp>(loc, source, i).getResult();
1304  extractSliceSizes.push_back(dynamicDim);
1305  shapeForEmptyOp.push_back(dynamicDim);
1306  } else {
1307  extractSliceSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
1308  if (srcShape[i] != 1)
1309  shapeForEmptyOp.push_back(rewriter.getIndexAttr(srcShape[i]));
1310  }
1311  // Compute the output shape for ExtractSliceOp - outer-untiled-dims (take
1312  // into account rank-reducing)
1313  if (srcShape[i] != 1) {
1314  readShapeForExtractSlice.push_back(srcShape[i]);
1315  }
1316  }
1317  // Append the tile sizes to "sizes attribute" for ExtractSliceOp and the
1318  // shape for EmptyOp.
1319  auto mixedTiles = unpackOp.getMixedTiles();
1320  extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
1321  shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
1322 
1323  // Explicitly create the type for extract_slice op because the inner tile
1324  // size could be 1. We want to represent the whole inner tile in this case.
1325  auto tileShape = srcShape.drop_front(destRank);
1326  // Append the inner tile shape to the permuted and rank-reduced outer shape.
1327  readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
1328  Type elemType = unpackOp.getSourceType().getElementType();
1329  auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
1330  Value innerTile = rewriter.create<tensor::ExtractSliceOp>(
1331  loc, readType, unpackOp.getSource(), extractSliceOffsets,
1332  extractSliceSizes, extractSliceStrides);
1333 
1334  // 2. Transpose the tile to match the outer corresponding tile order.
1336  srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
1337  // Unpack is a transition out of packed space so we invert the permutation.
1338  perm = invertPermutationVector(perm);
1339  applyPermutationToVector<OpFoldResult>(shapeForEmptyOp, perm);
1340 
1341  Value empty =
1342  rewriter.create<tensor::EmptyOp>(loc, shapeForEmptyOp, elemType);
1343  auto transposedOp =
1344  rewriter.create<linalg::TransposeOp>(loc, innerTile, empty, perm);
1345 
1346  // 3. Handle in-complete tiles if needed. It truncates trailing data from the
1347  // transposed tile.
1348  int numLoops = shapeForEmptyOp.size();
1349  SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr);
1350  SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr);
1351  SmallVector<OpFoldResult> tileSizes;
1352  ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
1353  for (auto i : llvm::seq<unsigned>(0, destRank)) {
1354  if (dimAndTileMapping.count(i) || destShape[i] != 1)
1355  tileSizes.push_back(
1356  tensor::getMixedSize(rewriter, loc, unpackOp.getDest(), i));
1357  }
1358 
1359  auto partialTile = rewriter.create<tensor::ExtractSliceOp>(
1360  loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides);
1361 
1362  // 4. Insert the result to the destination tensor.
1363  SmallVector<OpFoldResult> writeSizes;
1364  SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
1365  SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
1366  for (int i = 0, idx = 0; i < destRank; ++i) {
1367  if (dimAndTileMapping.count(i) || destShape[i] != 1)
1368  writeSizes.push_back(tileSizes[idx++]);
1369  else
1370  writeSizes.push_back(oneIdxAttr);
1371  }
1372  auto insert = rewriter.create<tensor::InsertSliceOp>(
1373  loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes,
1374  writeStrides);
1375  rewriter.replaceOp(unpackOp, insert.getResult());
1376 
1377  return success();
1378 }
1379 
1380 // The following are patterns for downscaling convolution ops with size-1
1381 // window dimensions.
1382 //
1383 // Note that we'd eventually want to write such transformations in a generic
1384 // way, e.g., converting to linalg.generic, removing the size-1 dimensions,
1385 // and then turning back to named ops. But for now it's fine to have a few
1386 // patterns matching special ops to get started.
1387 
1388 template <typename Conv2DOp, typename Conv1DOp>
1390  returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const {
1391  if (convOp.hasPureBufferSemantics())
1392  return failure(); // To be implemented.
1393 
1394  Value input = convOp.getInputs().front();
1395  Value kernel = convOp.getInputs().back();
1396  Value output = convOp.getOutputs().front();
1397 
1398  auto inputType = dyn_cast<RankedTensorType>(input.getType());
1399  auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1400  auto outputType = dyn_cast<RankedTensorType>(output.getType());
1401 
1402  auto kernelShape = kernelType.getShape();
1403  auto outputShape = outputType.getShape();
1404 
1405  // Get domain indices based on conv2D layout.
1406  auto [khIndex, kwIndex, ohIndex, owIndex] =
1408  convOp)
1409  .Case([&](linalg::Conv2DNhwcHwcfOp op) {
1410  return std::make_tuple(0, 1, 1, 2);
1411  })
1412  .Case([&](linalg::Conv2DNchwFchwOp op) {
1413  return std::make_tuple(2, 3, 2, 3);
1414  })
1415  .Case([&](linalg::PoolingNhwcSumOp op) {
1416  return std::make_tuple(0, 1, 1, 2);
1417  })
1418  .Case([&](linalg::PoolingNchwSumOp op) {
1419  return std::make_tuple(0, 1, 2, 3);
1420  })
1421  .Case([&](linalg::PoolingNhwcMaxOp op) {
1422  return std::make_tuple(0, 1, 1, 2);
1423  })
1424  .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
1425  return std::make_tuple(0, 1, 1, 2);
1426  })
1427  .Case([&](linalg::PoolingNhwcMinOp op) {
1428  return std::make_tuple(0, 1, 1, 2);
1429  })
1430  .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
1431  return std::make_tuple(0, 1, 1, 2);
1432  })
1433  .Case([&](linalg::PoolingNchwMaxOp op) {
1434  return std::make_tuple(0, 1, 2, 3);
1435  })
1436  .Default([&](Operation *op) {
1437  llvm_unreachable("unexpected conv2d/pool2d operation.");
1438  return std::make_tuple(0, 0, 0, 0);
1439  });
1440 
1441  // Only handle the case where at least one of the window dimensions is
1442  // of size 1. Other cases can rely on tiling to reduce to such cases.
1443  int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
1444  int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
1445  bool removeH = (khSize == 1 && ohSize == 1);
1446  bool removeW = (kwSize == 1 && owSize == 1);
1447  if (!removeH && !removeW)
1448  return failure();
1449 
1450  // Get new shapes and types for all operands by removing the size-1
1451  // dimension.
1452  using RTTBuilder = RankedTensorType::Builder;
1453  RankedTensorType newInputType =
1454  RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex));
1455  RankedTensorType newKernelType =
1456  RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex));
1457  RankedTensorType newOutputType =
1458  RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex));
1459 
1460  // Rank-reduce operands.
1461  Location loc = convOp.getLoc();
1463  rewriter, loc, input, newInputType);
1465  rewriter, loc, kernel, newKernelType);
1467  rewriter, loc, output, newOutputType);
1468 
1469  // Rank-reduce strides and dilations too.
1470  // TODO: dropDim 1-liner helper.
1471  auto strides =
1472  llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
1473  strides.erase(strides.begin() + (removeH ? 0 : 1));
1474  auto stridesAttr = rewriter.getI64VectorAttr(strides);
1475 
1476  auto dilations =
1477  llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
1478  dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1479  auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1480 
1481  auto conv1DOp = rewriter.create<Conv1DOp>(
1482  loc, newOutputType, ValueRange{newInput, newKernel},
1483  ValueRange{newOutput}, stridesAttr, dilationsAttr);
1484 
1485  // Insert back.
1487  rewriter, loc, conv1DOp.getResult(0), output);
1488  rewriter.replaceOp(convOp, inserted);
1489 
1490  return conv1DOp;
1491 }
1492 
1493 template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp,
1494  Conv1DNwcWcfOp>;
1495 template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
1496  Conv1DNcwFcwOp>;
1497 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp,
1498  PoolingNwcSumOp>;
1499 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp,
1500  PoolingNcwSumOp>;
1501 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp,
1502  PoolingNwcMaxOp>;
1504  PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
1505 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp,
1506  PoolingNwcMinOp>;
1508  PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
1509 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp,
1510  PoolingNcwMaxOp>;
1511 
1512 FailureOr<DepthwiseConv1DNwcWcOp>
1514  DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const {
1515  if (convOp.hasPureBufferSemantics())
1516  return failure(); // To be implemented.
1517 
1518  Value input = convOp.getInputs().front();
1519  Value kernel = convOp.getInputs().back();
1520  Value output = convOp.getOutputs().front();
1521 
1522  auto inputType = dyn_cast<RankedTensorType>(input.getType());
1523  auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1524  auto outputType = dyn_cast<RankedTensorType>(output.getType());
1525 
1526  auto kernelShape = kernelType.getShape();
1527  auto outputShape = outputType.getShape();
1528 
1529  // Only handle the case where at least one of the window dimensions is
1530  // of size 1. Other cases can rely on tiling to reduce to such cases.
1531  int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1532  int64_t ohSize = outputShape[1], owSize = outputShape[2];
1533  bool removeH = (khSize == 1 && ohSize == 1);
1534  bool removeW = (kwSize == 1 && owSize == 1);
1535  if (!removeH && !removeW)
1536  return failure();
1537 
1538  // Get new shapes and types for all operands by removing the size-1
1539  // dimension.
1540  using RTTBuilder = RankedTensorType::Builder;
1541  RankedTensorType newInputType =
1542  RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
1543  RankedTensorType newKernelType =
1544  RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1545  RankedTensorType newOutputType =
1546  RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
1547 
1548  // Rank-reduce operands.
1549  Location loc = convOp.getLoc();
1551  rewriter, loc, input, newInputType);
1553  rewriter, loc, kernel, newKernelType);
1555  rewriter, loc, output, newOutputType);
1556 
1557  // Rank-reduce strides and dilations too.
1558  // TODO: dropDim 1-liner helper.
1559  auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
1560  strides.erase(strides.begin() + (removeH ? 0 : 1));
1561  auto stridesAttr = rewriter.getI64VectorAttr(strides);
1562 
1563  auto dilations =
1564  llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
1565  dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1566  auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1567 
1568  auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
1569  loc, newOutputType, ValueRange{newInput, newKernel},
1570  ValueRange{newOutput}, stridesAttr, dilationsAttr);
1571 
1572  // Insert back.
1574  rewriter, loc, conv1DOp.getResult(0), output);
1575  rewriter.replaceOp(convOp, inserted);
1576 
1577  return conv1DOp;
1578 }
1579 
1580 FailureOr<Conv1DOp>
1582  PatternRewriter &rewriter) const {
1583  if (convOp.hasPureBufferSemantics())
1584  return failure(); // To be implemented.
1585 
1586  Value input = convOp.getInputs().front();
1587  Value kernel = convOp.getInputs().back();
1588  Value output = convOp.getOutputs().front();
1589 
1590  auto inputType = dyn_cast<RankedTensorType>(input.getType());
1591  auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1592  auto outputType = dyn_cast<RankedTensorType>(output.getType());
1593 
1594  auto kernelShape = kernelType.getShape();
1595  auto outputShape = outputType.getShape();
1596 
1597  // Only handle the case where at least one of the window dimensions is
1598  // of size 1. Other cases can rely on tiling to reduce to such cases.
1599  int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1600  int64_t ohSize = outputShape[0], owSize = outputShape[1];
1601  bool removeH = (khSize == 1 && ohSize == 1);
1602  bool removeW = (kwSize == 1 && owSize == 1);
1603  if (!removeH && !removeW)
1604  return failure();
1605 
1606  // Get new shapes and types for all operands by removing the size-1
1607  // dimension.
1608  using RTTBuilder = RankedTensorType::Builder;
1609  RankedTensorType newInputType =
1610  RTTBuilder(inputType).dropDim((removeH ? 0 : 1));
1611  RankedTensorType newKernelType =
1612  RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1613  RankedTensorType newOutputType =
1614  RTTBuilder(outputType).dropDim(removeH ? 0 : 1);
1615 
1616  // Rank-reduce operands.
1617  Location loc = convOp.getLoc();
1619  rewriter, loc, input, newInputType);
1621  rewriter, loc, kernel, newKernelType);
1623  rewriter, loc, output, newOutputType);
1624 
1625  auto conv1DOp = rewriter.create<Conv1DOp>(loc, newOutputType,
1626  ValueRange{newInput, newKernel},
1627  ValueRange{newOutput});
1628 
1629  // Insert back.
1631  rewriter, loc, conv1DOp.getResult(0), output);
1632  rewriter.replaceOp(convOp, inserted);
1633 
1634  return conv1DOp;
1635 }
1636 
1638  PatternBenefit benefit) {
1639  patterns.add<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
1640  Conv1DNwcWcfOp>,
1641  DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
1642  Conv1DNcwFcwOp>,
1644  patterns.getContext(), benefit);
1645  patterns.add<
1649  DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxUnsignedOp,
1650  PoolingNwcMaxUnsignedOp>,
1652  DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinUnsignedOp,
1653  PoolingNwcMinUnsignedOp>,
1655  patterns.getContext(), benefit);
1656 }
1657 
1659  // TODO: Add and test patterns for tensor.unpack
1661 }
1662 
1664  patterns.add<DecomposePadOpPattern>(patterns.getContext());
1665 }
static RankedTensorType permuteShape(RankedTensorType tensorType, ArrayRef< int64_t > permutationVector)
Return a copy of tensorType after permutation by permutationVector.
Definition: Transforms.cpp:619
static SmallVector< int64_t > getPackUnpackRankReducedPerm(ArrayRef< int64_t > shape, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
static std::optional< int64_t > getFirstResultIndexFunctionOf(AffineMap map, int64_t dim)
Return the index of the first result of map that is a function of AffineDimExpr(dim),...
Definition: Transforms.cpp:102
static FailureOr< SmallVector< std::optional< int64_t > > > packLinalgMetadataOnce(SmallVectorImpl< AffineMap > &indexingMaps, SmallVectorImpl< utils::IteratorType > &iteratorTypes, int64_t dim)
Perform one step of packing of a LinalgOp's metadata along dim into the newDim at iteratorTypes....
Definition: Transforms.cpp:149
static LinalgOp transposeOneLinalgOperandAndReplace(RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand, ArrayRef< int64_t > permutation, Value transposedValue)
Return a new GenericOp obtained by transposing opOperand by the permutation vector:
Definition: Transforms.cpp:632
static Value getPackOpSourceOrPaddedSource(OpBuilder &builder, tensor::PackOp packOp)
If padding value is set, returns a tensor.pad Op for the source tensor, with the output shape matchin...
static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim)
Return true if map has 0 or 1 result function of AffineDimExpr(dim).
Definition: Transforms.cpp:87
static SmallVector< int64_t > getPackUnpackNormalizedPerm(int rank, ArrayRef< int64_t > perm)
#define DBGSNL()
Definition: Transforms.cpp:46
#define DBGS()
Definition: Transforms.cpp:45
Base type for affine expression.
Definition: AffineExpr.h:68
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
Definition: AffineExpr.cpp:316
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:964
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
Definition: AffineMap.h:267
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
Definition: AffineMap.h:315
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:404
MLIRContext * getContext() const
Definition: Builders.h:56
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:168
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:440
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:529
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
result_range getResults()
Definition: Operation.h:415
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:261
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:288
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:272
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1194
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp namedOp)
Create a GenericOp from the given named operation namedOp and replace namedOp.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, tensor::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
Definition: Transforms.cpp:219
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel 'loops' and applies affine_min/max bounds simplification on the fly where relevant.
Definition: Transforms.cpp:75
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
Definition: Interchange.cpp:50
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pack and tensor.unpack Ops into e.g.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
Definition: Transforms.cpp:768
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Definition: Transforms.cpp:479
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
Definition: Transforms.cpp:356
SmallVector< Value > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel and canonicalize loop op and return the new result.
Definition: Transforms.cpp:59
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Definition: Transforms.cpp:677
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
SmallVector< int64_t > getUnPackInverseSrcPerm(tensor::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Definition: TensorOps.cpp:3006
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
Definition: TensorOps.cpp:2627
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:56
SmallVector< int64_t > getPackInverseDestPerm(tensor::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, SmallVector< Value > dynOutDim={})
Definition: Utils.cpp:25
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:66
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:387
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:1282
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override
Definition: Transforms.cpp:919
Rewrites a tensor::PackOp into a sequence of:
Definition: Transforms.h:1554
LogicalResult matchAndRewrite(tensor::PackOp packOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(tensor::UnPackOp unpackOp, PatternRewriter &rewriter) const override
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
Definition: Transforms.h:1508
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Definition: Transforms.cpp:943
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
Definition: Transforms.cpp:926
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Definition: Transforms.h:1428
FailureOr< DepthwiseConv1DNwcWcOp > returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D convolution ops with size-1 window dimensions into 1-D convolution ops.
Definition: Transforms.h:1408
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override
Definition: Transforms.cpp:993
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
Definition: Transforms.h:202
Struct to hold the result of a pack call.
Definition: Transforms.h:1140
Struct to hold the result of a packTranspose call.
Definition: Transforms.h:1152