MLIR  20.0.0git
PackAndUnpackPatterns.cpp
Go to the documentation of this file.
1 //===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/PatternMatch.h"
14 
15 namespace mlir {
16 namespace tensor {
17 namespace {
18 
19 /// Returns the number of shape sizes that is either dynamic or greater than 1.
20 static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
21  return llvm::count_if(
22  shape, [](int64_t v) { return ShapedType::isDynamic(v) || v > 1; });
23 }
24 
25 /// Returns success() if there is only 1 dimension size in non-packed domain
26 /// being greater than 1 and packing only happens on the dimension.
27 /// Note: this method should only be used by pack/unpack to reshape conversion.
28 /// It assumes that non-unit inner tile size must be used by the non-unit
29 /// dimension.
30 static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
31  ArrayRef<int64_t> srcShape,
32  ArrayRef<int64_t> innerPackTileSize) {
33  if (getNumGtOneDims(srcShape) > 1) {
34  return rewriter.notifyMatchFailure(
35  op, "expects non-packed domain to have at most one non-unit dims");
36  }
37  // Non-unit inner tile size must be used by the non-unit dimension. If not, it
38  // will faill on getting reassociation maps.
39  if (getNumGtOneDims(innerPackTileSize) > 1) {
40  return rewriter.notifyMatchFailure(
41  op, "expects at most one non-unit inner tiles");
42  }
43  return success();
44 }
45 
46 // If the `linalgOp` represents a transpose, return the permutation vector for
47 // the transpose. Otherwise, return failure.
48 static FailureOr<SmallVector<int64_t>>
49 getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
50  if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation()))
51  return SmallVector<int64_t>(transposeOp.getPermutation());
52  if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
53  return failure();
54 
55  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
56  return failure();
57  auto mapRange = linalgOp.getIndexingMapsArray();
58  if (!mapRange.front().isPermutation() || !mapRange.back().isPermutation() ||
59  mapRange.front() == mapRange.back()) {
60  return failure();
61  }
62  if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
63  return failure();
64  AffineMap outMap = mapRange.back();
65  AffineMap inMap = mapRange.front();
66  // To get the permutation, look at each output index and find which
67  // dimension in the input we're reading from for that index.
68  return llvm::map_to_vector(outMap.getResults(),
69  [&](AffineExpr expr) -> int64_t {
70  return *inMap.getResultPosition(expr);
71  });
72 }
73 
74 /// Packing one-dimensional tensor can be expressed as an expand shape op.
75 struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
77 
78  FailureOr<Value>
79  insertExpand(RewriterBase &rewriter, Location loc, Value operand,
80  Type newOperandType,
81  ArrayRef<ReassociationIndices> reassociation) const {
82  if (operand.getType() == newOperandType)
83  return operand;
84  return rewriter
85  .create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
86  reassociation)
87  .getResult();
88  }
89 
90  /// Returns success() if it is only packing on the innermost dimension.
91  LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter,
92  PackOp packOp) const {
93  auto outerDimsPerm = packOp.getOuterDimsPerm();
94  if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
95  return rewriter.notifyMatchFailure(
96  packOp,
97  "expects outer_dims_perm is empty or an identity permutation");
98  }
99 
100  int64_t srcRank = packOp.getSourceRank();
101  ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
102  if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
103  return rewriter.notifyMatchFailure(
104  packOp, "expects packing at the innermost dimension");
105  }
106  return success();
107  }
108 
109  LogicalResult matchAndRewrite(PackOp packOp,
110  PatternRewriter &rewriter) const override {
111  if (packOp.getPaddingValue())
112  return rewriter.notifyMatchFailure(packOp, "expects no padding value");
113 
114  RankedTensorType sourceType = packOp.getSourceType();
115  if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
116  failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
117  packOp.getStaticTiles())) &&
118  !packOp.isLikePad()) {
119  return failure();
120  }
121 
122  RankedTensorType destType = packOp.getDestType();
123  auto reassociation =
124  getReassociationIndicesForReshape(sourceType, destType);
125  if (!reassociation)
126  return failure();
127  FailureOr<Value> expanded =
128  insertExpand(rewriter, packOp.getLoc(), packOp.getSource(), destType,
129  *reassociation);
130  if (failed(expanded)) {
131  return rewriter.notifyMatchFailure(
132  packOp, "unable to expand source of tensor.pack");
133  }
134  rewriter.replaceOp(packOp, *expanded);
135  return success();
136  }
137 };
138 
139 struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
141 
142  Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
143  Type newOperandType, ArrayAttr reassociation) const {
144  if (operand.getType() == newOperandType)
145  return operand;
146  return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType,
147  operand, reassociation);
148  }
149 
150  /// Returns success() if it is unpacking on the innermost dimension.
151  LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter,
152  UnPackOp unpackOp) const {
153  auto outerDimsPerm = unpackOp.getOuterDimsPerm();
154  if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
155  return rewriter.notifyMatchFailure(
156  unpackOp,
157  "expects outer_dims_perm is empty or an identity permutation");
158  }
159 
160  RankedTensorType sourceType = unpackOp.getSourceType();
161  RankedTensorType destType = unpackOp.getDestType();
162  if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
163  return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
164 
165  ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
166  if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
167  return rewriter.notifyMatchFailure(
168  unpackOp, "expects unpacking on the innermost dimension");
169  }
170 
171  return success();
172  }
173 
174  LogicalResult matchAndRewrite(UnPackOp unpackOp,
175  PatternRewriter &rewriter) const override {
176  RankedTensorType destType = unpackOp.getDestType();
177  if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
178  failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
179  unpackOp.getStaticTiles())) &&
180  !unpackOp.isLikeUnPad()) {
181  return failure();
182  }
183 
184  RankedTensorType sourceType = unpackOp.getSourceType();
185  auto reassociation =
186  getReassociationIndicesForReshape(sourceType, destType);
187  if (!reassociation)
188  return failure();
189  Value collapsed = insertCollapse(
190  rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
191  getReassociationIndicesAttribute(rewriter, *reassociation));
192  rewriter.replaceOp(unpackOp, collapsed);
193  return success();
194  }
195 };
196 
197 /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
198 /// the pad op has zero low paddings, or if `pack` has no padding values.
199 struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
201 
202  LogicalResult matchAndRewrite(PackOp packOp,
203  PatternRewriter &rewriter) const override {
204  auto padOp = packOp.getSource().getDefiningOp<PadOp>();
205 
206  if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
207  return failure();
208 
209  Value constantPaddingValue = padOp.getConstantPaddingValue();
210  if (!constantPaddingValue)
211  return failure();
212 
213  if (auto paddingValue = packOp.getPaddingValue())
214  if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
215  return failure();
216 
217  rewriter.replaceOpWithNewOp<PackOp>(
218  packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
219  packOp.getMixedTiles(), constantPaddingValue,
220  packOp.getOuterDimsPerm());
221  return success();
222  }
223 };
224 
225 /// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
226 /// has extract_slice semantics.
227 struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
229 
230  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
231  PatternRewriter &rewriter) const override {
232  auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
233  if (!unpackOp)
234  return failure();
235 
236  if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
237  return rewriter.notifyMatchFailure(
238  sliceOp, "rank-reduced folding is not supported");
239  }
240 
241  // Check all offsets are zeros, and all strides are ones.
242  if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
243  !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
244  return rewriter.notifyMatchFailure(
245  sliceOp, "expects offsets to be 0s and strides to be 1s");
246  }
247 
248  // Create a new empty output tensor.
249  Type elementType = unpackOp.getDestType().getElementType();
250  Value output = rewriter.create<EmptyOp>(
251  sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
252  rewriter.replaceOpWithNewOp<UnPackOp>(
253  sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
254  unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
255  return success();
256  }
257 };
258 
259 // Applies 'permutation' on 'inVec' and stores the result in resVec.
260 // 'inVec' may be empty, in that case it's one-to-one mapping with permutation.
261 // `rank` sets the boundary for permutation i.e., the permutation dim can't be
262 // greater than the rank specified. If it's so then return false.
263 // For e.g., permutation {1, 0, 3, 2} with rank 2 is allowed since the values in
264 // permutation[:rank] doesn't exceed rank, whereas, permutation {1, 3, 0, 2} is
265 // not allowed since `3` exceeds the value of the rank in the given range.
266 static bool checkAndPermute(ArrayRef<int64_t> permutation,
267  ArrayRef<int64_t> inVec,
268  SmallVectorImpl<int64_t> &resVec, int64_t rank) {
269 
270  for (unsigned int i = 0; i < rank; ++i) {
271  int64_t remappedPosition = permutation[i];
272  if (remappedPosition >= rank)
273  return false;
274  if (!inVec.empty())
275  remappedPosition = inVec[remappedPosition];
276  resVec.push_back(remappedPosition);
277  }
278 
279  return true;
280 }
281 
282 /// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
283 /// semantics.
284 struct FoldProducerPackWithConsumerLinalgTransposeOp
285  : public OpInterfaceRewritePattern<linalg::LinalgOp> {
287 
288  LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
289  PatternRewriter &rewriter) const override {
290  auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>();
291 
292  if (!packOp)
293  return failure();
294 
295  FailureOr<SmallVector<int64_t>> maybePerm =
296  getTransposeOpPermutation(linalgOp);
297  if (failed(maybePerm))
298  return failure();
299 
300  auto innerDimsPos = packOp.getInnerDimsPos();
301  auto mixedInnerTiles = packOp.getMixedTiles();
302  auto outerDimsPerm = packOp.getOuterDimsPerm();
303  auto transposePerm = maybePerm.value();
304  SmallVector<int64_t> newOuterDimsPermVec;
305  SmallVector<int64_t> newInnerDimsPosVec;
306  SmallVector<OpFoldResult> newMixedInnerTilesVec;
307  int64_t srcRank = packOp.getSourceRank();
308 
309  if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
310  srcRank))
311  return rewriter.notifyMatchFailure(
312  linalgOp,
313  "Cannot fold in tensor.pack if a tile dimension was transposed "
314  "with a non-tile dimension in linalg.transpose.");
315 
316  // Process transpose operation for tiled inner dimensions
317  for (unsigned int i = srcRank; i < transposePerm.size(); ++i) {
318  int64_t remappedPosition = transposePerm[i] - srcRank;
319  newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]);
320  newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
321  }
322 
323  Value output = packOp.createDestinationTensor(
324  rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
325  newInnerDimsPosVec, newOuterDimsPermVec);
326 
327  rewriter.replaceOpWithNewOp<PackOp>(
328  linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
329  newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
330 
331  return success();
332  }
333 };
334 
335 /// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
336 /// semantics.
337 struct FoldConsumerPackWithProducerLinalgTransposeOp
338  : public OpRewritePattern<PackOp> {
340 
341  LogicalResult matchAndRewrite(PackOp packOp,
342  PatternRewriter &rewriter) const override {
343  auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
344  if (!linalgOp)
345  return failure();
346 
347  FailureOr<SmallVector<int64_t>> maybePerm =
348  getTransposeOpPermutation(linalgOp);
349  if (failed(maybePerm))
350  return failure();
351 
352  auto transposePermutation = maybePerm.value();
353  auto outerDimsPerm = packOp.getOuterDimsPerm();
354  auto innerDimsPos = packOp.getInnerDimsPos();
355  SmallVector<int64_t> newInnerDimsPosVec;
356  SmallVector<int64_t> newOuterDimsPermVec =
357  llvm::to_vector(transposePermutation);
358 
359  if (!outerDimsPerm.empty())
360  applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
361 
362  // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
363  // permutation rank won't necessarily be equal in all cases.
364  for (auto dim : innerDimsPos)
365  newInnerDimsPosVec.push_back(transposePermutation[dim]);
366 
367  Value output = packOp.createDestinationTensor(
368  rewriter, packOp.getLoc(), linalgOp->getOperand(0),
369  packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
370 
371  rewriter.replaceOpWithNewOp<PackOp>(
372  packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
373  packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
374 
375  return success();
376  }
377 };
378 
379 /// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
380 /// transpose semantics.
381 struct FoldProducerUnPackWithConsumerLinalgTransposeOp
382  : public OpInterfaceRewritePattern<linalg::LinalgOp> {
384 
385  LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
386  PatternRewriter &rewriter) const override {
387  auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
388 
389  if (!unPackOp)
390  return failure();
391 
392  FailureOr<SmallVector<int64_t>> maybePerm =
393  getTransposeOpPermutation(linalgOp);
394  if (failed(maybePerm))
395  return failure();
396 
397  auto outerDimsPerm = unPackOp.getOuterDimsPerm();
398  auto innerDimsPos = unPackOp.getInnerDimsPos();
399  SmallVector<int64_t> newInnerDimsPosVec;
400  SmallVector<int64_t> newOuterDimsPermVec =
401  invertPermutationVector(maybePerm.value());
402 
403  // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
404  // permutation rank won't necessarily be equal in all cases.
405  for (auto dim : innerDimsPos)
406  newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]);
407 
408  if (!outerDimsPerm.empty())
409  applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
410 
411  // Reuse the destination of the transpose op.
412  rewriter.replaceOpWithNewOp<UnPackOp>(
413  linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
414  newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
415 
416  return success();
417  }
418 };
419 
420 /// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
421 /// transpose semantics.
422 struct FoldConsumerUnPackWithProducerLinalgTransposeOp
423  : public OpRewritePattern<UnPackOp> {
425 
426  LogicalResult matchAndRewrite(UnPackOp unPackOp,
427  PatternRewriter &rewriter) const override {
428  auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
429  if (!linalgOp)
430  return failure();
431 
432  FailureOr<SmallVector<int64_t>> maybePerm =
433  getTransposeOpPermutation(linalgOp);
434  if (failed(maybePerm))
435  return failure();
436 
437  SmallVector<SmallVector<OpFoldResult>> unpackOpResultDims;
438  if (failed(reifyResultShapes(rewriter, unPackOp, unpackOpResultDims))) {
439  return failure();
440  }
441 
442  SmallVector<int64_t> inverseTransposePerm =
443  invertPermutationVector(maybePerm.value());
444  auto outerDimsPerm = unPackOp.getOuterDimsPerm();
445  auto innerDimsPos = unPackOp.getInnerDimsPos();
446  int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
447  auto mixedInnerTilesVec = unPackOp.getMixedTiles();
448  SmallVector<int64_t> newOuterDimsPermVec;
449  SmallVector<int64_t> newInnerDimsPosVec;
450  SmallVector<OpFoldResult> newMixedInnerTilesVec;
451  if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
452  newOuterDimsPermVec, destRank))
453  return rewriter.notifyMatchFailure(
454  unPackOp,
455  "Cannot fold in tensor.unpack if a tile dimension was transposed "
456  "with a non-tile dimension in linalg.transpose.");
457 
458  // Process transpose operation for tiled inner dimensions
459  for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
460  int64_t remappedPosition = inverseTransposePerm[i] - destRank;
461  newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
462  newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
463  }
464 
465  auto elemType =
466  cast<ShapedType>(unPackOp->getResultTypes()[0]).getElementType();
467  Value output = rewriter.create<tensor::EmptyOp>(
468  unPackOp->getLoc(), unpackOpResultDims[0], elemType);
469 
470  rewriter.replaceOpWithNewOp<UnPackOp>(
471  unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
472  newMixedInnerTilesVec, newOuterDimsPermVec);
473 
474  return success();
475  }
476 };
477 } // namespace
478 
480  patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
481  FoldProducerPackWithConsumerLinalgTransposeOp,
482  FoldConsumerPackWithProducerLinalgTransposeOp,
483  FoldConsumerUnPackWithProducerLinalgTransposeOp,
484  FoldProducerUnPackWithConsumerLinalgTransposeOp>(
485  patterns.getContext());
486 }
487 
489  patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
490  patterns.getContext());
491 }
492 
493 } // namespace tensor
494 } // namespace mlir
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:937
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
@ Type
An inlay hint that for a type annotation.
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that simplify tensor.pack and tensor.unpack operations.
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:374
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362