MLIR  19.0.0git
PackAndUnpackPatterns.cpp
Go to the documentation of this file.
1 //===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/PatternMatch.h"
14 
15 namespace mlir {
16 namespace tensor {
17 namespace {
18 
19 static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
20  return llvm::all_of(
21  ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
22 }
23 
24 /// Returns the number of shape sizes that is either dynamic or greater than 1.
25 static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
26  return llvm::count_if(
27  shape, [](int64_t v) { return ShapedType::isDynamic(v) || v > 1; });
28 }
29 
30 /// Returns success() if there is only 1 dimension size in non-packed domain
31 /// being greater than 1 and packing only happens on the dimension.
32 /// Note: this method should only be used by pack/unpack to reshape conversion.
33 /// It assumes that non-unit inner tile size must be used by the non-unit
34 /// dimension.
35 static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
36  ArrayRef<int64_t> srcShape,
37  ArrayRef<int64_t> innerPackTileSize) {
38  if (getNumGtOneDims(srcShape) > 1) {
39  return rewriter.notifyMatchFailure(
40  op, "expects non-packed domain to have at most one non-unit dims");
41  }
42  // Non-unit inner tile size must be used by the non-unit dimension. If not, it
43  // will faill on getting reassociation maps.
44  if (getNumGtOneDims(innerPackTileSize) > 1) {
45  return rewriter.notifyMatchFailure(
46  op, "expects at most one non-unit inner tiles");
47  }
48  return success();
49 }
50 
51 /// Packing one-dimensional tensor can be expressed as an expand shape op.
52 struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
54 
55  FailureOr<Value>
56  insertExpand(RewriterBase &rewriter, Location loc, Value operand,
57  Type newOperandType,
58  ArrayRef<ReassociationIndices> reassociation) const {
59  if (operand.getType() == newOperandType)
60  return operand;
61  return rewriter
62  .create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
63  reassociation)
64  .getResult();
65  }
66 
67  /// Returns success() if it is only packing on the innermost dimension.
68  LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter,
69  PackOp packOp) const {
70  auto outerDimsPerm = packOp.getOuterDimsPerm();
71  if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
72  return rewriter.notifyMatchFailure(
73  packOp,
74  "expects outer_dims_perm is empty or an identity permutation");
75  }
76 
77  int64_t srcRank = packOp.getSourceRank();
78  ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
79  if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
80  return rewriter.notifyMatchFailure(
81  packOp, "expects packing at the innermost dimension");
82  }
83  return success();
84  }
85 
86  LogicalResult matchAndRewrite(PackOp packOp,
87  PatternRewriter &rewriter) const override {
88  if (packOp.getPaddingValue())
89  return rewriter.notifyMatchFailure(packOp, "expects no padding value");
90 
91  RankedTensorType sourceType = packOp.getSourceType();
92  if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
93  failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
94  packOp.getStaticTiles()))) {
95  return failure();
96  }
97 
98  RankedTensorType destType = packOp.getDestType();
99  auto reassociation =
100  getReassociationIndicesForReshape(sourceType, destType);
101  if (!reassociation)
102  return failure();
103  FailureOr<Value> expanded =
104  insertExpand(rewriter, packOp.getLoc(), packOp.getSource(), destType,
105  *reassociation);
106  if (failed(expanded)) {
107  return rewriter.notifyMatchFailure(
108  packOp, "unable to expand source of tensor.pack");
109  }
110  rewriter.replaceOp(packOp, *expanded);
111  return success();
112  }
113 };
114 
115 struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
117 
118  Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
119  Type newOperandType, ArrayAttr reassociation) const {
120  if (operand.getType() == newOperandType)
121  return operand;
122  return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType,
123  operand, reassociation);
124  }
125 
126  /// Returns success() if it is unpacking on the innermost dimension.
127  LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter,
128  UnPackOp unpackOp) const {
129  auto outerDimsPerm = unpackOp.getOuterDimsPerm();
130  if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
131  return rewriter.notifyMatchFailure(
132  unpackOp,
133  "expects outer_dims_perm is empty or an identity permutation");
134  }
135 
136  RankedTensorType sourceType = unpackOp.getSourceType();
137  RankedTensorType destType = unpackOp.getDestType();
138  if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
139  return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
140 
141  ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
142  if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
143  return rewriter.notifyMatchFailure(
144  unpackOp, "expects unpacking on the innermost dimension");
145  }
146 
147  return success();
148  }
149 
150  LogicalResult matchAndRewrite(UnPackOp unpackOp,
151  PatternRewriter &rewriter) const override {
152  RankedTensorType destType = unpackOp.getDestType();
153  if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
154  failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
155  unpackOp.getStaticTiles()))) {
156  return failure();
157  }
158 
159  RankedTensorType sourceType = unpackOp.getSourceType();
160  auto reassociation =
161  getReassociationIndicesForReshape(sourceType, destType);
162  if (!reassociation)
163  return failure();
164  Value collapsed = insertCollapse(
165  rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
166  getReassociationIndicesAttribute(rewriter, *reassociation));
167  rewriter.replaceOp(unpackOp, collapsed);
168  return success();
169  }
170 };
171 
172 /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
173 /// the pad op has zero low paddings, or if `pack` has no padding values.
174 struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
176 
177  LogicalResult matchAndRewrite(PackOp packOp,
178  PatternRewriter &rewriter) const override {
179  auto padOp = packOp.getSource().getDefiningOp<PadOp>();
180 
181  if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
182  return failure();
183 
184  Value constantPaddingValue = padOp.getConstantPaddingValue();
185  if (!constantPaddingValue)
186  return failure();
187 
188  if (auto paddingValue = packOp.getPaddingValue())
189  if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
190  return failure();
191 
192  rewriter.replaceOpWithNewOp<PackOp>(
193  packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
194  packOp.getMixedTiles(), constantPaddingValue,
195  packOp.getOuterDimsPerm());
196  return success();
197  }
198 };
199 
200 /// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
201 /// has extract_slice semantics.
202 struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
204 
205  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
206  PatternRewriter &rewriter) const override {
207  auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
208  if (!unpackOp)
209  return failure();
210 
211  if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
212  return rewriter.notifyMatchFailure(
213  sliceOp, "rank-reduced folding is not supported");
214  }
215 
216  // Check all offsets are zeros, and all strides are ones.
217  if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
218  !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
219  return rewriter.notifyMatchFailure(
220  sliceOp, "expects offsets to be 0s and strides to be 1s");
221  }
222 
223  // Create a new empty output tensor.
224  Type elementType = unpackOp.getDestType().getElementType();
225  Value output = rewriter.create<EmptyOp>(
226  sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
227  rewriter.replaceOpWithNewOp<UnPackOp>(
228  sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
229  unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
230  return success();
231  }
232 };
233 
234 // Applies 'permutation' on 'inVec' and stores the result in resVec.
235 // 'inVec' may be empty, in that case it's one-to-one mapping with permutation.
236 // `rank` sets the boundary for permutation i.e., the permutation dim can't be
237 // greater than the rank specified. If it's so then return false.
238 // For e.g., permutation {1, 0, 3, 2} with rank 2 is allowed since the values in
239 // permutation[:rank] doesn't exceed rank, whereas, permutation {1, 3, 0, 2} is
240 // not allowed since `3` exceeds the value of the rank in the given range.
241 static bool checkAndPermute(ArrayRef<int64_t> permutation,
242  ArrayRef<int64_t> inVec,
243  SmallVectorImpl<int64_t> &resVec, int64_t rank) {
244 
245  for (unsigned int i = 0; i < rank; ++i) {
246  int64_t remappedPosition = permutation[i];
247 
248  if (!inVec.empty()) {
249  if (remappedPosition >= rank) {
250  return false;
251  }
252  remappedPosition = inVec[remappedPosition];
253  }
254 
255  resVec.push_back(remappedPosition);
256  }
257 
258  return true;
259 }
260 
261 /// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
262 /// semantics.
263 struct FoldProducerPackWithConsumerLinalgTransposeOp
264  : public OpRewritePattern<linalg::TransposeOp> {
266 
267  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
268  PatternRewriter &rewriter) const override {
269  auto packOp = transposeOp.getOperand(0).getDefiningOp<PackOp>();
270 
271  if (!packOp)
272  return failure();
273 
274  auto innerDimsPos = packOp.getInnerDimsPos();
275  auto mixedInnerTiles = packOp.getMixedTiles();
276  auto outerDimsPerm = packOp.getOuterDimsPerm();
277  auto transposePerm = transposeOp.getPermutation();
278  SmallVector<int64_t> newOuterDimsPermVec;
279  SmallVector<int64_t> newInnerDimsPosVec;
280  SmallVector<OpFoldResult> newMixedInnerTilesVec;
281  int64_t srcRank = packOp.getSourceRank();
282 
283  if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
284  srcRank))
285  return rewriter.notifyMatchFailure(
286  transposeOp,
287  "Cannot fold in tensor.pack if a tile dimension was transposed "
288  "with a non-tile dimension in linalg.transpose.");
289 
290  // Process transpose operation for tiled inner dimensions
291  for (unsigned int i = srcRank; i < transposePerm.size(); ++i) {
292  int64_t remappedPosition = transposePerm[i] - srcRank;
293  newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]);
294  newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
295  }
296 
297  Value output = packOp.createDestinationTensor(
298  rewriter, transposeOp.getLoc(), packOp.getSource(),
299  newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
300 
301  rewriter.replaceOpWithNewOp<PackOp>(
302  transposeOp, packOp.getSource(), output, newInnerDimsPosVec,
303  newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
304 
305  return success();
306  }
307 };
308 
309 /// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
310 /// semantics.
311 struct FoldConsumerPackWithProducerLinalgTransposeOp
312  : public OpRewritePattern<PackOp> {
314 
315  LogicalResult matchAndRewrite(PackOp packOp,
316  PatternRewriter &rewriter) const override {
317  auto transposeOp = packOp.getSource().getDefiningOp<linalg::TransposeOp>();
318 
319  if (!transposeOp)
320  return failure();
321 
322  auto transposePermutation = transposeOp.getPermutation();
323  auto outerDimsPerm = packOp.getOuterDimsPerm();
324  auto innerDimsPos = packOp.getInnerDimsPos();
325  SmallVector<int64_t> newInnerDimsPosVec;
326  SmallVector<int64_t> newOuterDimsPermVec =
327  llvm::to_vector(transposePermutation);
328 
329  if (!outerDimsPerm.empty())
330  applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
331 
332  // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
333  // permutation rank won't necessarily be equal in all cases.
334  for (auto dim : innerDimsPos)
335  newInnerDimsPosVec.push_back(transposePermutation[dim]);
336 
337  Value output = packOp.createDestinationTensor(
338  rewriter, packOp.getLoc(), transposeOp.getOperand(0),
339  packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
340 
341  rewriter.replaceOpWithNewOp<PackOp>(
342  packOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
343  packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
344 
345  return success();
346  }
347 };
348 
349 /// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
350 /// transpose semantics.
351 struct FoldProducerUnPackWithConsumerLinalgTransposeOp
352  : public OpRewritePattern<linalg::TransposeOp> {
354 
355  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
356  PatternRewriter &rewriter) const override {
357  auto unPackOp = transposeOp.getOperand(0).getDefiningOp<UnPackOp>();
358 
359  if (!unPackOp)
360  return failure();
361 
362  auto transposePermutation = transposeOp.getPermutation();
363  auto outerDimsPerm = unPackOp.getOuterDimsPerm();
364  auto innerDimsPos = unPackOp.getInnerDimsPos();
365  SmallVector<int64_t> newInnerDimsPosVec;
366  SmallVector<int64_t> newOuterDimsPermVec =
367  llvm::to_vector(transposePermutation);
368 
369  if (!outerDimsPerm.empty())
370  applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
371 
372  // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
373  // permutation rank won't necessarily be equal in all cases.
374  for (auto dim : innerDimsPos)
375  newInnerDimsPosVec.push_back(transposePermutation[dim]);
376 
377  // Reuse the destination of the transpose op.
378  rewriter.replaceOpWithNewOp<UnPackOp>(
379  transposeOp, unPackOp.getSource(), transposeOp.getDpsInits()[0],
380  newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
381 
382  return success();
383  }
384 };
385 
386 /// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
387 /// transpose semantics.
388 struct FoldConsumerUnPackWithProducerLinalgTransposeOp
389  : public OpRewritePattern<UnPackOp> {
391 
392  LogicalResult matchAndRewrite(UnPackOp unPackOp,
393  PatternRewriter &rewriter) const override {
394  auto transposeOp =
395  unPackOp.getSource().getDefiningOp<linalg::TransposeOp>();
396 
397  if (!transposeOp)
398  return failure();
399 
400  auto transposePermutation = transposeOp.getPermutation();
401  auto outerDimsPerm = unPackOp.getOuterDimsPerm();
402  auto innerDimsPos = unPackOp.getInnerDimsPos();
403  int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
404  auto mixedInnerTilesVec = unPackOp.getMixedTiles();
405  SmallVector<int64_t> newOuterDimsPermVec;
406  SmallVector<int64_t> newInnerDimsPosVec;
407  SmallVector<OpFoldResult> newMixedInnerTilesVec;
408 
409  if (!checkAndPermute(transposePermutation, outerDimsPerm,
410  newOuterDimsPermVec, destRank))
411  return rewriter.notifyMatchFailure(
412  unPackOp,
413  "Cannot fold in tensor.unpack if a tile dimension was transposed "
414  "with a non-tile dimension in linalg.transpose.");
415 
416  // Process transpose operation for tiled inner dimensions
417  for (unsigned int i = destRank; i < transposePermutation.size(); ++i) {
418  int64_t remappedPosition = transposePermutation[i] - destRank;
419  newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
420  newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
421  }
422 
423  Value output = unPackOp.createDestinationTensor(
424  rewriter, unPackOp.getLoc(), transposeOp.getOperand(0),
425  newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
426 
427  rewriter.replaceOpWithNewOp<UnPackOp>(
428  unPackOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
429  newMixedInnerTilesVec, newOuterDimsPermVec);
430 
431  return success();
432  }
433 };
434 } // namespace
435 
437  patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
438  FoldProducerPackWithConsumerLinalgTransposeOp,
439  FoldConsumerPackWithProducerLinalgTransposeOp,
440  FoldConsumerUnPackWithProducerLinalgTransposeOp,
441  FoldProducerUnPackWithConsumerLinalgTransposeOp>(
442  patterns.getContext());
443 }
444 
446  patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
447  patterns.getContext());
448 }
449 
450 } // namespace tensor
451 } // namespace mlir
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:930
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
@ Type
An inlay hint that for a type annotation.
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that simplify tensor.pack and tensor.unpack operations.
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362