MLIR  22.0.0git
Transforms.cpp
Go to the documentation of this file.
1 //===- Transforms.cpp - Linalg transformations as patterns ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
28 #include "mlir/IR/AffineExpr.h"
29 #include "mlir/Support/LLVM.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/DebugLog.h"
33 #include "llvm/Support/InterleavedRange.h"
34 #include "llvm/Support/raw_ostream.h"
35 #include <utility>
36 
37 #define DEBUG_TYPE "linalg-transforms"
38 
39 using namespace mlir;
40 using namespace mlir::linalg;
41 
42 //===----------------------------------------------------------------------===//
43 // Transformations exposed as functional-style API calls.
44 //===----------------------------------------------------------------------===//
45 
46 //===----------------------------------------------------------------------===//
47 // peelLoop transformation.
48 //===----------------------------------------------------------------------===//
49 
50 /// Try to peel and canonicalize loop `op` and return the new result.
51 /// Also applies affine_min/max bounds simplification on the fly where relevant.
52 // TODO: Add support for scf.parallel and affine.for loops.
54  Operation *op) {
56  .Case<scf::ForOp>([&](scf::ForOp forOp) {
57  scf::ForOp partialIteration;
58  if (succeeded(scf::peelForLoopAndSimplifyBounds(rewriter, forOp,
59  partialIteration)))
60  return partialIteration->getResults();
61  assert(!partialIteration && "expected that loop was not peeled");
62  return forOp->getResults();
63  })
64  .Default([&](Operation *op) { return op->getResults(); });
65 }
66 
67 /// Peel 'loops' and applies affine_min/max bounds simplification on the fly
68 /// where relevant.
70  ArrayRef<scf::ForOp> loops) {
71  for (auto loopOp : loops)
72  peelLoop(rewriter, loopOp);
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // pack transformation.
77 //===----------------------------------------------------------------------===//
78 
79 #ifndef NDEBUG
80 /// Return true if `map` has 0 or 1 result function of AffineDimExpr(dim).
81 static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) {
82  bool found = false;
83  for (AffineExpr e : map.getResults()) {
84  if (!e.isFunctionOfDim(dim))
85  continue;
86  if (found)
87  return false;
88  found = true;
89  }
90  return true;
91 }
92 #endif // NDEBUG
93 
95  return llvm::interleaved(ri, ", ", /*Prefix=*/"|", /*Suffix=*/"");
96 }
97 
98 /// Return the index of the first result of `map` that is a function of
99 /// AffineDimExpr(dim), std::nullopt otherwise.
100 static std::optional<int64_t> getFirstResultIndexFunctionOf(AffineMap map,
101  int64_t dim) {
102  for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
103  AffineExpr expr = map.getResult(i);
104  if (!expr.isFunctionOfDim(dim))
105  continue;
106  return i;
107  }
108  return std::nullopt;
109 }
110 
111 /// Perform one step of packing of a LinalgOp's metadata along `dim` into the
112 /// `newDim` at `iteratorTypes.size()` by:
113 /// 1. Appending `iteratorTypes[newDim]`, equal to `iteratorTypes[dim]`.
114 /// 2. Appending a `newDim` to the domain of every indexing map.
115 /// 3. For each operand (i.e. for each map in `indexingMaps`), perform packing
116 /// by potentially adding a `newDim` result to `map`.
117 /// The preserved invariant is that `iteratorTypes.size()` is always equal to
118 /// `map.getNumDims()` for every map in `indexingMaps`.
119 ///
120 /// Update `indexingMaps` and `iteratorTypes` inplace as one step of the update.
121 /// Return a vector that records the optional packing for each operand.
122 /// Return failure if the packed indexing cannot be represented with a LinalgOp.
123 ///
124 /// Further details:
125 /// ================
126 /// The current implementation of packing (i.e. data tiling) consists of
127 /// rewriting a linearized strip-mined form into a higher-dimensional access.
128 /// e.g. consider an access `A[I][f(j, k, l)]` and packing by 4; we rewrite
129 /// `I` into `4 * i + ii`, where `0 <= ii < 4`.
130 /// The access is further rewritten as `A[i][f(j, k, l)][ii]`.
131 ///
132 /// This rewrite into higher dimensional access is not possible for general
133 /// AffineExpr in Linalg atm, it is restricted to an AffineDimExpr:
134 /// e.g. consider an access `A[I + J][f(j, k, l)]` and packing by 4; we
135 /// rewrite `I + J` into `4 * i + ii + J`, where `0 <= ii < 4`.
136 /// The rewrite of the access would be a form not representable in Linalg:
137 /// `A[i + (ii + J) / 4][f(j, k, l)][(ii + J) % 4]`.
138 /// Note however that as `J` and `ii` iterate, the accesses do not have a
139 /// particular alignment, so packing does not achieve alignment in this case
140 ///
141 /// In the future, we may want to consider a mixed-form that allows some
142 /// alignment in the presence of multiple accesses:
143 /// `A[I][f(j, k, l)]` and `B[I + J][f(j, k, l)]`
144 /// And would rewrite accesses as:
145 /// `A[i][f(j, k, l)][ii]` and `B[4 * i + ii + J][f(j, k, l)]`
146 static FailureOr<SmallVector<std::optional<int64_t>>>
149  int64_t dim) {
150  int64_t newDim = iteratorTypes.size();
151  iteratorTypes.push_back(iteratorTypes[dim]);
152 
153  SmallVector<std::optional<int64_t>> packedDimPerIndexingMap(
154  indexingMaps.size(), std::nullopt);
155  SmallVector<AffineMap> newMaps;
156  for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e;
157  ++operandIdx) {
158  AffineMap map = indexingMaps[operandIdx];
159 
160  // Add the `newDim` to map whatever the case.
161  assert(map.getNumDims() == newDim && "num dims invariant violation");
162  map = map.shiftDims(1, newDim);
163 
164  // Get the at-most-1 index of the result that is a function of `dim`.
165  // If we can find one, we insert `AffineDimExpr(newDim)` to the map, which
166  // logically chunks dimension `dim` into `K * dim + newDim`, where the
167  // packing factor `K` is specified separately.
168  assert(hasAtMostOneResultFunctionOfDim(map, dim) &&
169  "num results invariant violation");
170  auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim);
171  if (!maybeOperandDimensionToPack.has_value()) {
172  newMaps.push_back(map);
173  continue;
174  }
175 
176  // We can only pack AffineDimExpr atm.
177  if (!isa<AffineDimExpr>(map.getResult(maybeOperandDimensionToPack.value())))
178  return failure();
179 
180  // Add `newDim` to the results of the map.
181  map = map.insertResult(Builder(map.getContext()).getAffineDimExpr(newDim),
182  map.getNumResults());
183  newMaps.push_back(map);
184 
185  // Record the that `operandIdx` is packed.
186  packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack;
187  }
188  indexingMaps = newMaps;
189 
190  return packedDimPerIndexingMap;
191 }
192 
193 namespace {
194 
195 /// Helper struct to encode packing along one dimension of a LinalgOp.
196 struct PackedOperandsDim {
197  OpFoldResult packedSize;
198  SmallVector<std::optional<int64_t>> packedDimForEachOperand;
199 };
200 
201 /// Helper struct to encode packing along all dimensions of a LinalgOp.
202 struct PackedOperandsDimList {
203  void pushBack(PackedOperandsDim &&packedOperandsDims) {
204  spec.emplace_back(packedOperandsDims);
205  }
206  /// Return all the dims that have been packed for operand @ `operandPos`.
207  SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos);
208  /// Return all the pack sizes by which an operand @ `operandPos` is packed.
209  SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos);
210 
211 private:
213 };
214 
215 } // namespace
216 
217 FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
218  linalg::PackOp packOp,
219  bool lowerPadLikeWithInsertSlice) {
220  // 1. Filter out NYI cases.
221  auto packedTensorType =
222  cast<RankedTensorType>(packOp->getResultTypes().front());
223  if (llvm::any_of(packOp.getStaticInnerTiles(), ShapedType::isDynamic)) {
224  return rewriter.notifyMatchFailure(
225  packOp,
226  "non-static shape NYI, needs a more powerful tensor.expand_shape op");
227  }
228 
229  Location loc = packOp->getLoc();
230  OpBuilder::InsertionGuard g(rewriter);
231  rewriter.setInsertionPoint(packOp);
232 
233  // 2. Compute the permutation vector to shuffle packed shape into the shape
234  // before any outer or inner permutations have been applied.
235  PackingMetadata packingMetadata = computePackingMetadata(
236  packedTensorType.getRank(), packOp.getInnerDimsPos());
237  SmallVector<int64_t> packedToStripMinedShapePerm =
238  getPackInverseDestPerm(packOp);
239 
240  // 3. Compute the stripMinedShape: this is the packed shape before any outer
241  // or inner permutations have been applied.
242  SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
243  applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
244 
245  // 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
246  SmallVector<OpFoldResult> lows(packOp.getSourceRank(),
247  rewriter.getIndexAttr(0));
248  SmallVector<OpFoldResult> highs(packOp.getSourceRank(),
249  rewriter.getIndexAttr(0));
250  for (auto [pos, innerSize] :
251  llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
252  int outerPos =
253  packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]];
254  OpFoldResult origSize =
255  tensor::getMixedSize(rewriter, loc, packOp.getSource(), pos);
256  OpFoldResult outerSize =
257  tensor::getMixedSize(rewriter, loc, packOp.getDest(), outerPos);
258  AffineExpr s0, d0, d1;
259  bindDims(rewriter.getContext(), d0, d1);
260  bindSymbols(rewriter.getContext(), s0);
261  auto map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/1, d0 * s0 - d1);
263  rewriter, loc, map, {outerSize, origSize, innerSize});
264  }
265  RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
266  RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
267  packingMetadata.reassociations);
268  Value paddingValue = packOp.getPaddingValue();
269  if (!paddingValue) {
270  paddingValue = arith::ConstantOp::create(
271  rewriter, loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
272  }
273  auto padOp =
274  tensor::PadOp::create(rewriter, loc, collapsed, packOp.getSource(), lows,
275  highs, paddingValue, /*nofold=*/false);
276 
277  LDBG() << "insertPositions: "
278  << llvm::interleaved(packingMetadata.insertPositions);
279  LDBG() << "outerPositions: "
280  << llvm::interleaved(packingMetadata.outerPositions);
281  LDBG() << "packedShape: " << llvm::interleaved(packedTensorType.getShape());
282  LDBG() << "packedToStripMinedShapePerm: "
283  << llvm::interleaved(packedToStripMinedShapePerm);
284  LDBG() << "reassociations: "
285  << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
287  LDBG() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
288  LDBG() << "collapsed type: " << collapsed;
289 
290  if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) {
291  // Pack ops which operate as simple pads may not produce legal
292  // tensor.insert_slice operations when the packed type does not rank reduce
293  // to the padded type.
294  SliceVerificationResult rankReduces =
295  isRankReducedType(packedTensorType, padOp.getResultType());
296 
297  if (rankReduces == SliceVerificationResult::Success) {
298  // This pack is just a plain pad.
299  // Just insert the pad in the higher ranked tensor.
300  // Offsets.
301  SmallVector<OpFoldResult> zeros(packOp.getDestRank(),
302  rewriter.getIndexAttr(0));
303  // Strides.
304  SmallVector<OpFoldResult> ones(packOp.getDestRank(),
305  rewriter.getIndexAttr(1));
307  tensor::getMixedSizes(rewriter, loc, packOp.getDest());
308 
309  auto insertSliceOp = tensor::InsertSliceOp::create(
310  rewriter, loc, /*source=*/padOp, /*dest=*/packOp.getDest(),
311  /*offsets=*/zeros, sizes, /*strides=*/ones);
312 
313  LDBG() << "insert_slice op: " << insertSliceOp;
314 
315  rewriter.replaceOp(packOp, insertSliceOp->getResults());
316 
317  return LowerPackResult{padOp, /*reshapeOp=*/nullptr,
318  /*transposeOp=*/nullptr};
319  }
320  }
321 
322  // 5. Expand from the padded result to the stripMinedShape.
323  auto expandShapeResultType =
324  RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
325  auto reshapeOp = tensor::ExpandShapeOp::create(
326  rewriter, loc, expandShapeResultType, padOp.getResult(),
327  packingMetadata.reassociations);
328 
329  // 6. Transpose stripMinedShape to packedShape.
330  SmallVector<int64_t> transpPerm =
331  invertPermutationVector(packedToStripMinedShapePerm);
332  auto transposeOp = linalg::TransposeOp::create(
333  rewriter, loc, reshapeOp.getResult(), packOp.getDest(), transpPerm);
334 
335  LDBG() << "reshape op: " << reshapeOp;
336  LDBG() << "transpPerm: " << llvm::interleaved(transpPerm);
337  LDBG() << "transpose op: " << transposeOp;
338 
339  // 7. Replace packOp by transposeOp.
340  rewriter.replaceOp(packOp, transposeOp->getResults());
341 
342  return LowerPackResult{padOp, reshapeOp, transposeOp};
343 }
344 
345 FailureOr<LowerUnPackOpResult>
346 linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
347  bool lowerUnpadLikeWithExtractSlice) {
348  Location loc = unPackOp->getLoc();
349  OpBuilder::InsertionGuard g(rewriter);
350  rewriter.setInsertionPoint(unPackOp);
351 
352  RankedTensorType packedTensorType = unPackOp.getSourceType();
353  int64_t packedRank = packedTensorType.getRank();
354 
355  OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
356  auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType());
357  if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) {
358  // This unpack is just a plain unpad.
359  // Just extract the slice from the higher ranked tensor.
360  ArrayRef<int64_t> destShape = destTensorType.getShape();
361  // The inner dimensions stay the same as the destination tensor, but the
362  // outer ones are additional 1s.
363  SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one);
364  sizes.append(tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()));
365 
366  auto extractSliceOp = tensor::ExtractSliceOp::create(
367  rewriter, loc, destTensorType, unPackOp.getSource(),
368  SmallVector<OpFoldResult>(packedRank, zero), sizes,
369  SmallVector<OpFoldResult>(packedRank, one));
370 
371  rewriter.replaceOp(unPackOp, extractSliceOp->getResults());
372 
373  return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr,
374  /*reshapeOp=*/nullptr, extractSliceOp};
375  }
376 
377  // 1. Compute the permutation vector to shuffle packed shape into the shape
378  // before any outer or inner permutations have been applied.
379  PackingMetadata packingMetadata;
380  SmallVector<int64_t> packedToStripMinedShapePerm =
381  getUnPackInverseSrcPerm(unPackOp, packingMetadata);
382 
383  // 2. Compute the stripMinedShape: this is the packed shape without outer and
384  // inner permutations.
385  SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
386  applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
387 
388  // 3. Transpose packedShape to stripMinedShape.
389  RankedTensorType stripMinedTensorType =
390  RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
391  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
392  stripMinedTensorType, packingMetadata.reassociations);
393 
394  // Get dynamic dims from input tensor based on packedToStripMinedShapePerm
395  // permutation.
397  tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
398  applyPermutationToVector(dims, packedToStripMinedShapePerm);
399  auto emptyOp = tensor::EmptyOp::create(rewriter, loc, dims,
400  stripMinedTensorType.getElementType());
401  auto transposeOp =
402  linalg::TransposeOp::create(rewriter, loc, unPackOp.getSource(), emptyOp,
403  packedToStripMinedShapePerm);
404 
405  LDBG() << "insertPositions: "
406  << llvm::interleaved(packingMetadata.insertPositions);
407  LDBG() << "packedShape: " << llvm::interleaved(packedTensorType.getShape());
408  LDBG() << "packedToStripMinedShapePerm: "
409  << llvm::interleaved(packedToStripMinedShapePerm);
410  LDBG() << "reassociations: "
411  << llvm::interleaved(llvm::map_range(packingMetadata.reassociations,
413  LDBG() << "stripMinedShape: " << llvm::interleaved(stripMinedShape);
414  LDBG() << "collapsed type: " << collapsedType;
415 
416  // 4. Collapse from the stripMinedShape to the padded result.
417  auto reshapeOp = tensor::CollapseShapeOp::create(
418  rewriter, loc, collapsedType, transposeOp->getResult(0),
419  packingMetadata.reassociations);
420 
421  // 5. ExtractSlice.
422  int64_t destRank = destTensorType.getRank();
423  auto extractSliceOp = tensor::ExtractSliceOp::create(
424  rewriter, loc, destTensorType, reshapeOp->getResult(0),
425  SmallVector<OpFoldResult>(destRank, zero),
426  tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()),
427  SmallVector<OpFoldResult>(destRank, one));
428 
429  // 6. Inject a copy to preserve DPS.
430  auto copyOp = linalg::CopyOp::create(
431  rewriter, loc, extractSliceOp->getResult(0), unPackOp.getDest());
432 
433  // 7. Replace unPackOp by copyOp.
434  rewriter.replaceOp(unPackOp, copyOp->getResults());
435 
436  return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp};
437 }
438 
440 PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) {
442  for (auto &i : spec) {
443  if (!i.packedDimForEachOperand[operandPos].has_value())
444  continue;
445  res.push_back(i.packedDimForEachOperand[operandPos].value());
446  }
447  return res;
448 }
449 
451 PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) {
453  for (auto &i : spec) {
454  if (!i.packedDimForEachOperand[operandPos].has_value())
455  continue;
456  res.push_back(i.packedSize);
457  }
458  return res;
459 }
460 
461 /// Implement packing of a single LinalgOp by performing packing by
462 /// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator.
463 /// Return the packed Linalg op on success, failure otherwise.
464 FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
465  linalg::LinalgOp linalgOp,
466  ArrayRef<OpFoldResult> packedSizes) {
467  if (packedSizes.size() != linalgOp.getNumLoops()) {
468  return rewriter.notifyMatchFailure(linalgOp,
469  "incorrect number of pack sizes");
470  }
471 
472  Location loc = linalgOp->getLoc();
473  SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
474  SmallVector<utils::IteratorType> iteratorTypes =
475  linalgOp.getIteratorTypesArray();
476  LDBG() << "Start packing: " << linalgOp;
477  LDBG() << "maps: " << llvm::interleaved(indexingMaps);
478  LDBG() << "iterators: " << llvm::interleaved(iteratorTypes);
479 
482  // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i].
483  PackedOperandsDimList listOfPackedOperandsDim;
484  for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) {
485  std::optional<int64_t> maybeConstant = getConstantIntValue(packedSizes[i]);
486  // Skip tile sizes explicitly set to 0.
487  if (maybeConstant.has_value() && maybeConstant.value() == 0)
488  continue;
489 
490  PackedOperandsDim packedOperandsDims;
491  packedOperandsDims.packedSize = packedSizes[i];
492  FailureOr<SmallVector<std::optional<int64_t>>>
493  maybePackedDimForEachOperand =
494  packLinalgMetadataOnce(indexingMaps, iteratorTypes, i);
495  if (failed(maybePackedDimForEachOperand))
496  return failure();
497  packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
498  listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
499 
500  LDBG() << "++++ After pack size #" << i << ": " << packedSizes[i];
501  LDBG() << "maps: " << llvm::interleaved(indexingMaps);
502  LDBG() << "iterators: " << llvm::interleaved(iteratorTypes);
503  LDBG() << "packedDimForEachOperand: "
504  << llvm::interleaved(packedOperandsDims.packedDimForEachOperand);
505  }
506 
507  // Step 2. Propagate packing to all LinalgOp operands.
508  SmallVector<Value> inputsAndInits, results;
509  SmallVector<OpOperand *> initOperands =
510  llvm::to_vector(llvm::make_pointer_range(linalgOp.getDpsInitsMutable()));
511  SmallVector<OpOperand *> inputOperands = linalgOp.getDpsInputOperands();
512  for (const auto &operandsList : {inputOperands, initOperands}) {
513  for (OpOperand *opOperand : operandsList) {
514  int64_t pos = opOperand->getOperandNumber();
515  Value operand = opOperand->get();
516  SmallVector<int64_t> innerPos =
517  listOfPackedOperandsDim.extractPackedDimsForOperand(pos);
518  SmallVector<OpFoldResult> innerPackSizes =
519  listOfPackedOperandsDim.extractPackSizesForOperand(pos);
520  LDBG() << "operand: " << operand;
521  LDBG() << "innerPos: " << llvm::interleaved(innerPos);
522  LDBG() << "innerPackSizes: " << llvm::interleaved(innerPackSizes);
523  if (innerPackSizes.empty()) {
524  inputsAndInits.push_back(operand);
525  continue;
526  }
527  Value dest = linalg::PackOp::createDestinationTensor(
528  rewriter, loc, operand, innerPackSizes, innerPos,
529  /*outerDimsPerm=*/{});
530  ShapedType operandType = cast<ShapedType>(operand.getType());
531  bool areConstantTiles =
532  llvm::all_of(innerPackSizes, [](OpFoldResult tile) {
533  return getConstantIntValue(tile).has_value();
534  });
535  if (areConstantTiles && operandType.hasStaticShape() &&
536  !linalg::PackOp::requirePaddingValue(
537  operandType.getShape(), innerPos,
538  cast<ShapedType>(dest.getType()).getShape(), {},
539  innerPackSizes)) {
540  packOps.push_back(linalg::PackOp::create(rewriter, loc, operand, dest,
541  innerPos, innerPackSizes));
542  } else {
543  // TODO: value of the padding attribute should be determined by
544  // consumers.
545  auto zeroAttr =
546  rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType()));
547  Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
548  packOps.push_back(linalg::PackOp::create(
549  rewriter, loc, operand, dest, innerPos, innerPackSizes, zero));
550  }
551  inputsAndInits.push_back(packOps.back());
552  }
553  }
554 
555  // Step 3. Build the packed op, use the type of `inits` as result types.
556  ValueRange inputs =
557  ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs());
558  ValueRange inits =
559  ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits());
560  auto packedLinalgOp =
561  linalg::GenericOp::create(rewriter, linalgOp.getLoc(), inits.getTypes(),
562  inputs, inits, indexingMaps, iteratorTypes);
563  packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0));
564 
565  // Step 4. Propagate packing to all the op results.
566  for (OpResult result : packedLinalgOp->getResults()) {
567  int64_t resultNum = result.getResultNumber();
568  linalg::PackOp maybePackedInit =
569  inits[resultNum].getDefiningOp<linalg::PackOp>();
570  if (!maybePackedInit) {
571  results.push_back(result);
572  continue;
573  }
574  // Build the symmetrical UnPackOp to the existing PackOp.
575  unPackOps.push_back(linalg::UnPackOp::create(
576  rewriter, packedLinalgOp->getLoc(), result, maybePackedInit.getSource(),
577  maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles()));
578  results.push_back(unPackOps.back());
579  }
580 
581  // Step 5. Replace `linalgOp`.
582  rewriter.replaceOp(linalgOp, results);
583 
584  // Return packedLinalgOp.
585  return PackResult{packOps,
586  cast<linalg::LinalgOp>(packedLinalgOp.getOperation()),
587  unPackOps};
588 }
589 
590 //===----------------------------------------------------------------------===//
591 // packTranspose transformation.
592 //===----------------------------------------------------------------------===//
593 
594 /// Return a copy of `tensorType` after permutation by `permutationVector`.
595 // Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder
596 // but this would introduce a dependence on Dialect in IR.
597 // TODO: Restructure.
598 static RankedTensorType permuteShape(RankedTensorType tensorType,
599  ArrayRef<int64_t> permutationVector) {
600  SmallVector<int64_t> shape(tensorType.getShape());
601  applyPermutationToVector(shape, permutationVector);
602  return RankedTensorType::Builder(tensorType).setShape(shape);
603 }
604 
605 /// Return a new GenericOp obtained by transposing opOperand by the permutation
606 /// vector:
607 /// - the corresponding indexing map is transposed by `permutation`
608 /// - the corresponding operand value is replaced by `transposedValue`
609 /// `linalgOp` is replaced by the return op in the process.
610 /// Asserts that `transposedValue` is of the proper transposed ShapedType.
612  RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand,
613  ArrayRef<int64_t> permutation, Value transposedValue) {
614  // Sanity check the operand.
615  assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand");
616 
617  // Sanity check of the expected transposed tensor type.
618  auto tensorType = permuteShape(
619  cast<RankedTensorType>(opOperand.get().getType()), permutation);
620  (void)tensorType;
621  assert(tensorType == transposedValue.getType() &&
622  "expected tensor type mismatch");
623 
624  // Compute the transposed indexing map.
625  // Sigh unsigned pollution.
626  SmallVector<unsigned> tmpTransposition = llvm::to_vector(
627  llvm::map_range(permutation, [](int64_t i) -> unsigned { return i; }));
628  AffineMap permutationMap =
629  AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext());
630  AffineMap transposedMap =
631  permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand));
632 
633  // Set the transposed indexing map in the proper position.
634  SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
635  indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap;
636  // Set the transposedValue in the proper operand position.
637  SmallVector<Value> operands = linalgOp->getOperands();
638  operands[opOperand.getOperandNumber()] = transposedValue;
639 
640  ValueRange operandsRef(operands);
641  auto transposedGenericOp = linalg::GenericOp::create(
642  rewriter,
643  /*location=*/linalgOp->getLoc(),
644  /*resultTensorTypes=*/
645  operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(),
646  /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()),
647  /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()),
648  /*indexingMaps=*/indexingMaps,
649  /*iteratorTypes=*/linalgOp.getIteratorTypesArray());
650  transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0));
651  rewriter.replaceOp(linalgOp, transposedGenericOp->getResults());
652 
653  return cast<linalg::LinalgOp>(transposedGenericOp.getOperation());
654 }
655 
656 FailureOr<PackTransposeResult>
657 linalg::packTranspose(RewriterBase &rewriter, linalg::PackOp packOp,
658  linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp,
659  ArrayRef<int64_t> outerPerm,
660  ArrayRef<int64_t> innerPerm) {
661  Location loc = linalgOp.getLoc();
662 
663  // Step 1. Transpose packOp.
664  rewriter.setInsertionPoint(packOp);
665  linalg::PackOp transposedPackOp =
666  packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm);
667 
668  if (!packOp.getResult().hasOneUse())
669  return rewriter.notifyMatchFailure(linalgOp, "expect single pack use");
670 
671  OpOperand &packUse = *packOp->getUses().begin();
672  if (packUse.getOwner() != linalgOp) {
673  return rewriter.notifyMatchFailure(
674  linalgOp, "not a single use by the LinalgOp target");
675  }
676  if (maybeUnPackOp &&
677  (!linalgOp.isDpsInit(&packUse) ||
678  maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) {
679  return rewriter.notifyMatchFailure(linalgOp,
680  "not produced by the LinalgOp target");
681  }
682 
683  // Step 2. Transpose linalgOp.
684  // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the
685  // identity. Don't rely on it.
686  int64_t numLeadingDims = packOp.getSourceRank();
687  int64_t numTrailingDims = packOp.getInnerDimsPos().size();
688  // Step 2.a. Compute the permutation on the whole operand.
689  // Leading part just reuse the outerPerm.
690  SmallVector<int64_t> permutation(outerPerm);
691  if (permutation.empty())
692  llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims));
693  // Trailing part needs to reindex positions by `numLeadingDims`.
694  if (innerPerm.empty()) {
695  llvm::append_range(
696  permutation,
697  llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims));
698  } else {
699  llvm::append_range(permutation,
700  llvm::map_range(innerPerm, [&](int64_t pos) {
701  return numLeadingDims + pos;
702  }));
703  }
704  if (!isPermutationVector(permutation))
705  return rewriter.notifyMatchFailure(linalgOp, "invalid permutation");
706 
707  // Step 2.b. Save the transposedPackUse operand number in case we need to
708  // get the tied OpResult after `linalgOp` has been replaced.
709  int64_t packUseOperandNumber = packUse.getOperandNumber();
710  // Step 2.c. Actually perform the transposition.
711  rewriter.setInsertionPoint(linalgOp);
712  linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace(
713  rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult());
714 
715  // Step 3. Maybe transpose unPackOp.
716  linalg::UnPackOp transposedUnPackOp;
717  if (maybeUnPackOp) {
718  OpOperand &opOperand =
719  transposedLinalgOp->getOpOperand(packUseOperandNumber);
720  OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand);
721  rewriter.setInsertionPoint(maybeUnPackOp);
722  transposedUnPackOp = maybeUnPackOp.createTransposedClone(
723  rewriter, loc, transposedResult, innerPerm, outerPerm);
724 
725  rewriter.replaceOp(maybeUnPackOp, transposedUnPackOp->getResults());
726  }
727 
728  // Step 4. Finally, replace packOp now that we don't need it anymore.
729  rewriter.replaceOp(packOp, transposedPackOp->getResults());
730 
731  return PackTransposeResult{transposedPackOp, transposedLinalgOp,
732  transposedUnPackOp};
733 }
734 
735 //===----------------------------------------------------------------------===//
736 // packMatmulGreedily transformation.
737 //===----------------------------------------------------------------------===//
738 
739 /// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m
740 /// and n are proper parallel dimensions and k is a proper reduction
741 /// dimension. Packing occurs by rewriting the op as a linalg.generic and
742 /// calling linalg::pack by `mnkPackedSizes`. The order of the packed
743 /// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2}
744 /// to reorder {m, n, k} into one of the 8 possible forms. The outer
745 /// dimensions of the operands are not permuted at this time, this is left for
746 /// future work.
747 FailureOr<PackResult>
748 linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
749  ArrayRef<OpFoldResult> mnkPackedSizes,
750  ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf,
751  ArrayRef<int64_t> mnkOrder) {
752  assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes");
753  assert((mnkPaddedSizesNextMultipleOf.empty() ||
754  mnkPaddedSizesNextMultipleOf.size() == 3) &&
755  "num of packing sizes next multiple should be empty or of size 3");
756  assert(mnkOrder.size() == 3 && "unexpected mnkOrder size");
757  assert(isPermutationVector(mnkOrder) && "expected a permutation");
758 
759  int64_t numLoops = linalgOp.getNumLoops();
760  if (numLoops <= 2) {
761  LDBG() << "need 3+ loops to find a matmul to pack, got " << numLoops
762  << " in: " << linalgOp;
763  return rewriter.notifyMatchFailure(
764  linalgOp, "need 3+ loops to find a matmul to pack");
765  }
766 
767  // Locally adjust the desired iterator position of mnk and packing sizes.
768  int64_t numPackedDims = mnkPackedSizes.size();
769  SmallVector<int64_t> mmnnkkPos(numPackedDims);
770  for (int64_t i = 0, e = numPackedDims; i < e; ++i)
771  mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
772  SmallVector<OpFoldResult> packedSizes(numPackedDims);
773  for (int64_t i = 0, e = numPackedDims; i < e; ++i)
774  packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
775  SmallVector<int64_t> paddedSizesNextMultipleOf(numPackedDims);
776  for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
777  paddedSizesNextMultipleOf[mnkOrder[i]] =
778  mnkPaddedSizesNextMultipleOf.empty() ? 0
779  : mnkPaddedSizesNextMultipleOf[i];
780  }
781 
782  // 1. Infer dims that are important for matmul.
783  FailureOr<ContractionDimensions> maybeDimensions =
784  inferContractionDims(linalgOp);
785  if (failed(maybeDimensions)) {
786  LDBG() << "couldn't infer matmul iterators in: " << linalgOp;
787  return rewriter.notifyMatchFailure(linalgOp,
788  "couldn't infer matmul iterators");
789  }
790 
791  // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most
792  // minor iterators. In cases with multiple options for m, n, k bias towards
793  // the most minor embedding.
794  // If we wanted a different normalization order, this is where it would have
795  // to plug a heuristic.
796  int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(),
797  kPos = maybeDimensions->k.back();
798  LDBG() << "Start packing generic op greedily with (m@" << mPos << ", n@"
799  << nPos << ", k@" << kPos << "): " << linalgOp;
800 
801  // 2.a. Rewrite as a generic.
802  auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
803  if (!genericOp) {
804  FailureOr<GenericOp> generalizeResult =
805  generalizeNamedOp(rewriter, linalgOp);
806  assert(succeeded(generalizeResult) && "unexpected failure generalizing op");
807  genericOp = *generalizeResult;
808  }
809 
810  // 2.b. Interchange to move the dimensions (k, m, n) as most-minor
811  // iterators. Note that this only normalized the iteration order and does
812  // not change the indexings of any operand.
813  SmallVector<int64_t> permutation =
814  computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos);
815  LDBG() << "perm: " << llvm::interleaved(permutation);
816  // Sign .. unsigned pollution.
817  SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end());
818  FailureOr<GenericOp> interchangeResult =
819  interchangeGenericOp(rewriter, genericOp, unsignedPerm);
820  assert(succeeded(interchangeResult) && "unexpected failure interchanging op");
821  genericOp = *interchangeResult;
822  LDBG() << "Generalized Op to pack: " << genericOp;
823 
824  // At this point, the op iterators are normalized to {leading, k, m, n}.
825  // The layouts induced by packing will always be:
826  // - LHS{leading_lhs, kk, mm}
827  // - RHS{leading_rhs, kk, nn}
828  // - RES{leading_res, mm, nn}
829  // If we wanted to change the packed order, we would reorder (k, m, n) to
830  // something else above.
831  //
832  // Additional permutations of the outer dims of the operands (i.e.
833  // leading_lhs, leading_rhs and leading_res) could follow by computing the
834  // desired outerPerm for each operand.
835  // This is left for future work.
836 
837  // TODO: this creates too much IR, go use reifyResultShapes.
838  SmallVector<Range, 4> loopRanges =
839  cast<LinalgOp>(genericOp.getOperation())
840  .createLoopRanges(rewriter, genericOp.getLoc());
841 
842  // Add leading zeros to match numLoops, we only pack the last 3 dimensions
843  // post interchange.
844  LDBG() << "paddedSizesNextMultipleOf: "
845  << llvm::interleaved(paddedSizesNextMultipleOf);
846  LDBG() << "loopRanges: "
847  << llvm::interleaved(
848  llvm::map_range(loopRanges, [](Range r) { return r.size; }));
849  SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(),
850  rewriter.getIndexAttr(0));
851  for (int64_t i = 0, e = numPackedDims; i < e; ++i) {
852  if (paddedSizesNextMultipleOf[i] == 0) {
853  adjustedPackedSizes.push_back(packedSizes[i]);
854  continue;
855  }
856  AffineExpr d0, s0;
857  bindDims(rewriter.getContext(), d0);
858  bindSymbols(rewriter.getContext(), s0);
859  adjustedPackedSizes.push_back(affine::makeComposedFoldedAffineApply(
860  rewriter, genericOp->getLoc(), d0.ceilDiv(s0) * s0,
861  {loopRanges[adjustedPackedSizes.size()].size,
862  rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])}));
863  }
864  LDBG() << "adjustedPackedSizes: " << llvm::interleaved(adjustedPackedSizes);
865 
866  // TODO: If we wanted to give the genericOp a name after packing, after
867  // calling `pack` would be a good time. One would still need to check that
868  // `containsMostMinorMatmul(packingRes->packedLinalgOp)` is true, since we
869  // also allow degenerate matmul cases (i.e. matvec, dot).
870  return pack(rewriter, genericOp, adjustedPackedSizes);
871 }
872 
873 //===----------------------------------------------------------------------===//
874 // Transformations exposed as rewrite patterns.
875 //===----------------------------------------------------------------------===//
876 
879  assert(!tileSizeComputationFunction && "tile sizes already set");
880  SmallVector<int64_t, 4> tileSizes(ts);
881  tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
882  OpBuilder::InsertionGuard guard(b);
884  &op->getParentOfType<func::FuncOp>().getBody().front());
885  return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
886  Value v = arith::ConstantIndexOp::create(b, op->getLoc(), s);
887  return v;
888  }));
889  };
890  return *this;
891 }
892 
894  memref::CopyOp copyOp, PatternRewriter &rewriter) const {
895  return vectorizeCopy(rewriter, copyOp);
896 }
897 
898 /// Filling `dest` using FillOp constant padding value if possible.
899 /// Otherwise, generate a tensor::GenerateOp.
901  RewriterBase &rewriter, tensor::PadOp padOp, Value dest,
902  const SmallVector<Value> &dynSizes) const {
903  auto padValue = padOp.getConstantPaddingValue();
904  if (padValue) {
905  // Move the padding value defined inside the PadOp block to outside.
906  if (padValue.getParentBlock() == &padOp.getRegion().front())
907  rewriter.moveOpBefore(padValue.getDefiningOp(), padOp);
908  return FillOp::create(rewriter, padOp.getLoc(), padValue, dest).result();
909  }
910 
911  // Fill could not be optimized: Lower to tensor::GenerateOp with region.
912  auto generateOp = tensor::GenerateOp::create(rewriter, padOp.getLoc(),
913  padOp.getResultType(), dynSizes);
914  // Copy region to new op.
915  IRMapping bvm;
916  padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm);
917  return generateOp;
918 }
919 
920 LogicalResult
922  PatternRewriter &rewriter) const {
923  // Given an OpFoldResult, return an index-typed value.
924  auto getIdxValue = [&](OpFoldResult ofr) {
925  if (auto val = llvm::dyn_cast_if_present<Value>(ofr))
926  return val;
928  rewriter, padOp.getLoc(),
929  cast<IntegerAttr>(cast<Attribute>(ofr)).getInt())
930  .getResult();
931  };
932 
933  auto resultType = padOp.getResultType();
934  // Compute size of EmptyOp. Any combination of static/dynamic is supported.
935  SmallVector<Value> dynSizes;
936  SmallVector<int64_t> staticSizes;
937  for (unsigned dim = 0; dim < resultType.getRank(); ++dim) {
938  if (resultType.isDynamicDim(dim)) {
939  auto srcSize = getIdxValue(tensor::getMixedSize(rewriter, padOp.getLoc(),
940  padOp.getSource(), dim));
941  // Add low and high padding value.
942  auto plusLow = rewriter.createOrFold<arith::AddIOp>(
943  padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim]));
944  auto plusHigh = rewriter.createOrFold<arith::AddIOp>(
945  padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim]));
946  dynSizes.push_back(plusHigh);
947  }
948  staticSizes.push_back(resultType.getDimSize(dim));
949  }
950 
951  // Init tensor and fill it with padding.
952  Value emptyTensor =
953  tensor::EmptyOp::create(rewriter, padOp.getLoc(), staticSizes,
954  resultType.getElementType(), dynSizes);
955  Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes);
956 
957  // Generate a InsertSliceOp for copying the PadOp source.
958  auto sourceType = padOp.getSourceType();
959  // Compute size of source of tensor::PadOp.
960  SmallVector<OpFoldResult> srcSizes =
961  tensor::getMixedSizes(rewriter, padOp.getLoc(), padOp.getSource());
962  // Strides of InsertSliceOp are all 1.
963  SmallVector<OpFoldResult> strides(sourceType.getRank(),
964  rewriter.getIndexAttr(1));
965  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
966  padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes,
967  strides);
968 
969  return success();
970 }
971 
973  tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
974  if (!sliceOp.hasUnitStride())
975  return failure();
976 
977  auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>();
978  if (!padOp)
979  return failure();
980 
981  bool zeroSliceGuard = true;
982  if (controlFn) {
983  if (std::optional<bool> control = controlFn(sliceOp))
984  zeroSliceGuard = *control;
985  else
986  return failure();
987  }
988 
989  FailureOr<TilingResult> tilingResult =
990  tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(),
991  sliceOp.getMixedSizes(), zeroSliceGuard);
992  if (failed(tilingResult))
993  return failure();
994 
995  RankedTensorType sourceType = sliceOp.getSourceType();
996  RankedTensorType resultType = sliceOp.getResultType();
997 
998  // If the extract_slice is not rank-reduced, all shapes are static and the
999  // data source is actually used. Rewrite into pad(extract_slice(x)).
1000  if (sourceType.getRank() == resultType.getRank()) {
1001  rewriter.replaceOp(sliceOp, tilingResult->tiledValues);
1002  return success();
1003  }
1004 
1005  // Handle rank-reduced slice by creating another extract_slice op.
1007  rewriter, sliceOp.getLoc(), tilingResult->tiledValues[0], resultType);
1008 
1009  rewriter.replaceOp(sliceOp, rankReduced);
1010  return success();
1011 }
1012 
1013 /// If padding value is set, returns a tensor.pad Op for the source tensor,
1014 /// with the output shape matching the output of `packOp`. Otherwise, returns
1015 /// the source directly.
1016 ///
1017 /// This method assumes that all outer dims for this pack Op are 1.
1019  linalg::PackOp packOp) {
1020  Value input = packOp.getSource();
1021  if (!packOp.getPaddingValue()) {
1022  return input;
1023  }
1024 
1025  assert(llvm::all_of(packOp.getAllOuterDims(),
1026  [](int64_t val) { return val == 1; }) &&
1027  "some outer dims are != 1");
1028 
1029  Location loc = packOp.getLoc();
1030  ShapedType inputType = packOp.getSourceType();
1031  int64_t inputRank = inputType.getRank();
1032 
1033  DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
1034  packOp.getDimAndTileMapping();
1035 
1036  // The sizes of dynamic tiles
1037  SmallVector<Value> dynamicTileSizes;
1038 
1039  // Collect dims for the padded shape.
1040  SmallVector<int64_t> paddedShape;
1041  for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1042  // 1. Non-tiled outer dims.
1043  // These dims should be 1 and we simply preserve them.
1044  if (!tileAndPosMapping.count(dimIdx)) {
1045  int64_t inputDimSize = inputType.getDimSize(dimIdx);
1046  assert(inputDimSize == 1 &&
1047  "with all outer dims == 1, this non-tiled input dim should be 1!");
1048  paddedShape.push_back(inputDimSize);
1049  continue;
1050  }
1051 
1052  // 2. Tiled outer dims
1053  // As all outer dims == 1, it is safe to use the tile size for the padded
1054  // shape.
1055  OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
1056 
1057  // 2.1 Static tile sizes
1058  std::optional<int64_t> cstTileSize = getConstantIntValue(tileSizeForDim);
1059  if (cstTileSize.has_value()) {
1060  paddedShape.push_back(cstTileSize.value());
1061  continue;
1062  }
1063 
1064  // 2.2 Dynamic tile sizes
1065  paddedShape.push_back(ShapedType::kDynamic);
1066 
1067  // Get the value that holds the dynamic size.
1068  dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
1069  }
1070  auto resultType =
1071  RankedTensorType::get(paddedShape, inputType.getElementType());
1072  return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(),
1073  /*nofold=*/false, loc, builder,
1074  dynamicTileSizes);
1075 }
1076 
1077 // Normalizes a permutation on a higher rank space to its actual size, e.g.
1078 // perm = [1, 4, 2]
1079 // becomes
1080 // norm = [0, 2, 1]
1081 static SmallVector<int64_t>
1083  constexpr int64_t kNonTiledMarker = -1;
1084  SmallVector<int64_t> vec(rank, kNonTiledMarker);
1085  for (auto [index, value] : llvm::enumerate(perm))
1086  vec[value] = index;
1087  SmallVector<int64_t> normalizedPerm = llvm::filter_to_vector(
1088  vec, [&](int64_t v) { return v != kNonTiledMarker; });
1089  // This inverts the permutation in addition to normalizing so invert back.
1090  return invertPermutationVector(normalizedPerm);
1091 }
1092 
1093 // Gets the normalized permutation implied by innerDimsPos and outerDimsPerm
1094 // assuming rank reduction of unit outer dims.
1095 static SmallVector<int64_t>
1099  SmallVector<int64_t> rankReducedOuterDimsPerm;
1100  SmallVector<int64_t> outerDims;
1101  SmallVector<int64_t> innerDims;
1102  int64_t dim = 0;
1103  int64_t unpackedRank = shape.size();
1104  for (auto i : llvm::seq<unsigned>(0, unpackedRank)) {
1105  if (llvm::is_contained(innerDimsPos, i)) {
1106  innerDims.push_back(dim++);
1107  continue;
1108  }
1109  if (shape[i] == 1)
1110  continue;
1111  outerDims.push_back(dim++);
1112  if (!outerDimsPerm.empty())
1113  rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]);
1114  }
1115 
1116  // Get the position of the inner dims after permutation.
1117  SmallVector<int64_t> innerPerm =
1119  applyPermutationToVector<int64_t>(innerDims, innerPerm);
1120 
1121  // Ditto for the outer dims.
1122  SmallVector<int64_t> perm = outerDims;
1123 
1124  rankReducedOuterDimsPerm =
1125  getPackUnpackNormalizedPerm(unpackedRank, rankReducedOuterDimsPerm);
1126  if (!rankReducedOuterDimsPerm.empty())
1127  applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm);
1128 
1129  // The tile always ends up as the inner most dims after packing.
1130  perm.append(innerDims);
1131 
1132  return perm;
1133 }
1134 
1136  linalg::PackOp packOp, PatternRewriter &rewriter) const {
1137  if (llvm::any_of(packOp.getTiledOuterDims(),
1138  [](int64_t dim) { return dim != 1; })) {
1139  return rewriter.notifyMatchFailure(
1140  packOp, "not all outer dimensions of the result are 1s");
1141  }
1142 
1143  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
1144  auto outerDimsPerm = packOp.getOuterDimsPerm();
1145 
1146  // Verify that there are no:
1147  // * non-unit + un-tiled-outer-dims,
1148  // that are permuted. Supporting such cases would require refining the logic
1149  // that generates the Transpose Op.
1150  if (!llvm::all_of(outerDimsPerm, [&innerDimsPos, &packOp](int64_t dim) {
1151  static int prev = 0;
1152  // Skip tiled dims - these can be permuted.
1153  if (llvm::is_contained(innerDimsPos, dim))
1154  return true;
1155 
1156  // Check whether this dim has been permuted. Permuting unit dims is fine
1157  // as that's effectively a no-op.
1158  if (dim < prev && (packOp.getType().getShape()[prev] != 1 ||
1159  packOp.getType().getShape()[dim] != 1))
1160  return false;
1161 
1162  prev = dim;
1163  return true;
1164  })) {
1165  return rewriter.notifyMatchFailure(
1166  packOp, "At least one non-unit and un-tiled outer dim is permuted, "
1167  "this is not supported ATM!");
1168  }
1169 
1170  Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
1171  Attribute oneIdxAttr = rewriter.getIndexAttr(1);
1172  Location loc = packOp.getLoc();
1173 
1174  int64_t srcRank = packOp.getSourceRank();
1175  int64_t destRank = packOp.getDestRank();
1176 
1177  // 1. Get the input that is going to be packed. If the input requires padding,
1178  // add a padding operation and return that as the input.
1179  Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
1180 
1181  // 2. Transpose the input to match the inner tile order:
1182  // %init = tensor.empty()
1183  // %transposed_tile = linalg.transpose ins(%source_or_padded_source),
1184  // outs(%init)
1185  // Assumptions made:
1186  // - All tiled outer dims are 1 - the corresponding transposition order
1187  // doesn't matter, but requires all dim indices to be present.
1188  // - Un-tiled outer dims remain un-permuted.
1189 
1190  // 2.1 Get the permutation for linalg.transpose:
1191  // [ untiled-dims, inner-dims-pos ]
1192  // Note, this logic assumes that the untiled dims are not permuted.
1193  SmallVector<int64_t> srcPermForTranspose;
1194  for (int64_t i = 0; i < srcRank; i++) {
1195  // We assume the `k` dimensions of the inner dim position, where `k` is the
1196  // rank of the inner tiling, correspond to the last `k` indices of the
1197  // transpose permutation. This is done by adding the indices not contained
1198  // in the inner dimension position in order from 0 to `n`. Where n is the
1199  // rank of the source tensor. For example if we have a source tensor with
1200  // indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining
1201  // indices are [1, 2]. and the transpose will be [1, 2, 3, 0].
1202  if (llvm::is_contained(innerDimsPos, i))
1203  continue;
1204  srcPermForTranspose.push_back(i);
1205  }
1206  srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());
1207 
1208  // 2.2 Create the init tensor for linalg.transpose with the correct shape:
1209  // [ untiled-dims, tiled-dims ]
1210  ShapedType inputTy = cast<ShapedType>(input.getType());
1211  SmallVector<OpFoldResult> shapeForEmptyOp;
1212  for (int64_t i = 0; i < srcRank; i++) {
1213  if (llvm::is_contained(innerDimsPos, i)) {
1214  // The tiled dims are appended after this loop.
1215  continue;
1216  }
1217  if (inputTy.isStaticDim(i))
1218  shapeForEmptyOp.push_back(rewriter.getIndexAttr(inputTy.getShape()[i]));
1219  else
1220  shapeForEmptyOp.emplace_back(
1221  tensor::DimOp::create(rewriter, loc, input, i).getResult());
1222  }
1223  shapeForEmptyOp.append(packOp.getMixedTiles());
1224 
1225  // getMixedTiles() may contain Values pointing to constant ops, not the
1226  // constant attributes. Replace them with a true OpFoldResult.
1227  llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(),
1228  [&](OpFoldResult ofr) {
1229  if (auto val = llvm::dyn_cast<Value>(ofr))
1230  return getAsOpFoldResult(val);
1231  return ofr;
1232  });
1233 
1234  LDBG() << "Pack permutation: " << packOp;
1235  LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose);
1236  LDBG() << "Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp);
1237 
1238  Value empty = tensor::EmptyOp::create(
1239  rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType());
1240 
1241  // 2.3 Create linalg.transpose
1242  auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
1243  srcPermForTranspose);
1244 
1245  // 3. Insert the inner tile into the destination tensor:
1246  // %inserted_tile = tensor.insert_slice(%transposed_tile)
1247 
1248  // Compute the sizes attribute:
1249  // [ outer-dims, tile-sizes ]
1250  // Note that the output from the transpose Op excludes the tiled outer dims.
1251  // However, given the assumption that:
1252  // * all tiled outer dims == 1,
1253  // we can just use a rank-expanding tensor.insert_slice.
1254  SmallVector<OpFoldResult> writeSizes;
1255  for (auto size : packOp.getAllOuterDims()) {
1256  writeSizes.push_back(rewriter.getIndexAttr(size));
1257  }
1258 
1259  for (auto tileSize : packOp.getMixedTiles()) {
1260  auto [_, tileSizeOfr] =
1261  getSimplifiedOfrAndStaticSizePair(tileSize, rewriter);
1262  writeSizes.push_back(tileSizeOfr);
1263  }
1264 
1265  // TODO: Add a constructor for tensor.insert_slice that doesn't require
1266  // strides nor offsets.
1267  SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
1268  SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
1269 
1270  auto insert = tensor::InsertSliceOp::create(
1271  rewriter, loc, transposedOp.getResult()[0], packOp.getDest(),
1272  writeOffsets, writeSizes, writeStrides);
1273 
1274  // 4. Replace tensor.packOp with tensor.insert_slice created above
1275  rewriter.replaceOp(packOp, insert.getResult());
1276 
1277  return success();
1278 }
1279 
1281  linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const {
1282  int64_t srcRank = unpackOp.getSourceRank();
1283  int64_t destRank = unpackOp.getDestRank();
1284  ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
1285  ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1286  if (llvm::any_of(unpackOp.getTiledOuterDims(),
1287  [](int64_t dim) { return dim != 1; })) {
1288  return rewriter.notifyMatchFailure(
1289  unpackOp,
1290  "require the tiled outer dimensions of the result are all 1s");
1291  }
1292 
1293  // 1. Use rank-reduced tensor.extract_slice op to extract the tile:
1294  // %extracted_tile = tensor.extract_slice(%unpack_op_input)
1295  Location loc = unpackOp.getLoc();
1296  Value source = unpackOp.getSource();
1297  DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
1298  unpackOp.getDimAndTileMapping();
1299  Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
1300  Attribute oneIdxAttr = rewriter.getIndexAttr(1);
1301 
1302  // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of
1303  // dims:
1304  // [ outer-untiled-dims, outer-tiled-dims, tile-sizes ]
1305  SmallVector<int64_t> readShapeForExtractSlice;
1306  // The sizes attribute for ExtractSliceOp. Due to rank-reducing (and
1307  // outer-tiled-dims being all 1), this will be
1308  // [ outer-untiled-dims, tile-sizes ]
1309  SmallVector<OpFoldResult> extractSliceSizes;
1310  // The offset and strides attributes for ExtractSliceOp.
1311  SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr);
1312  SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr);
1313 
1314  // Shape for EmptyOp that's used as the init value for TransposeOp below.
1315  // This should be:
1316  // [ outer-untiled-dims, tile-sizes ]
1317  // However, skip unit dims - TransposeOp (below) applies rank-reduced
1318  // permutation.
1319  SmallVector<OpFoldResult> shapeForEmptyOp;
1320 
1321  for (auto i : llvm::seq<unsigned>(0, destRank)) {
1322  // Compute sizes attribute for ExtractSliceOp - outer-tiled-dims.
1323  //
1324  // As all outer tiled dims are 1, so the corresponding
1325  // slice size to read will also 1. As this will be rank-reducing "extract
1326  // slice" (i.e. the unit dims will be "collapsed"), there's no need to
1327  // update:
1328  // * the output shape for ExtractSliceOp, nor
1329  // * the shape for EmptyOp.
1330  if (dimAndTileMapping.count(i)) {
1331  extractSliceSizes.push_back(oneIdxAttr);
1332  continue;
1333  }
1334 
1335  // Compute sizes attribute for ExtractSliceOp + EmptyOp -
1336  // outer-untiled-dims
1337  if (ShapedType::isDynamic(srcShape[i])) {
1338  OpFoldResult dynamicDim =
1339  tensor::DimOp::create(rewriter, loc, source, i).getResult();
1340  extractSliceSizes.push_back(dynamicDim);
1341  shapeForEmptyOp.push_back(dynamicDim);
1342  } else {
1343  extractSliceSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
1344  if (srcShape[i] != 1)
1345  shapeForEmptyOp.push_back(rewriter.getIndexAttr(srcShape[i]));
1346  }
1347  // Compute the output shape for ExtractSliceOp - outer-untiled-dims (take
1348  // into account rank-reducing)
1349  if (srcShape[i] != 1) {
1350  readShapeForExtractSlice.push_back(srcShape[i]);
1351  }
1352  }
1353  // Append the tile sizes to "sizes attribute" for ExtractSliceOp and the
1354  // shape for EmptyOp.
1355  auto mixedTiles = unpackOp.getMixedTiles();
1356  extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end());
1357  shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end());
1358 
1359  // Explicitly create the type for extract_slice op because the inner tile
1360  // size could be 1. We want to represent the whole inner tile in this case.
1361  auto tileShape = srcShape.drop_front(destRank);
1362  // Append the inner tile shape to the permuted and rank-reduced outer shape.
1363  readShapeForExtractSlice.append(tileShape.begin(), tileShape.end());
1364  Type elemType = unpackOp.getSourceType().getElementType();
1365  auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
1366  Value innerTile = tensor::ExtractSliceOp::create(
1367  rewriter, loc, readType, unpackOp.getSource(), extractSliceOffsets,
1368  extractSliceSizes, extractSliceStrides);
1369 
1370  // 2. Transpose the tile to match the outer corresponding tile order.
1372  srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm());
1373  // Unpack is a transition out of packed space so we invert the permutation.
1374  perm = invertPermutationVector(perm);
1375  applyPermutationToVector<OpFoldResult>(shapeForEmptyOp, perm);
1376 
1377  Value empty =
1378  tensor::EmptyOp::create(rewriter, loc, shapeForEmptyOp, elemType);
1379  auto transposedOp =
1380  linalg::TransposeOp::create(rewriter, loc, innerTile, empty, perm);
1381 
1382  // 3. Handle in-complete tiles if needed. It truncates trailing data from the
1383  // transposed tile.
1384  int numLoops = shapeForEmptyOp.size();
1385  SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr);
1386  SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr);
1387  SmallVector<OpFoldResult> tileSizes;
1388  ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
1389  for (auto i : llvm::seq<unsigned>(0, destRank)) {
1390  if (dimAndTileMapping.count(i) || destShape[i] != 1)
1391  tileSizes.push_back(
1392  tensor::getMixedSize(rewriter, loc, unpackOp.getDest(), i));
1393  }
1394 
1395  auto partialTile =
1396  tensor::ExtractSliceOp::create(rewriter, loc, transposedOp.getResult()[0],
1397  tileOffsets, tileSizes, tileStrides);
1398 
1399  // 4. Insert the result to the destination tensor.
1400  SmallVector<OpFoldResult> writeSizes;
1401  SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
1402  SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
1403  for (int i = 0, idx = 0; i < destRank; ++i) {
1404  if (dimAndTileMapping.count(i) || destShape[i] != 1)
1405  writeSizes.push_back(tileSizes[idx++]);
1406  else
1407  writeSizes.push_back(oneIdxAttr);
1408  }
1409  auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile,
1410  unpackOp.getDest(), writeOffsets,
1411  writeSizes, writeStrides);
1412  rewriter.replaceOp(unpackOp, insert.getResult());
1413 
1414  return success();
1415 }
1416 
1417 // The following are patterns for downscaling convolution ops with size-1
1418 // window dimensions.
1419 //
1420 // Note that we'd eventually want to write such transformations in a generic
1421 // way, e.g., converting to linalg.generic, removing the size-1 dimensions,
1422 // and then turning back to named ops. But for now it's fine to have a few
1423 // patterns matching special ops to get started.
1424 
1425 template <typename Conv2DOp, typename Conv1DOp>
1427  returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const {
1428  if (convOp.hasPureBufferSemantics())
1429  return failure(); // To be implemented.
1430 
1431  Value input = convOp.getInputs().front();
1432  Value kernel = convOp.getInputs().back();
1433  Value output = convOp.getOutputs().front();
1434 
1435  auto inputType = dyn_cast<RankedTensorType>(input.getType());
1436  auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1437  auto outputType = dyn_cast<RankedTensorType>(output.getType());
1438 
1439  auto kernelShape = kernelType.getShape();
1440  auto outputShape = outputType.getShape();
1441 
1442  // Get domain indices based on conv2D layout.
1443  auto [khIndex, kwIndex, ohIndex, owIndex] =
1445  convOp)
1446  .Case([&](linalg::Conv2DNhwcHwcfOp op) {
1447  return std::make_tuple(0, 1, 1, 2);
1448  })
1449  .Case([&](linalg::Conv2DNchwFchwOp op) {
1450  return std::make_tuple(2, 3, 2, 3);
1451  })
1452  .Case([&](linalg::PoolingNhwcSumOp op) {
1453  return std::make_tuple(0, 1, 1, 2);
1454  })
1455  .Case([&](linalg::PoolingNchwSumOp op) {
1456  return std::make_tuple(0, 1, 2, 3);
1457  })
1458  .Case([&](linalg::PoolingNhwcMaxOp op) {
1459  return std::make_tuple(0, 1, 1, 2);
1460  })
1461  .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) {
1462  return std::make_tuple(0, 1, 1, 2);
1463  })
1464  .Case([&](linalg::PoolingNhwcMinOp op) {
1465  return std::make_tuple(0, 1, 1, 2);
1466  })
1467  .Case([&](linalg::PoolingNhwcMinUnsignedOp op) {
1468  return std::make_tuple(0, 1, 1, 2);
1469  })
1470  .Case([&](linalg::PoolingNchwMaxOp op) {
1471  return std::make_tuple(0, 1, 2, 3);
1472  })
1473  .DefaultUnreachable("unexpected conv2d/pool2d operation.");
1474 
1475  // Only handle the case where at least one of the window dimensions is
1476  // of size 1. Other cases can rely on tiling to reduce to such cases.
1477  int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex];
1478  int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex];
1479  bool removeH = (khSize == 1 && ohSize == 1);
1480  bool removeW = (kwSize == 1 && owSize == 1);
1481  if (!removeH && !removeW)
1482  return failure();
1483 
1484  // Get new shapes and types for all operands by removing the size-1
1485  // dimension.
1486  using RTTBuilder = RankedTensorType::Builder;
1487  RankedTensorType newInputType =
1488  RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex));
1489  RankedTensorType newKernelType =
1490  RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex));
1491  RankedTensorType newOutputType =
1492  RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex));
1493 
1494  // Rank-reduce operands.
1495  Location loc = convOp.getLoc();
1497  rewriter, loc, input, newInputType);
1499  rewriter, loc, kernel, newKernelType);
1501  rewriter, loc, output, newOutputType);
1502 
1503  // Rank-reduce strides and dilations too.
1504  // TODO: dropDim 1-liner helper.
1505  auto strides =
1506  llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>());
1507  strides.erase(strides.begin() + (removeH ? 0 : 1));
1508  auto stridesAttr = rewriter.getI64VectorAttr(strides);
1509 
1510  auto dilations =
1511  llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>());
1512  dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1513  auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1514 
1515  auto conv1DOp = Conv1DOp::create(
1516  rewriter, loc, newOutputType, ValueRange{newInput, newKernel},
1517  ValueRange{newOutput}, stridesAttr, dilationsAttr);
1518 
1519  // Insert back.
1521  rewriter, loc, conv1DOp.getResult(0), output);
1522  rewriter.replaceOp(convOp, inserted);
1523 
1524  return conv1DOp;
1525 }
1526 
1527 template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp,
1528  Conv1DNwcWcfOp>;
1529 template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
1530  Conv1DNcwFcwOp>;
1531 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp,
1532  PoolingNwcSumOp>;
1533 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp,
1534  PoolingNcwSumOp>;
1535 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp,
1536  PoolingNwcMaxOp>;
1538  PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>;
1539 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp,
1540  PoolingNwcMinOp>;
1542  PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>;
1543 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp,
1544  PoolingNcwMaxOp>;
1545 
1546 FailureOr<DepthwiseConv1DNwcWcOp>
1548  DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const {
1549  if (convOp.hasPureBufferSemantics())
1550  return failure(); // To be implemented.
1551 
1552  Value input = convOp.getInputs().front();
1553  Value kernel = convOp.getInputs().back();
1554  Value output = convOp.getOutputs().front();
1555 
1556  auto inputType = dyn_cast<RankedTensorType>(input.getType());
1557  auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1558  auto outputType = dyn_cast<RankedTensorType>(output.getType());
1559 
1560  auto kernelShape = kernelType.getShape();
1561  auto outputShape = outputType.getShape();
1562 
1563  // Only handle the case where at least one of the window dimensions is
1564  // of size 1. Other cases can rely on tiling to reduce to such cases.
1565  int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1566  int64_t ohSize = outputShape[1], owSize = outputShape[2];
1567  bool removeH = (khSize == 1 && ohSize == 1);
1568  bool removeW = (kwSize == 1 && owSize == 1);
1569  if (!removeH && !removeW)
1570  return failure();
1571 
1572  // Get new shapes and types for all operands by removing the size-1
1573  // dimension.
1574  using RTTBuilder = RankedTensorType::Builder;
1575  RankedTensorType newInputType =
1576  RTTBuilder(inputType).dropDim((removeH ? 1 : 2));
1577  RankedTensorType newKernelType =
1578  RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1579  RankedTensorType newOutputType =
1580  RTTBuilder(outputType).dropDim(removeH ? 1 : 2);
1581 
1582  // Rank-reduce operands.
1583  Location loc = convOp.getLoc();
1585  rewriter, loc, input, newInputType);
1587  rewriter, loc, kernel, newKernelType);
1589  rewriter, loc, output, newOutputType);
1590 
1591  // Rank-reduce strides and dilations too.
1592  // TODO: dropDim 1-liner helper.
1593  auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>());
1594  strides.erase(strides.begin() + (removeH ? 0 : 1));
1595  auto stridesAttr = rewriter.getI64VectorAttr(strides);
1596 
1597  auto dilations =
1598  llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>());
1599  dilations.erase(dilations.begin() + (removeH ? 0 : 1));
1600  auto dilationsAttr = rewriter.getI64VectorAttr(dilations);
1601 
1602  auto conv1DOp = DepthwiseConv1DNwcWcOp::create(
1603  rewriter, loc, newOutputType, ValueRange{newInput, newKernel},
1604  ValueRange{newOutput}, stridesAttr, dilationsAttr);
1605 
1606  // Insert back.
1608  rewriter, loc, conv1DOp.getResult(0), output);
1609  rewriter.replaceOp(convOp, inserted);
1610 
1611  return conv1DOp;
1612 }
1613 
1614 FailureOr<Conv1DOp>
1616  PatternRewriter &rewriter) const {
1617  if (convOp.hasPureBufferSemantics())
1618  return failure(); // To be implemented.
1619 
1620  Value input = convOp.getInputs().front();
1621  Value kernel = convOp.getInputs().back();
1622  Value output = convOp.getOutputs().front();
1623 
1624  auto inputType = dyn_cast<RankedTensorType>(input.getType());
1625  auto kernelType = dyn_cast<RankedTensorType>(kernel.getType());
1626  auto outputType = dyn_cast<RankedTensorType>(output.getType());
1627 
1628  auto kernelShape = kernelType.getShape();
1629  auto outputShape = outputType.getShape();
1630 
1631  // Only handle the case where at least one of the window dimensions is
1632  // of size 1. Other cases can rely on tiling to reduce to such cases.
1633  int64_t khSize = kernelShape[0], kwSize = kernelShape[1];
1634  int64_t ohSize = outputShape[0], owSize = outputShape[1];
1635  bool removeH = (khSize == 1 && ohSize == 1);
1636  bool removeW = (kwSize == 1 && owSize == 1);
1637  if (!removeH && !removeW)
1638  return failure();
1639 
1640  // Get new shapes and types for all operands by removing the size-1
1641  // dimension.
1642  using RTTBuilder = RankedTensorType::Builder;
1643  RankedTensorType newInputType =
1644  RTTBuilder(inputType).dropDim((removeH ? 0 : 1));
1645  RankedTensorType newKernelType =
1646  RTTBuilder(kernelType).dropDim((removeH ? 0 : 1));
1647  RankedTensorType newOutputType =
1648  RTTBuilder(outputType).dropDim(removeH ? 0 : 1);
1649 
1650  // Rank-reduce operands.
1651  Location loc = convOp.getLoc();
1653  rewriter, loc, input, newInputType);
1655  rewriter, loc, kernel, newKernelType);
1657  rewriter, loc, output, newOutputType);
1658 
1659  auto conv1DOp =
1660  Conv1DOp::create(rewriter, loc, newOutputType,
1661  ValueRange{newInput, newKernel}, ValueRange{newOutput});
1662 
1663  // Insert back.
1665  rewriter, loc, conv1DOp.getResult(0), output);
1666  rewriter.replaceOp(convOp, inserted);
1667 
1668  return conv1DOp;
1669 }
1670 
1672  PatternBenefit benefit) {
1673  patterns.add<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
1674  Conv1DNwcWcfOp>,
1675  DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
1676  Conv1DNcwFcwOp>,
1678  patterns.getContext(), benefit);
1679  patterns.add<
1683  DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxUnsignedOp,
1684  PoolingNwcMaxUnsignedOp>,
1686  DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinUnsignedOp,
1687  PoolingNwcMinUnsignedOp>,
1689  patterns.getContext(), benefit);
1690 }
1691 
1695 }
1696 
1698  patterns.add<DecomposePadOpPattern>(patterns.getContext());
1699 }
SmallVector< int64_t > outerDimsPerm
Definition: LinalgOps.cpp:5181
SmallVector< int64_t > innerDimsPos
Definition: LinalgOps.cpp:5179
static RankedTensorType permuteShape(RankedTensorType tensorType, ArrayRef< int64_t > permutationVector)
Return a copy of tensorType after permutation by permutationVector.
Definition: Transforms.cpp:598
static SmallVector< int64_t > getPackUnpackRankReducedPerm(ArrayRef< int64_t > shape, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
static std::optional< int64_t > getFirstResultIndexFunctionOf(AffineMap map, int64_t dim)
Return the index of the first result of map that is a function of AffineDimExpr(dim),...
Definition: Transforms.cpp:100
static FailureOr< SmallVector< std::optional< int64_t > > > packLinalgMetadataOnce(SmallVectorImpl< AffineMap > &indexingMaps, SmallVectorImpl< utils::IteratorType > &iteratorTypes, int64_t dim)
Perform one step of packing of a LinalgOp's metadata along dim into the newDim at iteratorTypes....
Definition: Transforms.cpp:147
static LinalgOp transposeOneLinalgOperandAndReplace(RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand, ArrayRef< int64_t > permutation, Value transposedValue)
Return a new GenericOp obtained by transposing opOperand by the permutation vector:
Definition: Transforms.cpp:611
static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim)
Return true if map has 0 or 1 result function of AffineDimExpr(dim).
Definition: Transforms.cpp:81
static SmallVector< int64_t > getPackUnpackNormalizedPerm(int rank, ArrayRef< int64_t > perm)
static Value getPackOpSourceOrPaddedSource(OpBuilder &builder, linalg::PackOp packOp)
If padding value is set, returns a tensor.pad Op for the source tensor, with the output shape matchin...
static std::string stringifyReassocIndices(ReassociationIndicesRef ri)
Definition: Transforms.cpp:94
Base type for affine expression.
Definition: AffineExpr.h:68
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
Definition: AffineExpr.cpp:314
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
MLIRContext * getContext() const
Definition: AffineMap.cpp:339
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
Definition: AffineMap.h:267
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
Definition: AffineMap.h:315
unsigned getNumDims() const
Definition: AffineMap.cpp:390
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
unsigned getNumResults() const
Definition: AffineMap.cpp:398
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:407
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:260
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:108
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:324
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:364
MLIRContext * getContext() const
Definition: Builders.h:56
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:128
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:525
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_range getResults()
Definition: Operation.h:415
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:230
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:257
Builder & setShape(ArrayRef< int64_t > newShape)
Definition: BuiltinTypes.h:241
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void moveOpBefore(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1469
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
FailureOr< PackTransposeResult > packTranspose(RewriterBase &rewriter, linalg::PackOp packOp, linalg::LinalgOp linalgOp, linalg::UnPackOp maybeUnPackOp, ArrayRef< int64_t > outerPerm, ArrayRef< int64_t > innerPerm)
Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the transposed PackOp -> LinalgOp ...
Definition: Transforms.cpp:657
FailureOr< LowerUnPackOpResult > lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice=true)
Rewrite pack as empty + transpose + reshape + extract_slice.
Definition: Transforms.cpp:346
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp)
Shell function to compute the Destination Permutation of PackOp This function uses the helper functio...
void peelLoops(RewriterBase &rewriter, ArrayRef< scf::ForOp > loops)
Peel 'loops' and applies affine_min/max bounds simplification on the fly where relevant.
Definition: Transforms.cpp:69
FailureOr< GenericOp > generalizeNamedOp(RewriterBase &rewriter, LinalgOp linalgOp)
Create a GenericOp from the given named operation linalgOp and replace the given linalgOp.
void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Linalg decompose convolutions patterns.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp)
Emit a suitable vector form for a Copy op with fully static shape.
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op.
Definition: Interchange.cpp:45
void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns)
Populates patterns to decompose linalg.pack and linalg.unpack Ops into e.g.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp unpackOp)
Shell function to compute the Source Permutation of unPackOp.
FailureOr< PackResult > packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, ArrayRef< OpFoldResult > mnkPackedSizes, ArrayRef< int64_t > mnkPaddedSizesNextMultipleOf, ArrayRef< int64_t > mnkOrder)
Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m and n are proper parallel d...
Definition: Transforms.cpp:748
FailureOr< PackResult > pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ArrayRef< OpFoldResult > packedSizes)
Implement packing of a single LinalgOp by packedSizes.
Definition: Transforms.cpp:464
SmallVector< Value > peelLoop(RewriterBase &rewriter, Operation *op)
Try to peel and canonicalize loop op and return the new result.
Definition: Transforms.cpp:53
void populateDecomposePadPatterns(RewritePatternSet &patterns)
Populates patterns to decompose tensor.pad into e.g.
FailureOr< LowerPackResult > lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice=true)
Rewrite pack as pad + reshape + transpose.
Definition: Transforms.cpp:217
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
LogicalResult peelForLoopAndSimplifyBounds(RewriterBase &rewriter, ForOp forOp, scf::ForOp &partialIteration)
Rewrite a for loop with bounds/step that potentially do not divide evenly into a for loop where the s...
FailureOr< TilingResult > bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, bool generateZeroSliceGuard=true)
Bubbles up a slice of this pad by taking the slice first and then performing the padding.
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, ValueRange dynOutDims={})
Definition: Utils.cpp:23
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Definition: TensorOps.cpp:3183
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
Definition: TensorOps.cpp:2778
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:61
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:356
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
ArrayRef< int64_t > ReassociationIndicesRef
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:1279
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::pair< int64_t, OpFoldResult > getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b)
Given OpFoldResult representing dim size value (*), generates a pair of sizes:
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult size
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter &rewriter) const override
Definition: Transforms.cpp:893
Rewrites a linalg::PackOp into a sequence of:
Definition: Transforms.h:1695
LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override
Rewrites a linalg::UnPackOp into a sequence of:
Definition: Transforms.h:1732
LogicalResult matchAndRewrite(linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const override
Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and InsertSliceOp.
Definition: Transforms.h:1645
LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override
Definition: Transforms.cpp:921
Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, Value dest, const SmallVector< Value > &dynSizes) const
Filling dest using FillOp constant padding value if possible.
Definition: Transforms.cpp:900
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) dimensions into 1-D depthwise c...
Definition: Transforms.h:1565
FailureOr< DepthwiseConv1DNwcWcOp > returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const
Rewrites 2-D convolution ops with size-1 window dimensions into 1-D convolution ops.
Definition: Transforms.h:1545
FailureOr< Conv1DOp > returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override
Definition: Transforms.cpp:972
LinalgTilingOptions & setTileSizes(const SmallVector< Value, 4 > &ts)
Set the tileSizeComputationFunction to return the values ts.
Definition: Transforms.h:204
Struct to hold the result of a pack call.
Definition: Transforms.h:1269
Struct to hold the result of a packTranspose call.
Definition: Transforms.h:1281