MLIR  20.0.0git
PackAndUnpackPatterns.cpp
Go to the documentation of this file.
1 //===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/PatternMatch.h"
14 
15 namespace mlir {
16 namespace tensor {
17 namespace {
18 
19 static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
20  return llvm::all_of(
21  ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
22 }
23 
24 /// Returns the number of shape sizes that is either dynamic or greater than 1.
25 static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
26  return llvm::count_if(
27  shape, [](int64_t v) { return ShapedType::isDynamic(v) || v > 1; });
28 }
29 
30 /// Returns success() if there is only 1 dimension size in non-packed domain
31 /// being greater than 1 and packing only happens on the dimension.
32 /// Note: this method should only be used by pack/unpack to reshape conversion.
33 /// It assumes that non-unit inner tile size must be used by the non-unit
34 /// dimension.
35 static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
36  ArrayRef<int64_t> srcShape,
37  ArrayRef<int64_t> innerPackTileSize) {
38  if (getNumGtOneDims(srcShape) > 1) {
39  return rewriter.notifyMatchFailure(
40  op, "expects non-packed domain to have at most one non-unit dims");
41  }
42  // Non-unit inner tile size must be used by the non-unit dimension. If not, it
43  // will faill on getting reassociation maps.
44  if (getNumGtOneDims(innerPackTileSize) > 1) {
45  return rewriter.notifyMatchFailure(
46  op, "expects at most one non-unit inner tiles");
47  }
48  return success();
49 }
50 
51 // If the `linalgOp` represents a transpose, return the permutation vector for
52 // the transpose. Otherwise, return failure.
53 static FailureOr<SmallVector<int64_t>>
54 getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
55  if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation()))
56  return SmallVector<int64_t>(transposeOp.getPermutation());
57  if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
58  return failure();
59 
60  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
61  return failure();
62  auto mapRange = linalgOp.getIndexingMapsArray();
63  if (!mapRange.front().isPermutation() || !mapRange.back().isPermutation() ||
64  mapRange.front() == mapRange.back()) {
65  return failure();
66  }
67  if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
68  return failure();
69  AffineMap outMap = mapRange.back();
70  AffineMap inMap = mapRange.front();
71  // To get the permutation, look at each output index and find which
72  // dimension in the input we're reading from for that index.
73  return llvm::map_to_vector(outMap.getResults(),
74  [&](AffineExpr expr) -> int64_t {
75  return *inMap.getResultPosition(expr);
76  });
77 }
78 
79 /// Packing one-dimensional tensor can be expressed as an expand shape op.
80 struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
82 
83  FailureOr<Value>
84  insertExpand(RewriterBase &rewriter, Location loc, Value operand,
85  Type newOperandType,
86  ArrayRef<ReassociationIndices> reassociation) const {
87  if (operand.getType() == newOperandType)
88  return operand;
89  return rewriter
90  .create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
91  reassociation)
92  .getResult();
93  }
94 
95  /// Returns success() if it is only packing on the innermost dimension.
96  LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter,
97  PackOp packOp) const {
98  auto outerDimsPerm = packOp.getOuterDimsPerm();
99  if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
100  return rewriter.notifyMatchFailure(
101  packOp,
102  "expects outer_dims_perm is empty or an identity permutation");
103  }
104 
105  int64_t srcRank = packOp.getSourceRank();
106  ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
107  if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
108  return rewriter.notifyMatchFailure(
109  packOp, "expects packing at the innermost dimension");
110  }
111  return success();
112  }
113 
114  LogicalResult matchAndRewrite(PackOp packOp,
115  PatternRewriter &rewriter) const override {
116  if (packOp.getPaddingValue())
117  return rewriter.notifyMatchFailure(packOp, "expects no padding value");
118 
119  RankedTensorType sourceType = packOp.getSourceType();
120  if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
121  failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
122  packOp.getStaticTiles())) &&
123  !packOp.isLikePad()) {
124  return failure();
125  }
126 
127  RankedTensorType destType = packOp.getDestType();
128  auto reassociation =
129  getReassociationIndicesForReshape(sourceType, destType);
130  if (!reassociation)
131  return failure();
132  FailureOr<Value> expanded =
133  insertExpand(rewriter, packOp.getLoc(), packOp.getSource(), destType,
134  *reassociation);
135  if (failed(expanded)) {
136  return rewriter.notifyMatchFailure(
137  packOp, "unable to expand source of tensor.pack");
138  }
139  rewriter.replaceOp(packOp, *expanded);
140  return success();
141  }
142 };
143 
144 struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
146 
147  Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
148  Type newOperandType, ArrayAttr reassociation) const {
149  if (operand.getType() == newOperandType)
150  return operand;
151  return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType,
152  operand, reassociation);
153  }
154 
155  /// Returns success() if it is unpacking on the innermost dimension.
156  LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter,
157  UnPackOp unpackOp) const {
158  auto outerDimsPerm = unpackOp.getOuterDimsPerm();
159  if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
160  return rewriter.notifyMatchFailure(
161  unpackOp,
162  "expects outer_dims_perm is empty or an identity permutation");
163  }
164 
165  RankedTensorType sourceType = unpackOp.getSourceType();
166  RankedTensorType destType = unpackOp.getDestType();
167  if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
168  return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
169 
170  ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
171  if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
172  return rewriter.notifyMatchFailure(
173  unpackOp, "expects unpacking on the innermost dimension");
174  }
175 
176  return success();
177  }
178 
179  LogicalResult matchAndRewrite(UnPackOp unpackOp,
180  PatternRewriter &rewriter) const override {
181  RankedTensorType destType = unpackOp.getDestType();
182  if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
183  failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
184  unpackOp.getStaticTiles())) &&
185  !unpackOp.isLikeUnPad()) {
186  return failure();
187  }
188 
189  RankedTensorType sourceType = unpackOp.getSourceType();
190  auto reassociation =
191  getReassociationIndicesForReshape(sourceType, destType);
192  if (!reassociation)
193  return failure();
194  Value collapsed = insertCollapse(
195  rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
196  getReassociationIndicesAttribute(rewriter, *reassociation));
197  rewriter.replaceOp(unpackOp, collapsed);
198  return success();
199  }
200 };
201 
202 /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
203 /// the pad op has zero low paddings, or if `pack` has no padding values.
204 struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
206 
207  LogicalResult matchAndRewrite(PackOp packOp,
208  PatternRewriter &rewriter) const override {
209  auto padOp = packOp.getSource().getDefiningOp<PadOp>();
210 
211  if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
212  return failure();
213 
214  Value constantPaddingValue = padOp.getConstantPaddingValue();
215  if (!constantPaddingValue)
216  return failure();
217 
218  if (auto paddingValue = packOp.getPaddingValue())
219  if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
220  return failure();
221 
222  rewriter.replaceOpWithNewOp<PackOp>(
223  packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
224  packOp.getMixedTiles(), constantPaddingValue,
225  packOp.getOuterDimsPerm());
226  return success();
227  }
228 };
229 
230 /// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
231 /// has extract_slice semantics.
232 struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
234 
235  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
236  PatternRewriter &rewriter) const override {
237  auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
238  if (!unpackOp)
239  return failure();
240 
241  if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
242  return rewriter.notifyMatchFailure(
243  sliceOp, "rank-reduced folding is not supported");
244  }
245 
246  // Check all offsets are zeros, and all strides are ones.
247  if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
248  !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
249  return rewriter.notifyMatchFailure(
250  sliceOp, "expects offsets to be 0s and strides to be 1s");
251  }
252 
253  // Create a new empty output tensor.
254  Type elementType = unpackOp.getDestType().getElementType();
255  Value output = rewriter.create<EmptyOp>(
256  sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
257  rewriter.replaceOpWithNewOp<UnPackOp>(
258  sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
259  unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
260  return success();
261  }
262 };
263 
264 // Applies 'permutation' on 'inVec' and stores the result in resVec.
265 // 'inVec' may be empty, in that case it's one-to-one mapping with permutation.
266 // `rank` sets the boundary for permutation i.e., the permutation dim can't be
267 // greater than the rank specified. If it's so then return false.
268 // For e.g., permutation {1, 0, 3, 2} with rank 2 is allowed since the values in
269 // permutation[:rank] doesn't exceed rank, whereas, permutation {1, 3, 0, 2} is
270 // not allowed since `3` exceeds the value of the rank in the given range.
271 static bool checkAndPermute(ArrayRef<int64_t> permutation,
272  ArrayRef<int64_t> inVec,
273  SmallVectorImpl<int64_t> &resVec, int64_t rank) {
274 
275  for (unsigned int i = 0; i < rank; ++i) {
276  int64_t remappedPosition = permutation[i];
277  if (remappedPosition >= rank)
278  return false;
279  if (!inVec.empty())
280  remappedPosition = inVec[remappedPosition];
281  resVec.push_back(remappedPosition);
282  }
283 
284  return true;
285 }
286 
287 /// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
288 /// semantics.
289 struct FoldProducerPackWithConsumerLinalgTransposeOp
290  : public OpInterfaceRewritePattern<linalg::LinalgOp> {
292 
293  LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
294  PatternRewriter &rewriter) const override {
295  auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>();
296 
297  if (!packOp)
298  return failure();
299 
300  FailureOr<SmallVector<int64_t>> maybePerm =
301  getTransposeOpPermutation(linalgOp);
302  if (failed(maybePerm))
303  return failure();
304 
305  auto innerDimsPos = packOp.getInnerDimsPos();
306  auto mixedInnerTiles = packOp.getMixedTiles();
307  auto outerDimsPerm = packOp.getOuterDimsPerm();
308  auto transposePerm = maybePerm.value();
309  SmallVector<int64_t> newOuterDimsPermVec;
310  SmallVector<int64_t> newInnerDimsPosVec;
311  SmallVector<OpFoldResult> newMixedInnerTilesVec;
312  int64_t srcRank = packOp.getSourceRank();
313 
314  if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
315  srcRank))
316  return rewriter.notifyMatchFailure(
317  linalgOp,
318  "Cannot fold in tensor.pack if a tile dimension was transposed "
319  "with a non-tile dimension in linalg.transpose.");
320 
321  // Process transpose operation for tiled inner dimensions
322  for (unsigned int i = srcRank; i < transposePerm.size(); ++i) {
323  int64_t remappedPosition = transposePerm[i] - srcRank;
324  newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]);
325  newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
326  }
327 
328  Value output = packOp.createDestinationTensor(
329  rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
330  newInnerDimsPosVec, newOuterDimsPermVec);
331 
332  rewriter.replaceOpWithNewOp<PackOp>(
333  linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
334  newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
335 
336  return success();
337  }
338 };
339 
340 /// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
341 /// semantics.
342 struct FoldConsumerPackWithProducerLinalgTransposeOp
343  : public OpRewritePattern<PackOp> {
345 
346  LogicalResult matchAndRewrite(PackOp packOp,
347  PatternRewriter &rewriter) const override {
348  auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
349  if (!linalgOp)
350  return failure();
351 
352  FailureOr<SmallVector<int64_t>> maybePerm =
353  getTransposeOpPermutation(linalgOp);
354  if (failed(maybePerm))
355  return failure();
356 
357  auto transposePermutation = maybePerm.value();
358  auto outerDimsPerm = packOp.getOuterDimsPerm();
359  auto innerDimsPos = packOp.getInnerDimsPos();
360  SmallVector<int64_t> newInnerDimsPosVec;
361  SmallVector<int64_t> newOuterDimsPermVec =
362  llvm::to_vector(transposePermutation);
363 
364  if (!outerDimsPerm.empty())
365  applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
366 
367  // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
368  // permutation rank won't necessarily be equal in all cases.
369  for (auto dim : innerDimsPos)
370  newInnerDimsPosVec.push_back(transposePermutation[dim]);
371 
372  Value output = packOp.createDestinationTensor(
373  rewriter, packOp.getLoc(), linalgOp->getOperand(0),
374  packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
375 
376  rewriter.replaceOpWithNewOp<PackOp>(
377  packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
378  packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
379 
380  return success();
381  }
382 };
383 
384 /// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
385 /// transpose semantics.
386 struct FoldProducerUnPackWithConsumerLinalgTransposeOp
387  : public OpInterfaceRewritePattern<linalg::LinalgOp> {
389 
390  LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
391  PatternRewriter &rewriter) const override {
392  auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
393 
394  if (!unPackOp)
395  return failure();
396 
397  FailureOr<SmallVector<int64_t>> maybePerm =
398  getTransposeOpPermutation(linalgOp);
399  if (failed(maybePerm))
400  return failure();
401 
402  auto outerDimsPerm = unPackOp.getOuterDimsPerm();
403  auto innerDimsPos = unPackOp.getInnerDimsPos();
404  SmallVector<int64_t> newInnerDimsPosVec;
405  SmallVector<int64_t> newOuterDimsPermVec =
406  invertPermutationVector(maybePerm.value());
407 
408  // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
409  // permutation rank won't necessarily be equal in all cases.
410  for (auto dim : innerDimsPos)
411  newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]);
412 
413  if (!outerDimsPerm.empty())
414  applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
415 
416  // Reuse the destination of the transpose op.
417  rewriter.replaceOpWithNewOp<UnPackOp>(
418  linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
419  newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
420 
421  return success();
422  }
423 };
424 
425 /// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
426 /// transpose semantics.
427 struct FoldConsumerUnPackWithProducerLinalgTransposeOp
428  : public OpRewritePattern<UnPackOp> {
430 
431  LogicalResult matchAndRewrite(UnPackOp unPackOp,
432  PatternRewriter &rewriter) const override {
433  auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
434  if (!linalgOp)
435  return failure();
436 
437  FailureOr<SmallVector<int64_t>> maybePerm =
438  getTransposeOpPermutation(linalgOp);
439  if (failed(maybePerm))
440  return failure();
441 
442  SmallVector<int64_t> inverseTransposePerm =
443  invertPermutationVector(maybePerm.value());
444  auto outerDimsPerm = unPackOp.getOuterDimsPerm();
445  auto innerDimsPos = unPackOp.getInnerDimsPos();
446  int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
447  auto mixedInnerTilesVec = unPackOp.getMixedTiles();
448  SmallVector<int64_t> newOuterDimsPermVec;
449  SmallVector<int64_t> newInnerDimsPosVec;
450  SmallVector<OpFoldResult> newMixedInnerTilesVec;
451 
452  if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
453  newOuterDimsPermVec, destRank))
454  return rewriter.notifyMatchFailure(
455  unPackOp,
456  "Cannot fold in tensor.unpack if a tile dimension was transposed "
457  "with a non-tile dimension in linalg.transpose.");
458 
459  // Process transpose operation for tiled inner dimensions
460  for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
461  int64_t remappedPosition = inverseTransposePerm[i] - destRank;
462  newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
463  newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
464  }
465 
466  Value output = unPackOp.createDestinationTensor(
467  rewriter, unPackOp.getLoc(), linalgOp->getOperand(0),
468  newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
469 
470  rewriter.replaceOpWithNewOp<UnPackOp>(
471  unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
472  newMixedInnerTilesVec, newOuterDimsPermVec);
473 
474  return success();
475  }
476 };
477 } // namespace
478 
480  patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
481  FoldProducerPackWithConsumerLinalgTransposeOp,
482  FoldConsumerPackWithProducerLinalgTransposeOp,
483  FoldConsumerUnPackWithProducerLinalgTransposeOp,
484  FoldProducerUnPackWithConsumerLinalgTransposeOp>(
485  patterns.getContext());
486 }
487 
489  patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
490  patterns.getContext());
491 }
492 
493 } // namespace tensor
494 } // namespace mlir
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:931
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
@ Type
An inlay hint that for a type annotation.
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that simplify tensor.pack and tensor.unpack operations.
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
Definition: PatternMatch.h:374
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362