MLIR  22.0.0git
PackAndUnpackPatterns.cpp
Go to the documentation of this file.
1 //===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/PatternMatch.h"
15 
16 namespace mlir {
17 namespace linalg {
18 namespace {
19 
20 /// Returns the number of shape sizes that is either dynamic or greater than 1.
21 static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
22  return llvm::count_if(
23  shape, [](int64_t v) { return ShapedType::isDynamic(v) || v > 1; });
24 }
25 
26 /// Returns success() if there is only 1 dimension size in non-packed domain
27 /// being greater than 1 and packing only happens on the dimension.
28 /// Note: this method should only be used by pack/unpack to reshape conversion.
29 /// It assumes that non-unit inner tile size must be used by the non-unit
30 /// dimension.
31 static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
32  ArrayRef<int64_t> srcShape,
33  ArrayRef<int64_t> innerPackTileSize) {
34  if (getNumGtOneDims(srcShape) > 1) {
35  return rewriter.notifyMatchFailure(
36  op, "expects non-packed domain to have at most one non-unit dims");
37  }
38  // Non-unit inner tile size must be used by the non-unit dimension. If not, it
39  // will faill on getting reassociation maps.
40  if (getNumGtOneDims(innerPackTileSize) > 1) {
41  return rewriter.notifyMatchFailure(
42  op, "expects at most one non-unit inner tiles");
43  }
44  return success();
45 }
46 
47 // If the `linalgOp` represents a transpose, return the permutation vector for
48 // the transpose. Otherwise, return failure.
49 static FailureOr<SmallVector<int64_t>>
50 getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
51  if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation()))
52  return SmallVector<int64_t>(transposeOp.getPermutation());
53  if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
54  return failure();
55 
56  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
57  return failure();
58  auto mapRange = linalgOp.getIndexingMapsArray();
59  if (!mapRange.front().isPermutation() || !mapRange.back().isPermutation() ||
60  mapRange.front() == mapRange.back()) {
61  return failure();
62  }
63  if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
64  return failure();
65  AffineMap outMap = mapRange.back();
66  AffineMap inMap = mapRange.front();
67  // To get the permutation, look at each output index and find which
68  // dimension in the input we're reading from for that index.
69  return llvm::map_to_vector(outMap.getResults(),
70  [&](AffineExpr expr) -> int64_t {
71  return *inMap.getResultPosition(expr);
72  });
73 }
74 
75 /// Packing one-dimensional tensor can be expressed as an expand shape op.
76 struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
78 
79  FailureOr<Value>
80  insertExpand(RewriterBase &rewriter, Location loc, Value operand,
81  Type newOperandType,
82  ArrayRef<ReassociationIndices> reassociation) const {
83  if (operand.getType() == newOperandType)
84  return operand;
85  return tensor::ExpandShapeOp::create(rewriter, loc, newOperandType, operand,
86  reassociation)
87  .getResult();
88  }
89 
90  /// Returns success() if it is only packing on the innermost dimension.
91  LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter,
92  PackOp packOp) const {
93  auto outerDimsPerm = packOp.getOuterDimsPerm();
95  return rewriter.notifyMatchFailure(
96  packOp,
97  "expects outer_dims_perm is empty or an identity permutation");
98  }
99 
100  int64_t srcRank = packOp.getSourceRank();
101  ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
102  if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
103  return rewriter.notifyMatchFailure(
104  packOp, "expects packing at the innermost dimension");
105  }
106  return success();
107  }
108 
109  LogicalResult matchAndRewrite(PackOp packOp,
110  PatternRewriter &rewriter) const override {
111  if (packOp.getPaddingValue())
112  return rewriter.notifyMatchFailure(packOp, "expects no padding value");
113 
114  RankedTensorType sourceType = packOp.getSourceType();
115  if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
116  failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
117  packOp.getStaticTiles())) &&
118  !packOp.isLikePad()) {
119  return failure();
120  }
121 
122  RankedTensorType destType = packOp.getDestType();
123  auto reassociation =
124  getReassociationIndicesForReshape(sourceType, destType);
125  if (!reassociation)
126  return failure();
127  FailureOr<Value> expanded =
128  insertExpand(rewriter, packOp.getLoc(), packOp.getSource(), destType,
129  *reassociation);
130  if (failed(expanded)) {
131  return rewriter.notifyMatchFailure(
132  packOp, "unable to expand source of tensor.pack");
133  }
134  rewriter.replaceOp(packOp, *expanded);
135  return success();
136  }
137 };
138 
139 struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
141 
142  Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
143  Type newOperandType, ArrayAttr reassociation) const {
144  if (operand.getType() == newOperandType)
145  return operand;
146  return tensor::CollapseShapeOp::create(rewriter, loc, newOperandType,
147  operand, reassociation);
148  }
149 
150  /// Returns success() if it is unpacking on the innermost dimension.
151  LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter,
152  UnPackOp unpackOp) const {
153  auto outerDimsPerm = unpackOp.getOuterDimsPerm();
155  return rewriter.notifyMatchFailure(
156  unpackOp,
157  "expects outer_dims_perm is empty or an identity permutation");
158  }
159 
160  RankedTensorType sourceType = unpackOp.getSourceType();
161  RankedTensorType destType = unpackOp.getDestType();
162  if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
163  return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
164 
165  ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
166  if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
167  return rewriter.notifyMatchFailure(
168  unpackOp, "expects unpacking on the innermost dimension");
169  }
170 
171  return success();
172  }
173 
174  LogicalResult matchAndRewrite(UnPackOp unpackOp,
175  PatternRewriter &rewriter) const override {
176  RankedTensorType destType = unpackOp.getDestType();
177  if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
178  failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
179  unpackOp.getStaticTiles())) &&
180  !unpackOp.isLikeUnPad()) {
181  return failure();
182  }
183 
184  RankedTensorType sourceType = unpackOp.getSourceType();
185  auto reassociation =
186  getReassociationIndicesForReshape(sourceType, destType);
187  if (!reassociation)
188  return failure();
189  Value collapsed = insertCollapse(
190  rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
191  getReassociationIndicesAttribute(rewriter, *reassociation));
192  rewriter.replaceOp(unpackOp, collapsed);
193  return success();
194  }
195 };
196 
197 /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
198 /// the pad op has zero low paddings, or if `pack` has no padding values.
199 struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
200 public:
201  FoldPadWithPackOp(MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
202  : OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
203 
204  LogicalResult matchAndRewrite(PackOp packOp,
205  PatternRewriter &rewriter) const override {
206  auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
207 
208  if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
209  return failure();
210 
211  // User controlled folding function.
212  if (controlFn && !controlFn(&packOp.getSourceMutable()))
213  return failure();
214 
215  Value constantPaddingValue = padOp.getConstantPaddingValue();
216  if (!constantPaddingValue)
217  return failure();
218 
219  if (auto paddingValue = packOp.getPaddingValue())
220  if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
221  return failure();
222 
223  // Folding is not allowed if it were to introduce artificial padding.
224  // Folding is also disabled in the case of dynamic dimensions and/or tile
225  // sizes - that is because it would be impossible to compute the padding
226  // size and hence to establish whether "artificial" padding would be
227  // created.
228  RankedTensorType unpackedType = packOp.getSourceType();
229  SmallVector<int64_t> outerShapeWithoutTranspose =
231  for (auto [pos, tileSize, high] :
232  llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
233  padOp.getMixedHighPad())) {
234  if (unpackedType.isDynamicDim(pos))
235  return failure();
236  if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
237  return failure();
238  if (ShapedType::isDynamic(tileSize))
239  return failure();
240  std::optional<int64_t> cstHigh = getConstantIntValue(high);
241  if (!cstHigh)
242  return failure();
243  int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
244  unpackedType.getDimSize(pos);
245  // Do not fold the op if it requires artificial padding.
246  if (paddingSize + cstHigh.value() >= tileSize)
247  return failure();
248  }
249 
250  rewriter.replaceOpWithNewOp<PackOp>(
251  packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
252  packOp.getMixedTiles(), constantPaddingValue,
253  packOp.getOuterDimsPerm());
254  return success();
255  }
256 
257 private:
258  ControlFoldIntoPackUnpackFn controlFn;
259 };
260 
261 /// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
262 /// has extract_slice semantics.
263 struct FoldUnpackWithExtractSliceOp
264  : public OpRewritePattern<tensor::ExtractSliceOp> {
265 public:
266  FoldUnpackWithExtractSliceOp(MLIRContext *context,
267  ControlFoldIntoPackUnpackFn controlFn)
268  : OpRewritePattern<tensor::ExtractSliceOp>(context),
269  controlFn(std::move(controlFn)) {}
270 
271  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
272  PatternRewriter &rewriter) const override {
273  auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
274  if (!unpackOp)
275  return failure();
276 
277  // User controlled folding function.
278  if (controlFn && !controlFn(&sliceOp.getSourceMutable()))
279  return failure();
280 
281  if (!unpackOp.canFoldSliceOp(sliceOp))
282  return failure();
283 
284  // Create a new empty output tensor.
285  Type elementType = unpackOp.getDestType().getElementType();
286  Value output = tensor::EmptyOp::create(
287  rewriter, sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
288  rewriter.replaceOpWithNewOp<UnPackOp>(
289  sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
290  unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
291  return success();
292  }
293 
294 private:
295  ControlFoldIntoPackUnpackFn controlFn;
296 };
297 
298 // Applies 'permutation' on 'inVec' and stores the result in resVec.
299 // 'inVec' may be empty, in that case it's one-to-one mapping with permutation.
300 // `rank` sets the boundary for permutation i.e., the permutation dim can't be
301 // greater than the rank specified. If it's so then return false.
302 // For e.g., permutation {1, 0, 3, 2} with rank 2 is allowed since the values in
303 // permutation[:rank] doesn't exceed rank, whereas, permutation {1, 3, 0, 2} is
304 // not allowed since `3` exceeds the value of the rank in the given range.
305 static bool checkAndPermute(ArrayRef<int64_t> permutation,
306  ArrayRef<int64_t> inVec,
307  SmallVectorImpl<int64_t> &resVec, int64_t rank) {
308 
309  for (unsigned int i = 0; i < rank; ++i) {
310  int64_t remappedPosition = permutation[i];
311  if (remappedPosition >= rank)
312  return false;
313  if (!inVec.empty())
314  remappedPosition = inVec[remappedPosition];
315  resVec.push_back(remappedPosition);
316  }
317 
318  return true;
319 }
320 
321 /// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
322 /// semantics.
323 struct FoldProducerPackWithConsumerLinalgTransposeOp
324  : public OpInterfaceRewritePattern<linalg::LinalgOp> {
325 
326 public:
327  FoldProducerPackWithConsumerLinalgTransposeOp(
328  MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
329  : OpInterfaceRewritePattern<linalg::LinalgOp>(context),
330  controlFn(std::move(controlFn)) {}
331 
332  LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
333  PatternRewriter &rewriter) const override {
334  auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>();
335 
336  if (!packOp)
337  return failure();
338 
339  // User controlled folding function.
340  if (controlFn && !controlFn(&linalgOp->getOpOperand(0)))
341  return failure();
342 
343  FailureOr<SmallVector<int64_t>> maybePerm =
344  getTransposeOpPermutation(linalgOp);
345  if (failed(maybePerm))
346  return failure();
347 
348  auto innerDimsPos = packOp.getInnerDimsPos();
349  auto mixedInnerTiles = packOp.getMixedTiles();
350  auto outerDimsPerm = packOp.getOuterDimsPerm();
351  auto transposePerm = maybePerm.value();
352  SmallVector<int64_t> newOuterDimsPermVec;
353  SmallVector<int64_t> newInnerDimsPosVec;
354  SmallVector<OpFoldResult> newMixedInnerTilesVec;
355  int64_t srcRank = packOp.getSourceRank();
356 
357  if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
358  srcRank))
359  return rewriter.notifyMatchFailure(
360  linalgOp,
361  "Cannot fold in tensor.pack if a tile dimension was transposed "
362  "with a non-tile dimension in linalg.transpose.");
363 
364  // Process transpose operation for tiled inner dimensions
365  for (unsigned int i = srcRank; i < transposePerm.size(); ++i) {
366  int64_t remappedPosition = transposePerm[i] - srcRank;
367  newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]);
368  newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
369  }
370 
371  Value output = packOp.createDestinationTensor(
372  rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
373  newInnerDimsPosVec, newOuterDimsPermVec);
374 
375  rewriter.replaceOpWithNewOp<PackOp>(
376  linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
377  newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
378 
379  return success();
380  }
381 
382 private:
383  ControlFoldIntoPackUnpackFn controlFn;
384 };
385 
386 /// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
387 /// semantics.
388 struct FoldConsumerPackWithProducerLinalgTransposeOp
389  : public OpRewritePattern<PackOp> {
390 
391 public:
392  FoldConsumerPackWithProducerLinalgTransposeOp(
393  MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
394  : OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
395 
396  LogicalResult matchAndRewrite(PackOp packOp,
397  PatternRewriter &rewriter) const override {
398  auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
399  if (!linalgOp)
400  return failure();
401 
402  // User controlled folding function.
403  if (controlFn && !controlFn(&packOp.getSourceMutable()))
404  return failure();
405 
406  FailureOr<SmallVector<int64_t>> maybePerm =
407  getTransposeOpPermutation(linalgOp);
408  if (failed(maybePerm))
409  return failure();
410 
411  auto transposePermutation = maybePerm.value();
412  auto outerDimsPerm = packOp.getOuterDimsPerm();
413  auto innerDimsPos = packOp.getInnerDimsPos();
414  SmallVector<int64_t> newInnerDimsPosVec;
415  SmallVector<int64_t> newOuterDimsPermVec =
416  llvm::to_vector(transposePermutation);
417 
418  if (!outerDimsPerm.empty())
419  applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
420 
421  // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
422  // permutation rank won't necessarily be equal in all cases.
423  for (auto dim : innerDimsPos)
424  newInnerDimsPosVec.push_back(transposePermutation[dim]);
425 
426  Value output = packOp.createDestinationTensor(
427  rewriter, packOp.getLoc(), linalgOp->getOperand(0),
428  packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
429 
430  rewriter.replaceOpWithNewOp<PackOp>(
431  packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
432  packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
433 
434  return success();
435  }
436 
437 private:
438  ControlFoldIntoPackUnpackFn controlFn;
439 };
440 
441 /// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
442 /// transpose semantics.
443 struct FoldProducerUnPackWithConsumerLinalgTransposeOp
444  : public OpInterfaceRewritePattern<linalg::LinalgOp> {
445 
446 public:
447  FoldProducerUnPackWithConsumerLinalgTransposeOp(
448  MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
449  : OpInterfaceRewritePattern<linalg::LinalgOp>(context),
450  controlFn(std::move(controlFn)) {}
451 
452  LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
453  PatternRewriter &rewriter) const override {
454  auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
455 
456  if (!unPackOp)
457  return failure();
458 
459  // User controlled folding function.
460  if (controlFn && !controlFn(&linalgOp->getOpOperand(0)))
461  return failure();
462 
463  FailureOr<SmallVector<int64_t>> maybePerm =
464  getTransposeOpPermutation(linalgOp);
465  if (failed(maybePerm))
466  return failure();
467 
468  auto outerDimsPerm = unPackOp.getOuterDimsPerm();
469  auto innerDimsPos = unPackOp.getInnerDimsPos();
470  SmallVector<int64_t> newInnerDimsPosVec;
471  SmallVector<int64_t> newOuterDimsPermVec =
472  invertPermutationVector(maybePerm.value());
473 
474  // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
475  // permutation rank won't necessarily be equal in all cases.
476  for (auto dim : innerDimsPos)
477  newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]);
478 
479  if (!outerDimsPerm.empty())
480  applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
481 
482  // Reuse the destination of the transpose op.
483  rewriter.replaceOpWithNewOp<UnPackOp>(
484  linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
485  newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
486 
487  return success();
488  }
489 
490 private:
491  ControlFoldIntoPackUnpackFn controlFn;
492 };
493 
494 /// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
495 /// transpose semantics.
496 struct FoldConsumerUnPackWithProducerLinalgTransposeOp
497  : public OpRewritePattern<UnPackOp> {
499 
500 public:
501  FoldConsumerUnPackWithProducerLinalgTransposeOp(
502  MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
503  : OpRewritePattern<UnPackOp>(context), controlFn(std::move(controlFn)) {}
504 
505  LogicalResult matchAndRewrite(UnPackOp unPackOp,
506  PatternRewriter &rewriter) const override {
507  auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
508  if (!linalgOp)
509  return failure();
510 
511  // User controlled folding function.
512  if (controlFn && !controlFn(&unPackOp.getSourceMutable()))
513  return failure();
514 
515  FailureOr<SmallVector<int64_t>> maybePerm =
516  getTransposeOpPermutation(linalgOp);
517  if (failed(maybePerm))
518  return failure();
519 
520  SmallVector<SmallVector<OpFoldResult>> unpackOpResultDims;
521  if (failed(reifyResultShapes(rewriter, unPackOp, unpackOpResultDims))) {
522  return failure();
523  }
524 
525  SmallVector<int64_t> inverseTransposePerm =
526  invertPermutationVector(maybePerm.value());
527  auto outerDimsPerm = unPackOp.getOuterDimsPerm();
528  auto innerDimsPos = unPackOp.getInnerDimsPos();
529  int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
530  auto mixedInnerTilesVec = unPackOp.getMixedTiles();
531  SmallVector<int64_t> newOuterDimsPermVec;
532  SmallVector<int64_t> newInnerDimsPosVec;
533  SmallVector<OpFoldResult> newMixedInnerTilesVec;
534  if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
535  newOuterDimsPermVec, destRank))
536  return rewriter.notifyMatchFailure(
537  unPackOp,
538  "Cannot fold in tensor.unpack if a tile dimension was transposed "
539  "with a non-tile dimension in linalg.transpose.");
540 
541  // Process transpose operation for tiled inner dimensions
542  for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
543  int64_t remappedPosition = inverseTransposePerm[i] - destRank;
544  newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
545  newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
546  }
547 
548  auto elemType =
549  cast<ShapedType>(unPackOp->getResultTypes()[0]).getElementType();
550  Value output = tensor::EmptyOp::create(rewriter, unPackOp->getLoc(),
551  unpackOpResultDims[0], elemType);
552 
553  rewriter.replaceOpWithNewOp<UnPackOp>(
554  unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
555  newMixedInnerTilesVec, newOuterDimsPermVec);
556 
557  return success();
558  }
559 
560 private:
561  ControlFoldIntoPackUnpackFn controlFn;
562 };
563 
564 /// tensor.empty does not define any tensor contents, so an unpadded pack
565 /// can be folded away.
566 struct FoldEmptyTensorWithPackOp : public OpRewritePattern<PackOp> {
568 
569  LogicalResult matchAndRewrite(PackOp packOp,
570  PatternRewriter &rewriter) const override {
571  // Check for tensor.empty source.
572  auto emptyOp = packOp.getSource().getDefiningOp<tensor::EmptyOp>();
573  if (!emptyOp)
574  return failure();
575 
576  // Check for padding.
577  // Packing with padding cannot be simply removed.
578  if (packOp.getPaddingValue())
579  return rewriter.notifyMatchFailure(packOp, "expects no padding value");
580 
581  // Replace the pack directly with its destination.
582  rewriter.replaceOp(packOp, packOp.getDest());
583 
584  return success();
585  }
586 };
587 
588 /// tensor.empty does not define any tensor contents, so an unpack
589 /// can be folded away.
590 struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
592 
593  LogicalResult matchAndRewrite(UnPackOp unPackOp,
594  PatternRewriter &rewriter) const override {
595  // Check for tensor.empty source.
596  auto emptyOp = unPackOp.getSource().getDefiningOp<tensor::EmptyOp>();
597  if (!emptyOp)
598  return failure();
599 
600  // Replace the unpack directly with its destination.
601  rewriter.replaceOp(unPackOp, unPackOp.getDest());
602 
603  return success();
604  }
605 };
606 
607 } // namespace
608 
611  patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
612  FoldProducerPackWithConsumerLinalgTransposeOp,
613  FoldConsumerPackWithProducerLinalgTransposeOp,
614  FoldConsumerUnPackWithProducerLinalgTransposeOp,
615  FoldProducerUnPackWithConsumerLinalgTransposeOp>(
616  patterns.getContext(), controlFn);
617 }
618 
620  patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
621  patterns.getContext());
622 }
623 
626  patterns.add<FoldEmptyTensorWithPackOp, FoldEmptyTensorWithUnPackOp>(
627  patterns.getContext());
628 }
629 
630 } // namespace linalg
631 } // namespace mlir
SmallVector< int64_t > outerDimsPerm
Definition: LinalgOps.cpp:5181
SmallVector< int64_t > innerDimsPos
Definition: LinalgOps.cpp:5179
std::function< bool(OpOperand *opOperand)> ControlFoldIntoPackUnpackFn
Function type which is used to control folding operations like tensor.pad and tensor....
Definition: Transforms.h:2027
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that simplify tensor.pack and tensor.unpack operations.
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn=nullptr)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
Definition: LinalgOps.cpp:4959
@ Type
An inlay hint that for a type annotation.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
const FrozenRewritePatternSet & patterns
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319