MLIR  21.0.0git
Transforms.cpp
Go to the documentation of this file.
1 //===- Transforms.cpp - Linalg transformations as patterns ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
28 #include "mlir/IR/AffineExpr.h"
29 #include "mlir/IR/Matchers.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Support/LLVM.h"
33 #include "llvm/ADT/ScopeExit.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/InterleavedRange.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include <type_traits>
39 #include <utility>
40 
41 #define DEBUG_TYPE "linalg-transforms"
42 
43 using namespace mlir;
44 using namespace mlir::linalg;
45 
46 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
47 #define DBGSNL() (llvm::dbgs() << "\n")
48 
49 //===----------------------------------------------------------------------===//
50 // Transformations exposed as functional-style API calls.
51 //===----------------------------------------------------------------------===//
52 
53 //===----------------------------------------------------------------------===//
54 // peelLoop transformation.
55 //===----------------------------------------------------------------------===//
56 
57 /// Try to peel and canonicalize loop `op` and return the new result.
58 /// Also applies affine_min/max bounds simplification on the fly where relevant.
59 // TODO: Add support for scf.parallel and affine.for loops.
61  Operation *op) {
63  .Case<scf::ForOp>([&](scf::ForOp forOp) {
64  scf::ForOp partialIteration;
65  if (succeeded(scf::peelForLoopAndSimplifyBounds(rewriter, forOp,
66  partialIteration)))
67  return partialIteration->getResults();
68  assert(!partialIteration && "expected that loop was not peeled");
69  return forOp->getResults();
70  })
71  .Default([&](Operation *op) { return op->getResults(); });
72 }
73 
74 /// Peel 'loops' and applies affine_min/max bounds simplification on the fly
75 /// where relevant.
77  ArrayRef<scf::ForOp> loops) {
78  for (auto loopOp : loops)
79  peelLoop(rewriter, loopOp);
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // pack transformation.
84 //===----------------------------------------------------------------------===//
85 
86 #ifndef NDEBUG
87 /// Return true if `map` has 0 or 1 result function of AffineDimExpr(dim).
88 static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) {
89  bool found = false;
90  for (AffineExpr e : map.getResults()) {
91  if (!e.isFunctionOfDim(dim))
92  continue;
93  if (found)
94  return false;
95  found = true;
96  }
97  return true;
98 }
99 
101  return llvm::interleaved(ri, ", ", /*Prefix=*/"|", /*Suffix=*/"");
102 }
103 #endif // NDEBUG
104 
105 /// Return the index of the first result of `map` that is a function of
106 /// AffineDimExpr(dim), std::nullopt otherwise.
107 static std::optional<int64_t> getFirstResultIndexFunctionOf(AffineMap map,
108  int64_t dim) {
109  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
110  AffineExpr expr = map.getResult(i);
111  if (!expr.isFunctionOfDim(dim))
112  continue;
113  return i;
114  }
115  return std::nullopt;
116 }
117 
118 /// Perform one step of packing of a LinalgOp's metadata along `dim` into the
119 /// `newDim` at `iteratorTypes.size()` by:
120 /// 1. Appending `iteratorTypes[newDim]`, equal to `iteratorTypes[dim]`.
121 /// 2. Appending a `newDim` to the domain of every indexing map.
122 /// 3. For each operand (i.e. for each map in `indexingMaps`), perform packing
123 /// by potentially adding a `newDim` result to `map`.
124 /// The preserved invariant is that `iteratorTypes.size()` is always equal to
125 /// `map.getNumDims()` for every map in `indexingMaps`.
126 ///
127 /// Update `indexingMaps` and `iteratorTypes` inplace as one step of the update.
128 /// Return a vector that records the optional packing for each operand.
129 /// Return failure if the packed indexing cannot be represented with a LinalgOp.
130 ///
131 /// Further details:
132 /// ================
133 /// The current implementation of packing (i.e. data tiling) consists of
134 /// rewriting a linearized strip-mined form into a higher-dimensional access.
135 /// e.g. consider an access `A[I][f(j, k, l)]` and packing by 4; we rewrite
136 /// `I` into `4 * i + ii`, where `0 <= ii < 4`.
137 /// The access is further rewritten as `A[i][f(j, k, l)][ii]`.
138 ///
139 /// This rewrite into higher dimensional access is not possible for general
140 /// AffineExpr in Linalg atm, it is restricted to an AffineDimExpr:
141 /// e.g. consider an access `A[I + J][f(j, k, l)]` and packing by 4; we
142 /// rewrite `I + J` into `4 * i + ii + J`, where `0 <= ii < 4`.
143 /// The rewrite of the access would be a form not representable in Linalg:
144 /// `A[i + (ii + J) / 4][f(j, k, l)][(ii + J) % 4]`.
145 /// Note however that as `J` and `ii` iterate, the accesses do not have a
146 /// particular alignment, so packing does not achieve alignment in this case
147 ///
148 /// In the future, we may want to consider a mixed-form that allows some
149 /// alignment in the presence of multiple accesses:
150 /// `A[I][f(j, k, l)]` and `B[I + J][f(j, k, l)]`
151 /// And would rewrite accesses as:
152 /// `A[i][f(j, k, l)][ii]` and `B[4 * i + ii + J][f(j, k, l)]`
153 static FailureOr<SmallVector<std::optional<int64_t>>>
156  int64_t dim) {
157  int64_t newDim = iteratorTypes.size();
158  iteratorTypes.push_back(iteratorTypes[dim]);
159 
160  SmallVector<std::optional<int64_t>> packedDimPerIndexingMap(
161  indexingMaps.size(), std::nullopt);
162  SmallVector<AffineMap> newMaps;
163  for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
164  ++operandIdx) {
165  AffineMap map = indexingMaps[operandIdx];
166 
167  // Add the `newDim` to map whatever the case.
168  assert(map.getNumDims() == newDim && "num dims invariant violation");
169  map = map.shiftDims(1, newDim);
170 
171  // Get the at-most-1 index of the result that is a function of `dim`.
172  // If we can find one, we insert `AffineDimExpr(newDim)` to the map, which
173  // logically chunks dimension `dim` into `K * dim + newDim`, where the
174  // packing factor `K` is specified separately.
175  assert(hasAtMostOneResultFunctionOfDim(map, dim) &&
176  "num results invariant violation");
177  auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim);
178  if (!maybeOperandDimensionToPack.has_value()) {
179  newMaps.push_back(map);
180  continue;
181  }
182 
183  // We can only pack AffineDimExpr atm.
184  if (!isa<AffineDimExpr>(map.getResult(maybeOperandDimensionToPack.value())))
185  return failure();
186 
187  // Add `newDim` to the results of the map.
188  map = map.insertResult(Builder(map.getContext()).getAffineDimExpr(newDim),
189  map.getNumResults());
190  newMaps.push_back(map);
191 
192  // Record the that `operandIdx` is packed.
193  packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
194  }
195  indexingMaps = newMaps;
196 
197  return packedDimPerIndexingMap;
198 }
199 
200 namespace {
201 
202 /// Helper struct to encode packing along one dimension of a LinalgOp.
203 struct PackedOperandsDim {
204  OpFoldResult packedSize;
205  SmallVector<std::optional<int64_t>> packedDimForEachOperand;
206 };
207 
208 /// Helper struct to encode packing along all dimensions of a LinalgOp.
209 struct PackedOperandsDimList {
210  void pushBack(PackedOperandsDim &&packedOperandsDims) {
211  spec.emplace_back(packedOperandsDims);
212  }
213  /// Return all the dims that have been packed for operand @ `operandPos`.
214  SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos);
215  /// Return all the pack sizes by which an operand @ `operandPos` is packed.
216  SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos);
217 
218 private:
220 };
221 
222 } // namespace
223 
224 FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
225  linalg::PackOp packOp,
226  bool lowerPadLikeWithInsertSlice) {
227  // 1. Filter out NYI cases.
228  auto packedTensorType =
229  cast<RankedTensorType>(packOp->getResultTypes().front());
230  if (llvm::any_of(packOp.getStaticInnerTiles(),
231  [](int64_t size) { return ShapedType::isDynamic(size); })) {
232  return rewriter.notifyMatchFailure(
233  packOp,
234  "non-static shape NYI, needs a more powerful tensor.expand_shape op");
235  }
236 
237  Location loc = packOp->getLoc();
238  OpBuilder::InsertionGuard g(rewriter);
239  rewriter.setInsertionPoint(packOp);
240 
241  // 2. Compute the permutation vector to shuffle packed shape into the shape
242  // before any outer or inner permutations have been applied.
243  PackingMetadata packingMetadata = computePackingMetadata(
244  packedTensorType.getRank(), packOp.getInnerDimsPos());
245  SmallVector<int64_t> packedToStripMinedShapePerm =
246  getPackInverseDestPerm(packOp);
247 
248  // 3. Compute the stripMinedShape: this is the packed shape before any outer
249  // or inner permutations have been applied.
250  SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
251  applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
252 
253  // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
254  SmallVector<OpFoldResult> lows(packOp.getSourceRank(),
255  rewriter.getIndexAttr(0));
256  SmallVector<OpFoldResult> highs(packOp.getSourceRank(),
257  rewriter.getIndexAttr(0));
258  for (auto [pos, innerSize] :
259  llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
260  int outerPos =
261  packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
262  OpFoldResult origSize =
263  tensor::getMixedSize(rewriter, loc, packOp.getSource(), pos);
264  OpFoldResult outerSize =
265  tensor::getMixedSize(rewriter, loc, packOp.getDest(), outerPos);
266  AffineExpr s0, d0, d1;
267  bindDims(rewriter.getContext(), d0, d1);
268  bindSymbols(rewriter.getContext(), s0);
269  auto map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/1, d0 * s0 - d1);
271  rewriter, loc, map, {outerSize, origSize, innerSize});
272  }
273  RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
274  RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
275  packingMetadata.reassociations);
276  Value paddingValue = packOp.getPaddingValue();
277  if (!paddingValue) {
278  paddingValue = rewriter.create<arith::ConstantOp>(
279  loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
280  }
281  auto padOp =
282  rewriter.create<tensor::PadOp>(loc, collapsed, packOp.getSource(), lows,
283  highs, paddingValue, /*nofold=*/false);
284 
285  LLVM_DEBUG(
286  DBGSNL(); DBGSNL();
287  DBGS() << "insertPositions: "
288  << llvm::interleaved(packingMetadata.insertPositions);
289  DBGSNL(); DBGS() << "outerPositions: "
290  << llvm::interleaved(packingMetadata.outerPositions);
291  DBGSNL(); DBGS() << "packedShape: "
292  << llvm::interleaved(packedTensorType.getShape());
293  DBGSNL(); DBGS() << "packedToStripMinedShapePerm: "
294  << llvm::interleaved(packedToStripMinedShapePerm);
295  DBGSNL();
296  DBGS() << "reassociations: "
297  << llvm::interleaved(llvm::map_range(
298  packingMetadata.reassociations, stringifyReassocIndices));
299  DBGSNL();
300  DBGS() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
301  DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL(););
302 
303  if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
304  // Pack ops which operate as simple pads may not produce legal
305  // tensor.insert_slice operations when the packed type does not rank reduce
306  // to the padded type.
307  SliceVerificationResult rankReduces =
308  isRankReducedType(packedTensorType, padOp.getResultType());
309 
310  if (rankReduces == SliceVerificationResult::Success) {
311  // This pack is just a plain pad.
312  // Just insert the pad in the higher ranked tensor.
313  // Offsets.
314  SmallVector<OpFoldResult> zeros(packOp.getDestRank(),
315  rewriter.getIndexAttr(0));
316  // Strides.
317  SmallVector<OpFoldResult> ones(packOp.getDestRank(),
318  rewriter.getIndexAttr(1));
320  tensor::getMixedSizes(rewriter, loc, packOp.getDest());
321 
322  auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>(
323  loc, /*source=*/padOp, /*dest=*/packOp.getDest(),
324  /*offsets=*/zeros, sizes, /*strides=*/ones);
325 
326  LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL(););
327 
328  rewriter.replaceOp(packOp, insertSliceOp->getResults());
329 
330  return LowerPackResult{padOp, /*reshapeOp=*/nullptr,
331  /*transposeOp=*/nullptr};
332  }
333  }
334 
335  // 5. Expand from the padded result to the stripMinedShape.
336  auto expandShapeResultType =
337  RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
338  auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>(
339  loc, expandShapeResultType, padOp.getResult(),
340  packingMetadata.reassociations);
341 
342  // 6. Transpose stripMinedShape to packedShape.
343  SmallVector<int64_t> transpPerm =
344  invertPermutationVector(packedToStripMinedShapePerm);
345  auto transposeOp = rewriter.create<linalg::TransposeOp>(
346  loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
347 
348  LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
349  DBGS() << "reshape op: " << reshapeOp; DBGSNL();
350  DBGS() << "transpPerm: " << llvm::interleaved(transpPerm);
351  DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL(););
352 
353  // 7. Replace packOp by transposeOp.
354  rewriter.replaceOp(packOp, transposeOp->getResults());
355 
356  return LowerPackResult{padOp, reshapeOp, transposeOp};
357 }
358 
359 FailureOr<LowerUnPackOpResult>
360 linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
361  bool lowerUnpadLikeWithExtractSlice) {
362  Location loc = unPackOp->getLoc();
363  OpBuilder::InsertionGuard g(rewriter);
364  rewriter.setInsertionPoint(unPackOp);
365 
366  RankedTensorType packedTensorType = unPackOp.getSourceType();
367  int64_t packedRank = packedTensorType.getRank();
368 
369  OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
370  auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
371  if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
372  // This unpack is just a plain unpad.
373  // Just extract the slice from the higher ranked tensor.
374  ArrayRef<int64_t> destShape = destTensorType.getShape();
375  // The inner dimensions stay the same as the destination tensor, but the
376  // outer ones are additional 1s.
377  SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one);
378  sizes.append(tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()));
379 
380  auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
381  loc, destTensorType, unPackOp.getSource(),
382  SmallVector<OpFoldResult>(packedRank, zero), sizes,
383  SmallVector<OpFoldResult>(packedRank, one));
384 
385  rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
386 
387  return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
388  /*reshapeOp=*/nullptr, extractSliceOp};
389  }
390 
391  // 1. Compute the permutation vector to shuffle packed shape into the shape
392  // before any outer or inner permutations have been applied.
393  PackingMetadata packingMetadata;
394  SmallVector<int64_t> packedToStripMinedShapePerm =
395  getUnPackInverseSrcPerm(unPackOp, packingMetadata);
396 
397  // 2. Compute the stripMinedShape: this is the packed shape without outer and
398  // inner permutations.
399  SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
400  applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
401 
402  // 3. Transpose packedShape to stripMinedShape.
403  RankedTensorType stripMinedTensorType =
404  RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
405  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
406  stripMinedTensorType, packingMetadata.reassociations);
407 
408  // Get dynamic dims from input tensor based on packedToStripMinedShapePerm
409  // permutation.
411  tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
412  applyPermutationToVector(dims, packedToStripMinedShapePerm);
413  auto emptyOp = rewriter.create<tensor::EmptyOp>(
414  loc, dims, stripMinedTensorType.getElementType());
415  auto transposeOp = rewriter.create<linalg::TransposeOp>(
416  loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
417 
418  LLVM_DEBUG(
419  DBGSNL(); DBGSNL();
420  DBGS() << "insertPositions: "
421  << llvm::interleaved(packingMetadata.insertPositions);
422  DBGSNL(); DBGS() << "packedShape: "
423  << llvm::interleaved(packedTensorType.getShape());
424  DBGSNL(); DBGS() << "packedToStripMinedShapePerm: "
425  << llvm::interleaved(packedToStripMinedShapePerm);
426  DBGSNL();
427  DBGS() << "reassociations: "
428  << llvm::interleaved(llvm::map_range(
429  packingMetadata.reassociations, stringifyReassocIndices));
430  DBGSNL();
431  DBGS() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
432  DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL(););
433 
434  // 4. Collapse from the stripMinedShape to the padded result.
435  auto reshapeOp = rewriter.create<tensor::CollapseShapeOp>(
436  loc, collapsedType, transposeOp->getResult(0),
437  packingMetadata.reassociations);
438 
439  // 5. ExtractSlice.
440  int64_t destRank = destTensorType.getRank();
441  auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
442  loc, destTensorType, reshapeOp->getResult(0),
443  SmallVector<OpFoldResult>(destRank, zero),
444  tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()),
445  SmallVector<OpFoldResult>(destRank, one));
446 
447  // 6. Inject a copy to preserve DPS.
448  auto copyOp = rewriter.create<linalg::CopyOp>(
449  loc, extractSliceOp->getResult(0), unPackOp.getDest());
450 
451  // 7. Replace unPackOp by copyOp.
452  rewriter.replaceOp(unPackOp, copyOp->getResults());
453 
454  return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};
455 }
456 
458 PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
460  for (auto &i : spec) {
461  if (!i.packedDimForEachOperand[operandPos].has_value())
462  continue;
463  res.push_back(i.packedDimForEachOperand[operandPos].value());
464  }
465  return res;
466 }
467 
469 PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
471  for (auto &i : spec) {
472  if (!i.packedDimForEachOperand[operandPos].has_value())
473  continue;
474  res.push_back(i.packedSize);
475  }
476  return res;
477 }
478 
479 /// Implement packing of a single LinalgOp by performing packing by
480 /// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator.
481 /// Return the packed Linalg op on success, failure otherwise.
482 FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
483  linalg::LinalgOp linalgOp,
484  ArrayRef<OpFoldResult> packedSizes) {
485  if (packedSizes.size() != linalgOp.getNumLoops()) {
486  return rewriter.notifyMatchFailure(linalgOp,
487  "incorrect number of pack sizes");
488  }
489 
490  Location loc = linalgOp->getLoc();
491  SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
492  SmallVector<utils::IteratorType> iteratorTypes =
493  linalgOp.getIteratorTypesArray();
494  LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n"
495  << "maps: " << llvm::interleaved(indexingMaps) << "\n"
496  << "iterators: " << llvm::interleaved(iteratorTypes)
497  << "\n");
498 
501  // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i].
502  PackedOperandsDimList listOfPackedOperandsDim;
503  for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
504  std::optional<int64_t> maybeConstant = getConstantIntValue(packedSizes[i]);
505  // Skip tile sizes explicitly set to 0.
506  if (maybeConstant.has_value() && maybeConstant.value() == 0)
507  continue;
508 
509  PackedOperandsDim packedOperandsDims;
510  packedOperandsDims.packedSize = packedSizes[i];
511  FailureOr<SmallVector<std::optional<int64_t>>>
512  maybePackedDimForEachOperand =
513  packLinalgMetadataOnce(indexingMaps, iteratorTypes, i);
514  if (failed(maybePackedDimForEachOperand))
515  return failure();
516  packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
517  listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
518 
519  LLVM_DEBUG(
520  DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i]
521  << "\n"
522  << "maps: " << llvm::interleaved(indexingMaps) << "\n"
523  << "iterators: " << llvm::interleaved(iteratorTypes) << "\n"
524  << "packedDimForEachOperand: "
525  << llvm::interleaved(packedOperandsDims.packedDimForEachOperand)
526  << "\n");
527  }
528 
529  // Step 2. Propagate packing to all LinalgOp operands.
530  SmallVector<Value> inputsAndInits, results;
531  SmallVector<OpOperand *> initOperands =
532  llvm::to_vector(llvm::make_pointer_range(linalgOp.getDpsInitsMutable()));
533  SmallVector<OpOperand *> inputOperands = linalgOp.getDpsInputOperands();
534  for (const auto &operandsList : {inputOperands, initOperands}) {
535  for (OpOperand *opOperand : operandsList) {
536  int64_t pos = opOperand->getOperandNumber();
537  Value operand = opOperand->get();
538  SmallVector<int64_t> innerPos =
539  listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
540  SmallVector<OpFoldResult> innerPackSizes =
541  listOfPackedOperandsDim.extractPackSizesForOperand(pos);
542  LLVM_DEBUG(DBGS() << "operand: " << operand << "\n"
543  << "innerPos: " << llvm::interleaved(innerPos) << "\n"
544  << "innerPackSizes: "
545  << llvm::interleaved(innerPackSizes) << "\n");
546  if (innerPackSizes.empty()) {
547  inputsAndInits.push_back(operand);
548  continue;
549  }
550  Value dest = linalg::PackOp::createDestinationTensor(
551  rewriter, loc, operand, innerPackSizes, innerPos,
552  /*outerDimsPerm=*/{});
553  ShapedType operandType = cast<ShapedType>(operand.getType());
554  bool areConstantTiles =
555  llvm::all_of(innerPackSizes, [](OpFoldResult tile) {
556  return getConstantIntValue(tile).has_value();
557  });
558  if (areConstantTiles && operandType.hasStaticShape() &&
559  !linalg::PackOp::requirePaddingValue(
560  operandType.getShape(), innerPos,
561  cast<ShapedType>(dest.getType()).getShape(), {},
562  innerPackSizes)) {
563  packOps.push_back(rewriter.create<linalg::PackOp>(
564  loc, operand, dest, innerPos, innerPackSizes));
565  } else {
566  // TODO: value of the padding attribute should be determined by
567  // consumers.
568  auto zeroAttr =
569  rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
570  Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
571  packOps.push_back(rewriter.create<linalg::PackOp>(
572  loc, operand, dest, innerPos, innerPackSizes, zero));
573  }
574  inputsAndInits.push_back(packOps.back());
575  }
576  }
577 
578  // Step 3. Build the packed op, use the type of `inits` as result types.
579  ValueRange inputs =
580  ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
581  ValueRange inits =
582  ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
583  auto packedLinalgOp = rewriter.create<linalg::GenericOp>(
584  linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps,
585  iteratorTypes);
586  packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
587 
588  // Step 4. Propagate packing to all the op results.
589  for (OpResult result : packedLinalgOp->getResults()) {
590  int64_t resultNum = result.getResultNumber();
591  linalg::PackOp maybePackedInit =
592  inits[resultNum].getDefiningOp<linalg::PackOp>();
593  if (!maybePackedInit) {
594  results.push_back(result);
595  continue;
596  }
597  // Build the symmetrical UnPackOp to the existing PackOp.
598  unPackOps.push_back(rewriter.create<linalg::UnPackOp>(
599  packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
600  maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
601  results.push_back(unPackOps.back());
602  }
603 
604  // Step 5. Replace `linalgOp`.
605  rewriter.replaceOp(linalgOp, results);
606 
607  // Return packedLinalgOp.
608  return PackResult{packOps,
609  cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
610  unPackOps};
611 }
612 
613 //===----------------------------------------------------------------------===//
614 // packTranspose transformation.
615 //===----------------------------------------------------------------------===//
616 
617 /// Return a copy of `tensorType` after permutation by `permutationVector`.
618 // Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder
619 // but this would introduce a dependence on Dialect in IR.
620 // TODO: Restructure.
621 static RankedTensorType permuteShape(RankedTensorType tensorType,
622  ArrayRef<int64_t> permutationVector) {
623  SmallVector<int64_t> shape(tensorType.getShape());
624  applyPermutationToVector(shape, permutationVector);
625  return RankedTensorType::Builder(tensorType).setShape(shape);
626 }
627 
628 /// Return a new GenericOp obtained by transposing opOperand by the permutation
629 /// vector:
630 /// - the corresponding indexing map is transposed by `permutation`
631 /// - the corresponding operand value is replaced by `transposedValue`
632 /// `linalgOp` is replaced by the return op in the process.
633 /// Asserts that `transposedValue` is of the proper transposed ShapedType.
635  RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand,
636  ArrayRef<int64_t> permutation, Value transposedValue) {
637  // Sanity check the operand.
638  assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand");
639 
640  // Sanity check of the expected transposed tensor type.
641  auto tensorType = permuteShape(
642  cast<RankedTensorType>(opOperand.get().getType()), permutation);
643  (void)tensorType;
644  assert(tensorType == transposedValue.getType() &&
645  "expected tensor type mismatch");
646 
647  // Compute the transposed indexing map.
648  // Sigh unsigned pollution.
649  SmallVector<unsigned> tmpTransposition = llvm::to_vector(
650  llvm::map_range(permutation, [](int64_t i) -> unsigned { return i; }));
651  AffineMap permutationMap =
652  AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext());
653  AffineMap transposedMap =
654  permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand));
655 
656  // Set the transposed indexing map in the proper position.
657  SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
658  indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
659  // Set the transposedValue in the proper operand position.
660  SmallVector<Value> operands = linalgOp->getOperands();
661  operands[opOperand.getOperandNumber()] = transposedValue;
662 
663  ValueRange operandsRef(operands);
664  auto transposedGenericOp = rewriter.create<linalg::GenericOp>(
665  /*location=*/linalgOp->getLoc(),
666  /*resultTensorTypes=*/
667  operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(),
668  /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()),
669  /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()),
670  /*indexingMaps=*/indexingMaps,
671  /*iteratorTypes=*/linalgOp.getIteratorTypesArray());
672  transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
673  rewriter.replaceOp(linalgOp, transposedGenericOp->getResults());
674 
675  return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
676 }
677 
678 FailureOr<PackTransposeResult>
679 linalg::packTranspose(RewriterBase &rewriter, linalg::PackOp packOp,
680  linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp,
681  ArrayRef<int64_t> outerPerm,
682  ArrayRef<int64_t> innerPerm) {
683  Location loc = linalgOp.getLoc();
684 
685  // Step 1. Transpose packOp.
686  rewriter.setInsertionPoint(packOp);
687  linalg::PackOp transposedPackOp =
688  packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
689 
690  if (!packOp.getResult().hasOneUse())
691  return rewriter.notifyMatchFailure(linalgOp, "expect single pack use");
692 
693  OpOperand &packUse = *packOp->getUses().begin();
694  if (packUse.getOwner() != linalgOp) {
695  return rewriter.notifyMatchFailure(
696  linalgOp, "not a single use by the LinalgOp target");
697  }
698  if (maybeUnPackOp &&
699  (!linalgOp.isDpsInit(&packUse) ||
700  maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
701  return rewriter.notifyMatchFailure(linalgOp,
702  "not produced by the LinalgOp target");
703  }
704 
705  // Step 2. Transpose linalgOp.
706  // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the
707  // identity. Don't rely on it.
708  int64_t numLeadingDims = packOp.getSourceRank();
709  int64_t numTrailingDims = packOp.getInnerDimsPos().size();
710  // Step 2.a. Compute the permutation on the whole operand.
711  // Leading part just reuse the outerPerm.
712  SmallVector<int64_t> permutation(outerPerm);
713  if (permutation.empty())
714  llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
715  // Trailing part needs to reindex positions by `numLeadingDims`.
716  if (innerPerm.empty()) {
717  llvm::append_range(
718  permutation,
719  llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
720  } else {
721  llvm::append_range(permutation,
722  llvm::map_range(innerPerm, [&](int64_t pos) {
723  return numLeadingDims + pos;
724  }));
725  }
726  if (!isPermutationVector(permutation))
727  return rewriter.notifyMatchFailure(linalgOp, "invalid permutation");
728 
729  // Step 2.b. Save the transposedPackUse operand number in case we need to
730  // get the tied OpResult after `linalgOp` has been replaced.
731  int64_t packUseOperandNumber = packUse.getOperandNumber();
732  // Step 2.c. Actually perform the transposition.
733  rewriter.setInsertionPoint(linalgOp);
734  linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace(
735  rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
736 
737  // Step 3. Maybe transpose unPackOp.
738  linalg::UnPackOp transposedUnPackOp;
739  if (maybeUnPackOp) {
740  OpOperand &opOperand =
741  transposedLinalgOp->getOpOperand(packUseOperandNumber);
742  OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
743  rewriter.setInsertionPoint(maybeUnPackOp);
744  transposedUnPackOp = maybeUnPackOp.createTransposedClone(
745  rewriter, loc, transposedResult, innerPerm, outerPerm);
746 
747  rewriter.replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
748  }
749 
750  // Step 4. Finally, replace packOp now that we don't need it anymore.
751  rewriter.replaceOp(packOp, transposedPackOp->getResults());
752 
753  return PackTransposeResult{transposedPackOp, transposedLinalgOp,
754  transposedUnPackOp};
755 }
756 
757 //===----------------------------------------------------------------------===//
758 // packMatmulGreedily transformation.
759 //===----------------------------------------------------------------------===//
760 
761 /// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m
762 /// and n are proper parallel dimensions and k is a proper reduction
763 /// dimension. Packing occurs by rewriting the op as a linalg.generic and
764 /// calling linalg::pack by `mnkPackedSizes`. The order of the packed
765 /// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2}
766 /// to reorder {m, n, k} into one of the 8 possible forms. The outer
767 /// dimensions of the operands are not permuted at this time, this is left for
768 /// future work.
769 FailureOr<PackResult>
770 linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
771  ArrayRef<OpFoldResult> mnkPackedSizes,
772  ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf,
773  ArrayRef<int64_t> mnkOrder) {
774  assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes");
775  assert((mnkPaddedSizesNextMultipleOf.empty() ||
776  mnkPaddedSizesNextMultipleOf.size() == 3) &&
777  "num of packing sizes next multiple should be empty or of size 3");
778  assert(mnkOrder.size() == 3 && "unexpected mnkOrder size");
779  assert(isPermutationVector(mnkOrder) && "expected a permutation");
780 
781  int64_t numLoops = linalgOp.getNumLoops();
782  if (numLoops <= 2) {
783  LLVM_DEBUG(DBGS() << "need 3+ loops to find a matmul to pack, got "
784  << numLoops << "\nin: " << linalgOp << "\n");
785  return rewriter.notifyMatchFailure(
786  linalgOp, "need 3+ loops to find a matmul to pack");
787  }
788 
789  // Locally adjust the desired iterator position of mnk and packing sizes.
790  int64_t numPackedDims = mnkPackedSizes.size();
791  SmallVector<int64_t> mmnnkkPos(numPackedDims);
792  for (int64_t i = 0, e = numPackedDims; i < e; ++i)
793  mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
794  SmallVector<OpFoldResult> packedSizes(numPackedDims);
795  for (int64_t i = 0, e = numPackedDims; i < e; ++i)
796  packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
797  SmallVector<int64_t> paddedSizesNextMultipleOf(numPackedDims);
798  for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
799  paddedSizesNextMultipleOf[mnkOrder[i]] =
800  mnkPaddedSizesNextMultipleOf.empty() ? 0
801  : mnkPaddedSizesNextMultipleOf[i];
802  }
803 
804  // 1. Infer dims that are important for matmul.
805  FailureOr<ContractionDimensions> maybeDimensions =
806  inferContractionDims(linalgOp);
807  if (failed(maybeDimensions)) {
808  LLVM_DEBUG(DBGS() << "couldn't infer matmul iterators in: " << linalgOp
809  << "\n");
810  return rewriter.notifyMatchFailure(linalgOp,
811  "couldn't infer matmul iterators");
812  }
813 
814  // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most
815  // minor iterators. In cases with multiple options for m, n, k bias towards
816  // the most minor embedding.
817  // If we wanted a different normalization order, this is where it would have
818  // to plug a heuristic.
819  int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
820  kPos = maybeDimensions->k.back();
821  LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
822  DBGS() << "Start packing generic op greedily with (m@" << mPos
823  << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp
824  << "\n";);
825 
826  // 2.a. Rewrite as a generic.
827  auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
828  if (!genericOp) {
829  FailureOr<GenericOp> generalizeResult =
830  generalizeNamedOp(rewriter, linalgOp);
831  assert(succeeded(generalizeResult) && "unexpected failure generalizing op");
832  genericOp = *generalizeResult;
833  }
834 
835  // 2.b. Interchange to move the dimensions (k, m, n) as most-minor
836  // iterators. Note that this only normalized the iteration order and does
837  // not change the indexings of any operand.
838  SmallVector<int64_t> permutation =
839  computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos);
840  LLVM_DEBUG(DBGS() << "perm: " << llvm::interleaved(permutation) << "\n");
841  // Sign .. unsigned pollution.
842  SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end());
843  FailureOr<GenericOp> interchangeResult =
844  interchangeGenericOp(rewriter, genericOp, unsignedPerm);
845  assert(succeeded(interchangeResult) && "unexpected failure interchanging op");
846  genericOp = *interchangeResult;
847  LLVM_DEBUG(DBGS() << "Generalized Op to pack: " << genericOp << "\n";);
848 
849  // At this point, the op iterators are normalized to {leading, k, m, n}.
850  // The layouts induced by packing will always be:
851  // - LHS{leading_lhs, kk, mm}
852  // - RHS{leading_rhs, kk, nn}
853  // - RES{leading_res, mm, nn}
854  // If we wanted to change the packed order, we would reorder (k, m, n) to
855  // something else above.
856  //
857  // Additional permutations of the outer dims of the operands (i.e.
858  // leading_lhs, leading_rhs and leading_res) could follow by computing the
859  // desired outerPerm for each operand.
860  // This is left for future work.
861 
862  // TODO: this creates too much IR, go use reifyResultShapes.
863  SmallVector<Range, 4> loopRanges =
864  cast<LinalgOp>(genericOp.getOperation())
865  .createLoopRanges(rewriter, genericOp.getLoc());
866 
867  // Add leading zeros to match numLoops, we only pack the last 3 dimensions
868  // post interchange.
869  LLVM_DEBUG(DBGS() << "paddedSizesNextMultipleOf: "
870  << llvm::interleaved(paddedSizesNextMultipleOf) << "\n"
871  << "loopRanges: "
872  << llvm::interleaved(llvm::map_range(
873  loopRanges, [](Range r) { return r.size; }))
874  << "\n");
875  SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(),
876  rewriter.getIndexAttr(0));
877  for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
878  if (paddedSizesNextMultipleOf[i] == 0) {
879  adjustedPackedSizes.push_back(packedSizes[i]);
880  continue;
881  }
882  AffineExpr d0, s0;
883  bindDims(rewriter.getContext(), d0);
884  bindSymbols(rewriter.getContext(), s0);
885  adjustedPackedSizes.push_back(affine::makeComposedFoldedAffineApply(
886  rewriter, genericOp->getLoc(), d0.ceilDiv(s0) * s0,
887  {loopRanges[adjustedPackedSizes.size()].size,
888  rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
889  }
890  LLVM_DEBUG(DBGS() << "adjustedPackedSizes: "
891  << llvm::interleaved(adjustedPackedSizes) << "\n");
892 
893  // TODO: If we wanted to give the genericOp a name after packing, after
894  // calling `pack` would be a good time. One would still need to check that
895  // `containsMostMinorMatmul(packingRes->packedLinalgOp)` is true, since we
896  // also allow degenerate matmul cases (i.e. matvec, dot).
897  return pack(rewriter, genericOp, adjustedPackedSizes);
898 }
899 
900 //===----------------------------------------------------------------------===//
901 // Transformations exposed as rewrite patterns.
902 //===----------------------------------------------------------------------===//
903 
906  assert(!tileSizeComputationFunction && "tile sizes already set");
907  SmallVector<int64_t, 4> tileSizes(ts);
908  tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
909  OpBuilder::InsertionGuard guard(b);
911  &op->getParentOfType<func::FuncOp>().getBody().front());
912  return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
913  Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s);
914  return v;
915  }));
916  };
917  return *this;
918 }
919 
921  memref::CopyOp copyOp, PatternRewriter &rewriter) const {
922  return vectorizeCopy(rewriter, copyOp);
923 }
924 
925 /// Filling `dest` using FillOp constant padding value if possible.
926 /// Otherwise, generate a tensor::GenerateOp.
928  RewriterBase &rewriter, tensor::PadOp padOp, Value dest,
929  const SmallVector<Value> &dynSizes) const {
930  auto padValue = padOp.getConstantPaddingValue();
931  if (padValue) {
932  // Move the padding value defined inside the PadOp block to outside.
933  if (padValue.getParentBlock() == &padOp.getRegion().front())
934  rewriter.moveOpBefore(padValue.getDefiningOp(), padOp);
935  return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result();
936  }
937 
938  // Fill could not be optimized: Lower to tensor::GenerateOp with region.
939  auto generateOp = rewriter.create<tensor::GenerateOp>(
940  padOp.getLoc(), padOp.getResultType(), dynSizes);
941  // Copy region to new op.
942  IRMapping bvm;
943  padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
944  return generateOp;
945 }
946 
947 LogicalResult
949  PatternRewriter &rewriter) const {
950  // Given an OpFoldResult, return an index-typed value.
951  auto getIdxValue = [&](OpFoldResult ofr) {
952  if (auto val = llvm::dyn_cast_if_present<Value>(ofr))
953  return val;
954  return rewriter
956  padOp.getLoc(), cast<IntegerAttr>(cast<Attribute>(ofr)).getInt())
957  .getResult();
958  };
959 
960  auto resultType = padOp.getResultType();
961  // Compute size of EmptyOp. Any combination of static/dynamic is supported.
962  SmallVector<Value> dynSizes;
963  SmallVector<int64_t> staticSizes;
964  for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
965  if (resultType.isDynamicDim(dim)) {
966  auto srcSize = getIdxValue(tensor::getMixedSize(rewriter, padOp.getLoc(),
967  padOp.getSource(), dim));
968  // Add low and high padding value.
969  auto plusLow = rewriter.createOrFold<arith::AddIOp>(
970  padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
971  auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
972  padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
973  dynSizes.push_back(plusHigh);
974  }
975  staticSizes.push_back(resultType.getDimSize(dim));
976  }
977 
978  // Init tensor and fill it with padding.
979  Value emptyTensor = rewriter.create<tensor::EmptyOp>(
980  padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes);
981  Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes);
982 
983  // Generate a InsertSliceOp for copying the PadOp source.
984  auto sourceType = padOp.getSourceType();
985  // Compute size of source of tensor::PadOp.
986  SmallVector<OpFoldResult> srcSizes =
987  tensor::getMixedSizes(rewriter, padOp.getLoc(), padOp.getSource());
988  // Strides of InsertSliceOp are all 1.
989  SmallVector<OpFoldResult> strides(sourceType.getRank(),
990  rewriter.getIndexAttr(1));
991  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
992  padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
993  strides);
994 
995  return success();
996 }
997 
999  tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
1000  if (!sliceOp.hasUnitStride())
1001  return failure();
1002 
1003  auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
1004  if (!padOp)
1005  return failure();
1006 
1007  bool zeroSliceGuard = true;
1008  if (controlFn) {
1009  if (std::optional<bool> control = controlFn(sliceOp))
1010  zeroSliceGuard = *control;
1011  else
1012  return failure();
1013  }
1014 
1015  FailureOr<TilingResult> tilingResult =
1016  tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(),
1017  sliceOp.getMixedSizes(), zeroSliceGuard);
1018  if (failed(tilingResult))
1019  return failure();
1020 
1021  RankedTensorType sourceType = sliceOp.getSourceType();
1022  RankedTensorType resultType = sliceOp.getResultType();
1023 
1024  // If the extract_slice is not rank-reduced, all shapes are static and the
1025  // data source is actually used. Rewrite into pad(extract_slice(x)).
1026  if (sourceType.getRank() == resultType.getRank()) {
1027  rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
1028  return success();
1029  }
1030 
1031  // Handle rank-reduced slice by creating another extract_slice op.
1033  rewriter, sliceOp.getLoc(), tilingResult->tiledValues[0], resultType);
1034 
1035  rewriter.replaceOp(sliceOp, rankReduced);
1036  return success();
1037 }
1038 
1039 /// If padding value is set, returns a tensor.pad Op for the source tensor,
1040 /// with the output shape matching the output of `packOp`. Otherwise, returns
1041 /// the source directly.
1042 ///
1043 /// This method assumes that all outer dims for this pack Op are 1.
1045  linalg::PackOp packOp) {
1046  Value input = packOp.getSource();
1047  if (!packOp.getPaddingValue()) {
1048  return input;
1049  }
1050 
1051  assert(llvm::all_of(packOp.getAllOuterDims(),
1052  [](int64_t val) { return val == 1; }) &&
1053  "some outer dims are != 1");
1054 
1055  Location loc = packOp.getLoc();
1056  ShapedType inputType = packOp.getSourceType();
1057  int64_t inputRank = inputType.getRank();
1058 
1059  DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
1060  packOp.getDimAndTileMapping();
1061 
1062  // The sizes of dynamic tiles
1063  SmallVector<Value> dynamicTileSizes;
1064 
1065  // Collect dims for the padded shape.
1066  SmallVector<int64_t> paddedShape;
1067  for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1068  // 1. Non-tiled outer dims.
1069  // These dims should be 1 and we simply preserve them.
1070  if (!tileAndPosMapping.count(dimIdx)) {
1071  int64_t inputDimSize = inputType.getDimSize(dimIdx);
1072  assert(inputDimSize == 1 &&
1073  "with all outer dims == 1, this non-tiled input dim should be 1!");
1074  paddedShape.push_back(inputDimSize);
1075  continue;
1076  }
1077 
1078  // 2. Tiled outer dims
1079  // As all outer dims == 1, it is safe to use the tile size for the padded
1080  // shape.
1081  OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
1082 
1083  // 2.1 Static tile sizes
1084  std::optional<int64_t> cstTileSize = getConstantIntValue(tileSizeForDim);
1085  if (cstTileSize.has_value()) {
1086  paddedShape.push_back(cstTileSize.value());
1087  continue;
1088  }
1089 
1090  // 2.2 Dynamic tile sizes
1091  paddedShape.push_back(ShapedType::kDynamic);
1092 
1093  // Get the value that holds the dynamic size.
1094  dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
1095  }
1096  auto resultType =
1097  RankedTensorType::get(paddedShape, inputType.getElementType());
1098  return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(),
1099  /*nofold=*/false, loc, builder,
1100  dynamicTileSizes);
1101 }
1102 
1103 // Normalizes a permutation on a higher rank space to its actual size, e.g.
1104 // perm = [1, 4, 2]
1105 // becomes
1106 // norm = [0, 2, 1]
1107 static SmallVector<int64_t>
1109  constexpr int64_t kNonTiledMarker = -1;
1110  SmallVector<int64_t> vec(rank, kNonTiledMarker);
1111  for (auto [index, value] : llvm::enumerate(perm))
1112  vec[value] = index;
1113  SmallVector<int64_t> normalizedPerm = llvm::filter_to_vector(
1114  vec, [&](int64_t v) { return v != kNonTiledMarker; });
1115  // This inverts the permutation in addition to normalizing so invert back.
1116  return invertPermutationVector(normalizedPerm);
1117 }
1118 
1119 // Gets the normalized permutation implied by innerDimsPos and outerDimsPerm
1120 // assuming rank reduction of unit outer dims.
1121 static SmallVector<int64_t>
1125  SmallVector<int64_t> rankReducedOuterDimsPerm;
1126  SmallVector<int64_t> outerDims;
1127  SmallVector<int64_t> innerDims;
1128  int64_t dim = 0;
1129  int64_t unpackedRank = shape.size();
1130  for (auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1131  if (llvm::is_contained(innerDimsPos, i)) {
1132  innerDims.push_back(dim++);
1133  continue;
1134  }
1135  if (shape[i] == 1)
1136  continue;
1137  outerDims.push_back(dim++);
1138  if (!outerDimsPerm.empty())
1139  rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
1140  }
1141 
1142  // Get the position of the inner dims after permutation.
1143  SmallVector<int64_t> innerPerm =
1145  applyPermutationToVector<int64_t>(innerDims, innerPerm);
1146 
1147  // Ditto for the outer dims.
1148  SmallVector<int64_t> perm = outerDims;
1149 
1150  rankReducedOuterDimsPerm =
1151  getPackUnpackNormalizedPerm(unpackedRank, rankReducedOuterDimsPerm);
1152  if (!rankReducedOuterDimsPerm.empty())
1153  applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm);
1154 
1155  // The tile always ends up as the inner most dims after packing.
1156  perm.append(innerDims);
1157 
1158  return perm;
1159 }
1160 
1162  linalg::PackOp packOp, PatternRewriter &rewriter) const {
1163  // TODO: support the case that outer dimensions are not all 1s. A
1164  // tensor.expand_shape will be generated in this case.
1165  if (llvm::any_of(packOp.getAllOuterDims(),
1166  [](int64_t dim) { return dim != 1; })) {
1167  return rewriter.notifyMatchFailure(
1168  packOp, "not all outer dimensions of the result are 1s");
1169  }
1170 
1171  Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
1172  Attribute oneIdxAttr = rewriter.getIndexAttr(1);
1173  Location loc = packOp.getLoc();
1174 
1175  Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
1176  DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1177  packOp.getDimAndTileMapping();
1178  int64_t srcRank = packOp.getSourceRank();
1179  int64_t destRank = packOp.getDestRank();
1180  int64_t numTiles = destRank - srcRank;
1181 
1182  if (!llvm::all_of(packOp.getInnerDimsPos(),
1183  [&srcRank, &numTiles](int64_t dimPos) {
1184  return dimPos >= (srcRank - numTiles - 1);
1185  }))
1186  return rewriter.notifyMatchFailure(
1187  packOp, "Attempting to tile non-trailing source dims!");
1188 
1189  // 1. Extract the inner tile sizes.
1190  // Where possible, values are replaced with constant attributes (to match the
1191  // behaviour of `getPackOpSourceOrPaddedSource`).
1192  SmallVector<OpFoldResult> tileSizes;
1193  for (auto i : llvm::seq<unsigned>(0, srcRank)) {
1194  if (dimAndTileMapping.count(i)) {
1195  // Rather than taking the tile size as is, extact the actual constant
1196  // value Attribute where possible, e.g.:
1197  // [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
1198  auto [_, tileSize] =
1199  getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
1200  tileSizes.push_back(tileSize);
1201  }
1202  }
1203 
1204  // 2. Transpose the input to match the inner tile order:
1205  // %init = tensor.empty()
1206  // %transposed_tile = linalg.transpose ins(%source_or_padded_source),
1207  // outs(%init)
1208  // Two assumptions are made:
1209  // 1. All outer dims are 1 - the corresponding transposition doesn't matter.
1210  // 2. Inner dims position correspond to the trailing `numTiles` dims.
1211  SmallVector<int64_t> tilesPermNormalized =
1212  getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos());
1213  SmallVector<int64_t> srcPermForTranspose;
1214  for (int64_t i = 0; i < (srcRank - numTiles); i++)
1215  srcPermForTranspose.push_back(i);
1216 
1217  srcPermForTranspose.append(SmallVector<int64_t>(packOp.getInnerDimsPos()));
1218 
1219  LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"
1220  << "perm: " << llvm::interleaved(srcPermForTranspose)
1221  << "\n");
1222 
1223  // 2.1 Create tensor.empty (init value for TransposeOp)
1224  SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,
1225  oneIdxAttr);
1226  transShapeForEmptyOp.append(tileSizes);
1227 
1228  applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
1229  srcPermForTranspose);
1230  Value empty = rewriter.create<tensor::EmptyOp>(
1231  loc, transShapeForEmptyOp, packOp.getSourceType().getElementType());
1232 
1233  // 2.2 Create linalg.transpose
1234  auto transposedOp = rewriter.create<linalg::TransposeOp>(loc, input, empty,
1235  srcPermForTranspose);
1236 
1237  // 3. Insert the inner tile to the destination:
1238  // %inserted_tile = tensor.insert_slice(%transposed_tile)
1239  SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
1240  SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
1241  // Outer dims are all 1s!
1242  SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(),
1243  oneIdxAttr);
1244  SmallVector<int64_t> writeShape;
1245 
1246  for (auto tileSize : packOp.getMixedTiles()) {
1247  auto [tileSizeStatic, tileSizeOfr] =
1248  getSimplifiedOfrAndStaticSizePair(tileSize, rewriter);
1249  writeSizes.push_back(tileSizeOfr);
1250  writeShape.push_back(tileSizeStatic);
1251  }
1252 
1253  // 4. Replace tensor.packOp with tensor.insert_slice created above
1254  auto insert = rewriter.create<tensor::InsertSliceOp>(
1255  loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets,
1256  writeSizes, writeStrides);
1257  rewriter.replaceOp(packOp, insert.getResult());
1258 
1259  return success();
1260 }
1261 
1263  linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const {
1264  int64_t srcRank = unpackOp.getSourceRank();
1265  int64_t destRank = unpackOp.getDestRank();
1266  ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
1267  ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1268  if (llvm::any_of(unpackOp.getTiledOuterDims(),
1269  [](int64_t dim) { return dim != 1; })) {
1270  return rewriter.notifyMatchFailure(
1271  unpackOp,
1272  "require the tiled outer dimensions of the result are all 1s");
1273  }
1274 
1275  // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
1276  // %extracted_tile = tensor.extract_slice(%unpack_op_input)
1277  Location loc = unpackOp.getLoc();
1278  Value source = unpackOp.getSource();
1279  DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1280  unpackOp.getDimAndTileMapping();
1281  Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
1282  Attribute oneIdxAttr = rewriter.getIndexAttr(1);
1283 
1284  // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of
1285  // dims:
1286  // [ outer-untiled-dims, outer-tiled-dims, tile-sizes ]
1287  SmallVector<int64_t> readShapeForExtractSlice;
1288  // The sizes attribute for ExtractSliceOp. Due to rank-reducing (and
1289  // outer-tiled-dims being all 1), this will be
1290  // [ outer-untiled-dims, tile-sizes ]
1291  SmallVector<OpFoldResult> extractSliceSizes;
1292  // The offset and strides attributes for ExtractSliceOp.
1293  SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr);
1294  SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr);
1295 
1296  // Shape for EmptyOp that's used as the init value for TransposeOp below.
1297  // This should be:
1298  // [ outer-untiled-dims, tile-sizes ]
1299  // However, skip unit dims - TransposeOp (below) applies rank-reduced
1300  // permutation.
1301  SmallVector<OpFoldResult> shapeForEmptyOp;
1302 
1303  for (auto i : llvm::seq<unsigned>(0, destRank)) {
1304  // Compute sizes attribute for ExtractSliceOp - outer-tiled-dims.
1305  //
1306  // As all outer tiled dims are 1, so the corresponding
1307  // slice size to read will also 1. As this will be rank-reducing "extract
1308  // slice" (i.e. the unit dims will be "collapsed"), there's no need to
1309  // update:
1310  // * the output shape for ExtractSliceOp, nor
1311  // * the shape for EmptyOp.
1312  if (dimAndTileMapping.count(i)) {
1313  extractSliceSizes.push_back(oneIdxAttr);
1314  continue;
1315  }
1316 
1317  // Compute sizes attribute for ExtractSliceOp + EmptyOp -
1318  // outer-untiled-dims
1319  if (ShapedType::isDynamic(srcShape[i])) {
1320  OpFoldResult dynamicDim =
1321  rewriter.create<tensor::DimOp>(loc, source, i).getResult();
1322  extractSliceSizes.push_back(dynamicDim);
1323  shapeForEmptyOp.push_back(dynamicDim);
1324  } else {
1325  extractSliceSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
1326  if (srcShape[i] != 1)
1327  shapeForEmptyOp.push_back(rewriter.getIndexAttr(srcShape[i]));
1328  }
1329  // Compute the output shape for ExtractSliceOp - outer-untiled-dims (take
1330  // into account rank-reducing)
1331  if (srcShape[i] != 1) {
1332  readShapeForExtractSlice.push_back(srcShape[i]);
1333  }
1334  }
1335  // Append the tile sizes to "sizes attribute" for ExtractSliceOp and the
1336  // shape for EmptyOp.
1337  auto mixedTiles = unpackOp.getMixedTiles();
1338  extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
1339  shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
1340 
1341  // Explicitly create the type for extract_slice op because the inner tile
1342  // size could be 1. We want to represent the whole inner tile in this case.
1343  auto tileShape = srcShape.drop_front(destRank);
1344  // Append the inner tile shape to the permuted and rank-reduced outer shape.
1345  readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
1346  Type elemType = unpackOp.getSourceType().getElementType();
1347  auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
1348  Value innerTile = rewriter.create<tensor::ExtractSliceOp>(
1349  loc, readType, unpackOp.getSource(), extractSliceOffsets,
1350  extractSliceSizes, extractSliceStrides);
1351 
1352  // 2. Transpose the tile to match the outer corresponding tile order.
1354  srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
1355  // Unpack is a transition out of packed space so we invert the permutation.
1356  perm = invertPermutationVector(perm);
1357  applyPermutationToVector<OpFoldResult>(shapeForEmptyOp, perm);
1358 
1359  Value empty =
1360  rewriter.create<tensor::EmptyOp>(loc, shapeForEmptyOp, elemType);
1361  auto transposedOp =
1362  rewriter.create<linalg::TransposeOp>(loc, innerTile, empty, perm);
1363 
1364  // 3. Handle in-complete tiles if needed. It truncates trailing data from the
1365  // transposed tile.
1366  int numLoops = shapeForEmptyOp.size();
1367  SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr);
1368  SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr);
1369  SmallVector<OpFoldResult> tileSizes;
1370  ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
1371  for (auto i : llvm::seq<unsigned>(0, destRank)) {
1372  if (dimAndTileMapping.count(i) || destShape[i] != 1)
1373  tileSizes.push_back(
1374  tensor::getMixedSize(rewriter, loc, unpackOp.getDest(), i));
1375  }
1376 
1377  auto partialTile = rewriter.create<tensor::ExtractSliceOp>(
1378  loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides);
1379 
1380  // 4. Insert the result to the destination tensor.
1381  SmallVector<OpFoldResult> writeSizes;
1382  SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
1383  SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
1384  for (int i = 0, idx = 0; i < destRank; ++i) {
1385  if (dimAndTileMapping.count(i) || destShape[i] != 1)
1386  writeSizes.push_back(tileSizes[idx++]);
1387  else
1388  writeSizes.push_back(oneIdxAttr);
1389  }
1390  auto insert = rewriter.create<tensor::InsertSliceOp>(
1391  loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes,
1392  writeStrides);
1393  rewriter.replaceOp(unpackOp, insert.getResult());
1394 
1395  return success();
1396 }
1397 
1398 // The following are patterns for downscaling convolution ops with size-1
1399 // window dimensions.
1400 //
1401 // Note that we'd eventually want to write such transformations in a generic
1402 // way, e.g., converting to linalg.generic, removing the size-1 dimensions,
1403 // and then turning back to named ops. But for now it's fine to have a few
1404 // patterns matching special ops to get started.
1405 
1406 template <typename Conv2DOp, typename Conv1DOp>
1408  returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const {
1409  if (convOp.hasPureBufferSemantics())
1410  return failure(); // To be implemented.
1411 
1412  Value input = convOp.getInputs().front();
1413  Value kernel = convOp.getInputs().back();
1414  Value output = convOp.getOutputs().front();
1415 
1416  auto inputType = dyn_cast<RankedTensorType>(input.getType());
1417  auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1418  auto outputType = dyn_cast<RankedTensorType>(output.getType());
1419 
1420  auto kernelShape = kernelType.getShape();
1421  auto outputShape = outputType.getShape();
1422 
1423  // Get domain indices based on conv2D layout.
1424  auto [khIndex, kwIndex, ohIndex, owIndex] =
1426  convOp)
1427  .Case([&](linalg::Conv2DNhwcHwcfOp op) {
1428  return std::make_tuple(0, 1, 1, 2);
1429  })
1430  .Case([&](linalg::Conv2DNchwFchwOp op) {
1431  return std::make_tuple(2, 3, 2, 3);
1432  })
1433  .Case([&](linalg::PoolingNhwcSumOp op) {
1434  return std::make_tuple(0, 1, 1, 2);
1435  })
1436  .Case([&](linalg::PoolingNchwSumOp op) {
1437  return std::make_tuple(0, 1, 2, 3);
1438  })
1439  .Case([&](linalg::PoolingNhwcMaxOp op) {
1440  return std::make_tuple(0, 1, 1, 2);
1441  })
1442  .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
1443  return std::make_tuple(0, 1, 1, 2);
1444  })
1445  .Case([&](linalg::PoolingNhwcMinOp op) {
1446  return std::make_tuple(0, 1, 1, 2);
1447  })
1448  .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
1449  return std::make_tuple(0, 1, 1, 2);
1450  })
1451  .Case([&](linalg::PoolingNchwMaxOp op) {
1452  return std::make_tuple(0, 1, 2, 3);
1453  })
1454  .Default([&](Operation *op) {
1455  llvm_unreachable("unexpected conv2d/pool2d operation.");
1456  return std::make_tuple(0, 0, 0, 0);
1457  });
1458 
1459  // Only handle the case where at least one of the window dimensions is
1460  // of size 1. Other cases can rely on tiling to reduce to such cases.
1461  int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
1462  int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
1463  bool removeH = (khSize == 1 && ohSize == 1);
1464  bool removeW = (kwSize == 1 && owSize == 1);
1465  if (!removeH && !removeW)
1466  return failure();
1467 
1468  // Get new shapes and types for all operands by removing the size-1
1469  // dimension.
1470  using RTTBuilder = RankedTensorType::Builder;
1471  RankedTensorType newInputType =
1472  RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex));
1473  RankedTensorType newKernelType =
1474  RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex));
1475  RankedTensorType newOutputType =
1476  RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex));
1477 
1478  // Rank-reduce operands.
1479  Location loc = convOp.getLoc();
1481  rewriter, loc, input, newInputType);
1483  rewriter, loc, kernel, newKernelType);
1485  rewriter, loc, output, newOutputType);
1486 
1487  // Rank-reduce strides and dilations too.
1488  // TODO: dropDim 1-liner helper.
1489  auto strides =
1490  llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
1491  strides.erase(strides.begin() + (removeH ? 0 : 1));
1492  auto stridesAttr = rewriter.getI64VectorAttr(strides);
1493 
1494  auto dilations =
1495  llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
1496  dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1497  auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1498 
1499  auto conv1DOp = rewriter.create<Conv1DOp>(
1500  loc, newOutputType, ValueRange{newInput, newKernel},
1501  ValueRange{newOutput}, stridesAttr, dilationsAttr);
1502 
1503  // Insert back.
1505  rewriter, loc, conv1DOp.getResult(0), output);
1506  rewriter.replaceOp(convOp, inserted);
1507 
1508  return conv1DOp;
1509 }
1510 
1511 template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp,
1512  Conv1DNwcWcfOp>;
1513 template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
1514  Conv1DNcwFcwOp>;
1515 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp,
1516  PoolingNwcSumOp>;
1517 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp,
1518  PoolingNcwSumOp>;
1519 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp,
1520  PoolingNwcMaxOp>;
1522  PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
1523 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp,
1524  PoolingNwcMinOp>;
1526  PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
1527 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp,
1528  PoolingNcwMaxOp>;
1529 
1530 FailureOr<DepthwiseConv1DNwcWcOp>
1532  DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const {
1533  if (convOp.hasPureBufferSemantics())
1534  return failure(); // To be implemented.
1535 
1536  Value input = convOp.getInputs().front();
1537  Value kernel = convOp.getInputs().back();
1538  Value output = convOp.getOutputs().front();
1539 
1540  auto inputType = dyn_cast<RankedTensorType>(input.getType());
1541  auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1542  auto outputType = dyn_cast<RankedTensorType>(output.getType());
1543 
1544  auto kernelShape = kernelType.getShape();
1545  auto outputShape = outputType.getShape();
1546 
1547  // Only handle the case where at least one of the window dimensions is
1548  // of size 1. Other cases can rely on tiling to reduce to such cases.
1549  int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1550  int64_t ohSize = outputShape[1], owSize = outputShape[2];
1551  bool removeH = (khSize == 1 && ohSize == 1);
1552  bool removeW = (kwSize == 1 && owSize == 1);
1553  if (!removeH && !removeW)
1554  return failure();
1555 
1556  // Get new shapes and types for all operands by removing the size-1
1557  // dimension.
1558  using RTTBuilder = RankedTensorType::Builder;
1559  RankedTensorType newInputType =
1560  RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
1561  RankedTensorType newKernelType =
1562  RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1563  RankedTensorType newOutputType =
1564  RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
1565 
1566  // Rank-reduce operands.
1567  Location loc = convOp.getLoc();
1569  rewriter, loc, input, newInputType);
1571  rewriter, loc, kernel, newKernelType);
1573  rewriter, loc, output, newOutputType);
1574 
1575  // Rank-reduce strides and dilations too.
1576  // TODO: dropDim 1-liner helper.
1577  auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
1578  strides.erase(strides.begin() + (removeH ? 0 : 1));
1579  auto stridesAttr = rewriter.getI64VectorAttr(strides);
1580 
1581  auto dilations =
1582  llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
1583  dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1584  auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1585 
1586  auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>(
1587  loc, newOutputType, ValueRange{newInput, newKernel},
1588  ValueRange{newOutput}, stridesAttr, dilationsAttr);
1589 
1590  // Insert back.
1592  rewriter, loc, conv1DOp.getResult(0), output);
1593  rewriter.replaceOp(convOp, inserted);
1594 
1595  return conv1DOp;
1596 }
1597 
1598 FailureOr<Conv1DOp>
1600  PatternRewriter &rewriter) const {
1601  if (convOp.hasPureBufferSemantics())
1602  return failure(); // To be implemented.
1603 
1604  Value input = convOp.getInputs().front();
1605  Value kernel = convOp.getInputs().back();
1606  Value output = convOp.getOutputs().front();
1607 
1608  auto inputType = dyn_cast<RankedTensorType>(input.getType());
1609  auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1610  auto outputType = dyn_cast<RankedTensorType>(output.getType());
1611 
1612  auto kernelShape = kernelType.getShape();
1613  auto outputShape = outputType.getShape();
1614 
1615  // Only handle the case where at least one of the window dimensions is
1616  // of size 1. Other cases can rely on tiling to reduce to such cases.
1617  int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1618  int64_t ohSize = outputShape[0], owSize = outputShape[1];
1619  bool removeH = (khSize == 1 && ohSize == 1);
1620  bool removeW = (kwSize == 1 && owSize == 1);
1621  if (!removeH && !removeW)
1622  return failure();
1623 
1624  // Get new shapes and types for all operands by removing the size-1
1625  // dimension.
1626  using RTTBuilder = RankedTensorType::Builder;
1627  RankedTensorType newInputType =
1628  RTTBuilder(inputType).dropDim((removeH ? 0 : 1));
1629  RankedTensorType newKernelType =
1630  RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1631  RankedTensorType newOutputType =
1632  RTTBuilder(outputType).dropDim(removeH ? 0 : 1);
1633 
1634  // Rank-reduce operands.
1635  Location loc = convOp.getLoc();
1637  rewriter, loc, input, newInputType);
1639  rewriter, loc, kernel, newKernelType);
1641  rewriter, loc, output, newOutputType);
1642 
1643  auto conv1DOp = rewriter.create<Conv1DOp>(loc, newOutputType,
1644  ValueRange{newInput, newKernel},
1645  ValueRange{newOutput});
1646 
1647  // Insert back.
1649  rewriter, loc, conv1DOp.getResult(0), output);
1650  rewriter.replaceOp(convOp, inserted);
1651 
1652  return conv1DOp;
1653 }
1654 
1656  PatternBenefit benefit) {
1657  patterns.add<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
1658  Conv1DNwcWcfOp>,
1659  DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
1660  Conv1DNcwFcwOp>,
1662  patterns.getContext(), benefit);
1663  patterns.add<
1667  DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxUnsignedOp,
1668  PoolingNwcMaxUnsignedOp>,
1670  DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinUnsignedOp,
1671  PoolingNwcMinUnsignedOp>,
1673  patterns.getContext(), benefit);
1674 }
1675 
1679 }
1680 
1682  patterns.add<DecomposePadOpPattern>(patterns.getContext());
1683 }
SmallVector< int64_t > outerDimsPerm
Definition: LinalgOps.cpp:4575
SmallVector< int64_t > innerDimsPos
Definition: LinalgOps.cpp:4573
static RankedTensorType permuteShape(RankedTensorType tensorType, ArrayRef< int64_t > permutationVector)
Return a copy of tensorType after permutation by permutationVector.
Definition: Transforms.cpp:621
static SmallVector< int64_t > getPackUnpackRankReducedPerm(ArrayRef< int64_t > shape, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
static std::optional< int64_t > getFirstResultIndexFunctionOf(AffineMap map, int64_t dim)
Return the index of the first result of map that is a function of AffineDimExpr(dim),...
Definition: Transforms.cpp:107
static FailureOr< SmallVector< std::optional< int64_t > > > packLinalgMetadataOnce(SmallVectorImpl< AffineMap > &indexingMaps, SmallVectorImpl< utils::IteratorType > &iteratorTypes, int64_t dim)
Perform one step of packing of a LinalgOp's metadata along dim into the newDim at iteratorTypes....
Definition: Transforms.cpp:154
static LinalgOp transposeOneLinalgOperandAndReplace(RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand, ArrayRef< int64_t > permutation, Value transposedValue)
Return a new GenericOp obtained by transposing opOperand by the permutation vector:
Definition: Transforms.cpp:634
static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim)
Return true if map has 0 or 1 result function of AffineDimExpr(dim).
Definition: Transforms.cpp:88
static SmallVector< int64_t > getPackUnpackNormalizedPerm(int rank, ArrayRef< int64_t > perm)
static Value getPackOpSourceOrPaddedSource(OpBuilder &builder, linalg::PackOp packOp)
If padding value is set, returns a tensor.pad Op for the source tensor, with the output shape matchin...
static std::string stringifyReassocIndices(ReassociationIndicesRef ri)
Definition: Transforms.cpp:100
#define DBGSNL()
Definition: Transforms.cpp:47
#define DBGS()
Definition: Transforms.cpp:46
Base type for affine expression.
Definition: AffineExpr.h:68
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
Definition: AffineExpr.cpp:316
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:968
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
Definition: AffineMap.h:267
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
Definition: AffineMap.h:315
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:360
MLIRContext * getContext() const
Definition: Builders.h:55
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:124
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:345
This class helps build Operations.
Definition: Builders.h:204
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:428
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:395
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:243
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:433
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
result_range getResults()
Definition: Operation.h:415
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:214
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:241
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:225
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1224
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Definition: Transforms.cpp:679
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
Definition: Transforms.cpp:360
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel 'loops' and applies affine_min/max bounds simplification on the fly where relevant.
Definition: Transforms.cpp:76
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
Definition: Interchange.cpp:50
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
Definition: Transforms.cpp:770
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Definition: Transforms.cpp:482
SmallVector< Value > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel and canonicalize loop op and return the new result.
Definition: Transforms.cpp:60
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
Definition: Transforms.cpp:224
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, SmallVector< Value > dynOutDims={})
Definition: Utils.cpp:25
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Definition: TensorOps.cpp:3093
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
Definition: TensorOps.cpp:2689
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:61
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:340
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
ArrayRef< int64_t > ReassociationIndicesRef
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:1297
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override
Definition: Transforms.cpp:920
Rewrites a linalg::PackOp into a sequence of:
Definition: Transforms.h:1563
LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override
Rewrites a linalg::UnPackOp into a sequence of rank-reduced.
Definition: Transforms.h:1597
LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const override
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
Definition: Transforms.h:1517
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Definition: Transforms.cpp:948
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
Definition: Transforms.cpp:927
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Definition: Transforms.h:1437
FailureOr< DepthwiseConv1DNwcWcOp > returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D convolution ops with size-1 window dimensions into 1-D convolution ops.
Definition: Transforms.h:1417
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override
Definition: Transforms.cpp:998
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
Definition: Transforms.h:202
Struct to hold the result of a pack call.
Definition: Transforms.h:1141
Struct to hold the result of a packTranspose call.
Definition: Transforms.h:1153