MLIR  19.0.0git
PackAndUnpackPatterns.cpp
Go to the documentation of this file.
1 //===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/PatternMatch.h"
14 
15 namespace mlir {
16 namespace tensor {
17 namespace {
18 
19 static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
20  return llvm::all_of(
21  ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
22 }
23 
24 /// Returns the number of shape sizes that is either dynamic or greater than 1.
25 static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
26  return llvm::count_if(
27  shape, [](int64_t v) { return ShapedType::isDynamic(v) || v > 1; });
28 }
29 
30 /// Returns success() if there is only 1 dimension size in non-packed domain
31 /// being greater than 1 and packing only happens on the dimension.
32 /// Note: this method should only be used by pack/unpack to reshape conversion.
33 /// It assumes that non-unit inner tile size must be used by the non-unit
34 /// dimension.
35 static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
36  ArrayRef<int64_t> srcShape,
37  ArrayRef<int64_t> innerPackTileSize) {
38  if (getNumGtOneDims(srcShape) > 1) {
39  return rewriter.notifyMatchFailure(
40  op, "expects non-packed domain to have at most one non-unit dims");
41  }
42  // Non-unit inner tile size must be used by the non-unit dimension. If not, it
43  // will faill on getting reassociation maps.
44  if (getNumGtOneDims(innerPackTileSize) > 1) {
45  return rewriter.notifyMatchFailure(
46  op, "expects at most one non-unit inner tiles");
47  }
48  return success();
49 }
50 
51 /// Packing one-dimensional tensor can be expressed as an expand shape op.
52 struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
54 
55  Value insertExpand(RewriterBase &rewriter, Location loc, Value operand,
56  Type newOperandType, ArrayAttr reassociation) const {
57  if (operand.getType() == newOperandType)
58  return operand;
59  return rewriter.create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
60  reassociation);
61  }
62 
63  /// Returns success() if it is only packing on the innermost dimension.
64  LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter,
65  PackOp packOp) const {
66  auto outerDimsPerm = packOp.getOuterDimsPerm();
67  if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
68  return rewriter.notifyMatchFailure(
69  packOp,
70  "expects outer_dims_perm is empty or an identity permutation");
71  }
72 
73  int64_t srcRank = packOp.getSourceRank();
74  ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
75  if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
76  return rewriter.notifyMatchFailure(
77  packOp, "expects packing at the innermost dimension");
78  }
79  return success();
80  }
81 
82  LogicalResult matchAndRewrite(PackOp packOp,
83  PatternRewriter &rewriter) const override {
84  if (packOp.getPaddingValue())
85  return rewriter.notifyMatchFailure(packOp, "expects no padding value");
86 
87  RankedTensorType sourceType = packOp.getSourceType();
88  if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
89  failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
90  packOp.getStaticTiles()))) {
91  return failure();
92  }
93 
94  RankedTensorType destType = packOp.getDestType();
95  auto reassociation =
96  getReassociationIndicesForReshape(sourceType, destType);
97  if (!reassociation)
98  return failure();
99  Value expanded = insertExpand(
100  rewriter, packOp.getLoc(), packOp.getSource(), destType,
101  getReassociationIndicesAttribute(rewriter, *reassociation));
102  rewriter.replaceOp(packOp, expanded);
103  return success();
104  }
105 };
106 
107 struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
109 
110  Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
111  Type newOperandType, ArrayAttr reassociation) const {
112  if (operand.getType() == newOperandType)
113  return operand;
114  return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType,
115  operand, reassociation);
116  }
117 
118  /// Returns success() if it is unpacking on the innermost dimension.
119  LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter,
120  UnPackOp unpackOp) const {
121  auto outerDimsPerm = unpackOp.getOuterDimsPerm();
122  if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
123  return rewriter.notifyMatchFailure(
124  unpackOp,
125  "expects outer_dims_perm is empty or an identity permutation");
126  }
127 
128  RankedTensorType sourceType = unpackOp.getSourceType();
129  RankedTensorType destType = unpackOp.getDestType();
130  if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
131  return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
132 
133  ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
134  if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
135  return rewriter.notifyMatchFailure(
136  unpackOp, "expects unpacking on the innermost dimension");
137  }
138 
139  return success();
140  }
141 
142  LogicalResult matchAndRewrite(UnPackOp unpackOp,
143  PatternRewriter &rewriter) const override {
144  RankedTensorType destType = unpackOp.getDestType();
145  if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
146  failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
147  unpackOp.getStaticTiles()))) {
148  return failure();
149  }
150 
151  RankedTensorType sourceType = unpackOp.getSourceType();
152  auto reassociation =
153  getReassociationIndicesForReshape(sourceType, destType);
154  if (!reassociation)
155  return failure();
156  Value collapsed = insertCollapse(
157  rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
158  getReassociationIndicesAttribute(rewriter, *reassociation));
159  rewriter.replaceOp(unpackOp, collapsed);
160  return success();
161  }
162 };
163 
164 /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
165 /// the pad op has zero low paddings, or if `pack` has no padding values.
166 struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
168 
169  LogicalResult matchAndRewrite(PackOp packOp,
170  PatternRewriter &rewriter) const override {
171  auto padOp = packOp.getSource().getDefiningOp<PadOp>();
172 
173  if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
174  return failure();
175 
176  Value constantPaddingValue = padOp.getConstantPaddingValue();
177  if (!constantPaddingValue)
178  return failure();
179 
180  if (auto paddingValue = packOp.getPaddingValue())
181  if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
182  return failure();
183 
184  rewriter.replaceOpWithNewOp<PackOp>(
185  packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
186  packOp.getMixedTiles(), constantPaddingValue,
187  packOp.getOuterDimsPerm());
188  return success();
189  }
190 };
191 
192 /// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
193 /// has extract_slice semantics.
194 struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
196 
197  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
198  PatternRewriter &rewriter) const override {
199  auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
200  if (!unpackOp)
201  return failure();
202 
203  if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
204  return rewriter.notifyMatchFailure(
205  sliceOp, "rank-reduced folding is not supported");
206  }
207 
208  // Check all offsets are zeros, and all strides are ones.
209  if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
210  !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
211  return rewriter.notifyMatchFailure(
212  sliceOp, "expects offsets to be 0s and strides to be 1s");
213  }
214 
215  // Create a new empty output tensor.
216  Type elementType = unpackOp.getDestType().getElementType();
217  Value output = rewriter.create<EmptyOp>(
218  sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
219  rewriter.replaceOpWithNewOp<UnPackOp>(
220  sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
221  unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
222  return success();
223  }
224 };
225 
226 // Applies 'permutation' on 'inVec' and stores the result in resVec.
227 // 'inVec' may be empty, in that case it's one-to-one mapping with permutation.
228 // `rank` sets the boundary for permutation i.e., the permutation dim can't be
229 // greater than the rank specified. If it's so then return false.
230 // For e.g., permutation {1, 0, 3, 2} with rank 2 is allowed since the values in
231 // permutation[:rank] doesn't exceed rank, whereas, permutation {1, 3, 0, 2} is
232 // not allowed since `3` exceeds the value of the rank in the given range.
233 static bool checkAndPermute(ArrayRef<int64_t> permutation,
234  ArrayRef<int64_t> inVec,
235  SmallVectorImpl<int64_t> &resVec, int64_t rank) {
236 
237  for (unsigned int i = 0; i < rank; ++i) {
238  int64_t remappedPosition = permutation[i];
239 
240  if (!inVec.empty()) {
241  if (remappedPosition >= rank) {
242  return false;
243  }
244  remappedPosition = inVec[remappedPosition];
245  }
246 
247  resVec.push_back(remappedPosition);
248  }
249 
250  return true;
251 }
252 
253 /// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
254 /// semantics.
255 struct FoldProducerPackWithConsumerLinalgTransposeOp
256  : public OpRewritePattern<linalg::TransposeOp> {
258 
259  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
260  PatternRewriter &rewriter) const override {
261  auto packOp = transposeOp.getOperand(0).getDefiningOp<PackOp>();
262 
263  if (!packOp)
264  return failure();
265 
266  auto innerDimsPos = packOp.getInnerDimsPos();
267  auto mixedInnerTiles = packOp.getMixedTiles();
268  auto outerDimsPerm = packOp.getOuterDimsPerm();
269  auto transposePerm = transposeOp.getPermutation();
270  SmallVector<int64_t> newOuterDimsPermVec;
271  SmallVector<int64_t> newInnerDimsPosVec;
272  SmallVector<OpFoldResult> newMixedInnerTilesVec;
273  int64_t srcRank = packOp.getSourceRank();
274 
275  if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
276  srcRank))
277  return rewriter.notifyMatchFailure(
278  transposeOp,
279  "Cannot fold in tensor.pack if a tile dimension was transposed "
280  "with a non-tile dimension in linalg.transpose.");
281 
282  // Process transpose operation for tiled inner dimensions
283  for (unsigned int i = srcRank; i < transposePerm.size(); ++i) {
284  int64_t remappedPosition = transposePerm[i] - srcRank;
285  newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]);
286  newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
287  }
288 
289  Value output = packOp.createDestinationTensor(
290  rewriter, transposeOp.getLoc(), packOp.getSource(),
291  newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
292 
293  rewriter.replaceOpWithNewOp<PackOp>(
294  transposeOp, packOp.getSource(), output, newInnerDimsPosVec,
295  newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
296 
297  return success();
298  }
299 };
300 
301 /// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
302 /// semantics.
303 struct FoldConsumerPackWithProducerLinalgTransposeOp
304  : public OpRewritePattern<PackOp> {
306 
307  LogicalResult matchAndRewrite(PackOp packOp,
308  PatternRewriter &rewriter) const override {
309  auto transposeOp = packOp.getSource().getDefiningOp<linalg::TransposeOp>();
310 
311  if (!transposeOp)
312  return failure();
313 
314  auto transposePermutation = transposeOp.getPermutation();
315  auto outerDimsPerm = packOp.getOuterDimsPerm();
316  auto innerDimsPos = packOp.getInnerDimsPos();
317  SmallVector<int64_t> newInnerDimsPosVec;
318  SmallVector<int64_t> newOuterDimsPermVec =
319  llvm::to_vector(transposePermutation);
320 
321  if (!outerDimsPerm.empty())
322  applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
323 
324  // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
325  // permutation rank won't necessarily be equal in all cases.
326  for (auto dim : innerDimsPos)
327  newInnerDimsPosVec.push_back(transposePermutation[dim]);
328 
329  Value output = packOp.createDestinationTensor(
330  rewriter, packOp.getLoc(), transposeOp.getOperand(0),
331  packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
332 
333  rewriter.replaceOpWithNewOp<PackOp>(
334  packOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
335  packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
336 
337  return success();
338  }
339 };
340 
341 /// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
342 /// transpose semantics.
343 struct FoldProducerUnPackWithConsumerLinalgTransposeOp
344  : public OpRewritePattern<linalg::TransposeOp> {
346 
347  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
348  PatternRewriter &rewriter) const override {
349  auto unPackOp = transposeOp.getOperand(0).getDefiningOp<UnPackOp>();
350 
351  if (!unPackOp)
352  return failure();
353 
354  auto transposePermutation = transposeOp.getPermutation();
355  auto outerDimsPerm = unPackOp.getOuterDimsPerm();
356  auto innerDimsPos = unPackOp.getInnerDimsPos();
357  SmallVector<int64_t> newInnerDimsPosVec;
358  SmallVector<int64_t> newOuterDimsPermVec =
359  llvm::to_vector(transposePermutation);
360 
361  if (!outerDimsPerm.empty())
362  applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
363 
364  // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
365  // permutation rank won't necessarily be equal in all cases.
366  for (auto dim : innerDimsPos)
367  newInnerDimsPosVec.push_back(transposePermutation[dim]);
368 
369  Value output = unPackOp.createDestinationTensor(
370  rewriter, transposeOp.getLoc(), unPackOp.getSource(),
371  unPackOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
372 
373  rewriter.replaceOpWithNewOp<UnPackOp>(
374  transposeOp, unPackOp.getSource(), output, newInnerDimsPosVec,
375  unPackOp.getMixedTiles(), newOuterDimsPermVec);
376 
377  return success();
378  }
379 };
380 
381 /// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
382 /// transpose semantics.
383 struct FoldConsumerUnPackWithProducerLinalgTransposeOp
384  : public OpRewritePattern<UnPackOp> {
386 
387  LogicalResult matchAndRewrite(UnPackOp unPackOp,
388  PatternRewriter &rewriter) const override {
389  auto transposeOp =
390  unPackOp.getSource().getDefiningOp<linalg::TransposeOp>();
391 
392  if (!transposeOp)
393  return failure();
394 
395  auto transposePermutation = transposeOp.getPermutation();
396  auto outerDimsPerm = unPackOp.getOuterDimsPerm();
397  auto innerDimsPos = unPackOp.getInnerDimsPos();
398  int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
399  auto mixedInnerTilesVec = unPackOp.getMixedTiles();
400  SmallVector<int64_t> newOuterDimsPermVec;
401  SmallVector<int64_t> newInnerDimsPosVec;
402  SmallVector<OpFoldResult> newMixedInnerTilesVec;
403 
404  if (!checkAndPermute(transposePermutation, outerDimsPerm,
405  newOuterDimsPermVec, destRank))
406  return rewriter.notifyMatchFailure(
407  unPackOp,
408  "Cannot fold in tensor.unpack if a tile dimension was transposed "
409  "with a non-tile dimension in linalg.transpose.");
410 
411  // Process transpose operation for tiled inner dimensions
412  for (unsigned int i = destRank; i < transposePermutation.size(); ++i) {
413  int64_t remappedPosition = transposePermutation[i] - destRank;
414  newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
415  newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
416  }
417 
418  Value output = unPackOp.createDestinationTensor(
419  rewriter, unPackOp.getLoc(), transposeOp.getOperand(0),
420  newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
421 
422  rewriter.replaceOpWithNewOp<UnPackOp>(
423  unPackOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
424  newMixedInnerTilesVec, newOuterDimsPermVec);
425 
426  return success();
427  }
428 };
429 } // namespace
430 
432  patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
433  FoldProducerPackWithConsumerLinalgTransposeOp,
434  FoldConsumerPackWithProducerLinalgTransposeOp,
435  FoldConsumerUnPackWithProducerLinalgTransposeOp,
436  FoldProducerUnPackWithConsumerLinalgTransposeOp>(
437  patterns.getContext());
438 }
439 
441  patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
442  patterns.getContext());
443 }
444 
445 } // namespace tensor
446 } // namespace mlir
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:930
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
@ Type
An inlay hint that for a type annotation.
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that simplify tensor.pack and tensor.unpack operations.
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362