MLIR  21.0.0git
ReshapePatterns.cpp
Go to the documentation of this file.
1 //===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/PatternMatch.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/Support/Debug.h"
17 #include "llvm/Support/LogicalResult.h"
18 
19 using namespace mlir;
20 using namespace mlir::tensor;
21 
22 namespace {
23 /// Fold expand_shape(extract_slice) ops that cancel itself out.
24 struct FoldExpandOfRankReducingExtract
25  : public OpRewritePattern<ExpandShapeOp> {
27 
28  LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
29  PatternRewriter &rewriter) const override {
30  RankedTensorType resultType = expandShapeOp.getResultType();
31  auto extractSliceOp =
32  expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
33  if (!extractSliceOp)
34  return failure();
35  RankedTensorType srcType = extractSliceOp.getSourceType();
36 
37  // Only cases where the ExpandShapeOp can be folded away entirely are
38  // supported. Moreover, only simple cases where the resulting ExtractSliceOp
39  // has no rank-reduction anymore are supported at the moment.
40  RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
41  srcType, extractSliceOp.getStaticOffsets(),
42  extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
43  if (nonReducingExtractType != resultType)
44  return failure();
45 
46  SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
47  SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
48  SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
49  rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
50  expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
51  mixedStrides);
52  return success();
53  }
54 };
55 
56 /// Fold collapse_shape which only removes static dimensions of size `1`
57 /// into extract_slice.
58 struct FoldUnPaddingCollapseIntoExtract
59  : public OpRewritePattern<tensor::CollapseShapeOp> {
61 
62  LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
63  PatternRewriter &rewriter) const override {
64  auto extractSliceOp =
65  collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
66  // Collapse cannot be folded away with multiple users of the extract slice
67  // and it is not necessarily beneficial to only convert the collapse into
68  // another extract slice.
69  if (!extractSliceOp || !extractSliceOp->hasOneUse())
70  return failure();
71 
72  // Only fold away simple collapse where all removed dimensions have static
73  // size `1`.
75  collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
77  return rewriter.notifyMatchFailure(collapseShapeOp,
78  "expected unpadding collapse");
79 
80  Value unPaddedExtractSlice = rewriter.create<tensor::ExtractSliceOp>(
81  extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
82  extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
83  extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
84  rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice);
85  return success();
86  }
87 };
88 
89 /// Fold insert_slice(collapse_shape) ops that cancel itself out.
90 template <typename OpTy>
91 struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
93 
94  LogicalResult matchAndRewrite(OpTy insertSliceOp,
95  PatternRewriter &rewriter) const override {
96  auto collapseShapeOp =
97  insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
98  if (!collapseShapeOp)
99  return failure();
100  RankedTensorType srcType = collapseShapeOp.getSrcType();
101 
102  // Only cases where the CollapseShapeOp can be folded away entirely are
103  // supported. Moreover, only simple cases where the resulting InsertSliceOp
104  // has no rank-reduction anymore are supported at the moment.
105  RankedTensorType nonReducingInsertType =
106  RankedTensorType::get(insertSliceOp.getStaticSizes(),
107  insertSliceOp.getDestType().getElementType());
108  if (nonReducingInsertType != srcType)
109  return failure();
110 
111  SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
112  SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
113  SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
114  rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(),
115  insertSliceOp.getDest(), mixedOffsets,
116  mixedSizes, mixedStrides);
117  return success();
118  }
119 };
120 
121 /// Fold expand_shape which only adds static dimensions of size `1`
122 /// into insert_slice.
123 template <typename OpTy>
124 struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
126 
127  LogicalResult matchAndRewrite(OpTy insertSliceOp,
128  PatternRewriter &rewriter) const override {
129  auto expandShapeOp = insertSliceOp.getSource()
130  .template getDefiningOp<tensor::ExpandShapeOp>();
131  if (!expandShapeOp)
132  return failure();
133 
134  // Only fold away simple expansion where all added dimensions have static
135  // size `1`.
137  expandShapeOp.getResultType(), expandShapeOp.getSrcType());
139  return rewriter.notifyMatchFailure(insertSliceOp,
140  "expected rank increasing expansion");
141 
142  rewriter.modifyOpInPlace(insertSliceOp, [&]() {
143  insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
144  });
145  return success();
146  }
147 };
148 
149 /// Pattern to bubble up a tensor.expand_shape op through a producer
150 /// tensor.collapse_shape op that has non intersecting reassociations.
151 struct BubbleUpExpandThroughParallelCollapse
152  : public OpRewritePattern<tensor::ExpandShapeOp> {
154 
155  LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
156  PatternRewriter &rewriter) const override {
157  auto collapseOp =
158  expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
159  if (!collapseOp)
160  return failure();
161  auto expandReInds = expandOp.getReassociationIndices();
162  auto collapseReInds = collapseOp.getReassociationIndices();
163 
164  // Special case where the collapsed tensor to expand is a 0-D tensor,
165  // then the reassociation maps will be empty and not produce valid results.
166  if (expandReInds.size() == 0) {
167  return failure();
168  }
169 
170  // Reshapes are parallel to each other (by construction the number of
171  // reassociations specified in the collapse and expand are the same), if at
172  // any position
173  // 1. either the reassociation indices are of the same size, or
174  // 2. either the reassociation in the collapse or the expand is of size 1.
175  ArrayRef<int64_t> staticSourceSize = collapseOp.getSrcType().getShape();
176  ArrayRef<int64_t> staticResultSize = expandOp.getStaticOutputShape();
177  for (auto [expandReassociation, collapseReassociation] :
178  llvm::zip_equal(expandReInds, collapseReInds)) {
179  if (collapseReassociation.size() == expandReassociation.size()) {
180  // Even if the reassociations are the same, the collapse/expand should
181  // result in the same dimensions. i.e 4x8x2 into 64 should be expanded
182  // into 4x8x2 again. In presense of dynamic dimensions one can only
183  // verify "equality" when there is only one dynamic dimension present,
184  // and all other static dimensions are equal.
185  ArrayRef<int64_t> collapsedStaticShapes = staticSourceSize.slice(
186  collapseReassociation.front(), collapseReassociation.size());
187  int64_t numCollapsedDynamic =
188  llvm::count_if(collapsedStaticShapes,
189  [](int64_t d) { return ShapedType::isDynamic(d); });
190  ArrayRef<int64_t> expandedStaticShapes = staticResultSize.slice(
191  expandReassociation.front(), expandReassociation.size());
192  int64_t numExpandedDynamic =
193  llvm::count_if(expandedStaticShapes,
194  [](int64_t d) { return ShapedType::isDynamic(d); });
195  if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
196  collapsedStaticShapes != expandedStaticShapes) {
197  return failure();
198  }
199  continue;
200  }
201  // If the reassociations are not same, one or the other needs to be of
202  // size one.
203  if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
204  return failure();
205  }
206 
207  // Compute new reassociation indices and expanded/collaped shapes.
208  SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
209  Location loc = expandOp->getLoc();
210  SmallVector<OpFoldResult> sourceSizes =
211  tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
212  SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape();
213  SmallVector<OpFoldResult> newExpandSizes;
214 
215  int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0,
216  resultSizeIndex = 0;
217 
218  for (size_t idx = 0, idxEnd = collapseReInds.size(); idx < idxEnd; idx++) {
219  auto &collapseReassociation = collapseReInds[idx];
220  auto &expandReassociation = expandReInds[idx];
221 
222  // Case 1. The reassociations are same in the collapse producer
223  // and expand consumer. In the swapped expand, each of the final
224  // dimensions are kept as is in the expand and the collapse. So,
225  // for every element in the `ReassocationIndices` vector add a new
226  // `ReassociationIndices` vector for the swapped expand and collapse
227  // (of size 1).
228  if (collapseReassociation.size() == expandReassociation.size()) {
229  for (size_t i = 0; i < collapseReassociation.size(); ++i) {
230  newCollapseReInds.push_back({newCollapseIndex++});
231  newExpandReInds.push_back({newExpandIndex++});
232  newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
233  sourceSizeIndex++;
234  }
235  continue;
236  }
237 
238  // Case 2. The `ReassociationIndices` in the collapse is of size > 1 (and
239  // in the expand is of size == 1). In this case, the original dimensions
240  // are preserved on expansion and collapsed subsequently.
241  if (collapseReassociation.size() != 1) {
242  ReassociationIndices newCollapseReassociation;
243  for (size_t i = 0; i < collapseReassociation.size(); ++i) {
244  newCollapseReassociation.push_back(newCollapseIndex++);
245  newExpandReInds.push_back({newExpandIndex++});
246  newExpandSizes.push_back(sourceSizes[sourceSizeIndex++]);
247  }
248  resultSizeIndex++;
249  newCollapseReInds.push_back(newCollapseReassociation);
250  continue;
251  }
252 
253  // Case 3. The `ReassociationIndices` in the expand is of size > 1 (and
254  // in the collapse is of size == 1). In this case, the expansion happens
255  // first and the expanded dimensions are preserved on collapse.
256  ReassociationIndices newExpandReassociation;
257  for (size_t i = 0; i < expandReassociation.size(); ++i) {
258  newExpandReassociation.push_back(newExpandIndex++);
259  newCollapseReInds.push_back({newCollapseIndex++});
260  newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
261  }
262  newExpandReInds.push_back(newExpandReassociation);
263  sourceSizeIndex++;
264  }
265 
266  // Swap reshape order.
267  SmallVector<Value> dynamicSizes;
268  SmallVector<int64_t> staticSizes;
269  dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
270  auto expandResultType = expandOp.getResultType().clone(staticSizes);
271  Value newCollapseSrc = collapseOp.getSrc();
272  // If the number of reassociation indices in the new `expand_shape` op
273  // matches the number of dimensions of the result, then the expand_shape
274  // is a no-op.
275  if (newExpandReInds.size() != newExpandSizes.size()) {
276  newCollapseSrc = rewriter.create<tensor::ExpandShapeOp>(
277  loc, expandResultType, newCollapseSrc, newExpandReInds,
278  newExpandSizes);
279  }
280 
281  // If the number of reassociation indices in the new `collapse_shape` op
282  // matches the number of dimensions of the source, then the collapse_shape
283  // is a no-op.
284  Value replacement = newCollapseSrc;
285  if (newCollapseReInds.size() != newExpandSizes.size()) {
286  replacement = rewriter.create<tensor::CollapseShapeOp>(
287  loc, newCollapseSrc, newCollapseReInds);
288  }
289  rewriter.replaceOp(expandOp, replacement);
290  return success();
291  }
292 };
293 
294 /// Converts `tensor.extract_slice(tensor.expand_shape)` to
295 /// `tensor.expand_shape(tensor.extract_slice)`.
296 ///
297 /// For this transformation to be possible, the slice must be fully contiguous
298 /// within each reassociation group of the expand_shape. A slice is defined as
299 /// fully contiguous within a reassociation group if after flattening the
300 /// reassociation group to a single 1D range, then the slice taken out of the
301 /// group could be defined as a single contiguous subrange within that range.
302 ///
303 /// Rank reducing slices are not supported.
304 ///
305 /// Example:
306 /// The transformation is possible because each reassociation group has a
307 /// contiguous slice (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4]).
308 /// ```
309 /// BEFORE:
310 /// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
311 /// tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
312 /// %slice = tensor.extract_slice %reshape ...
313 /// tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
314 ///
315 /// AFTER:
316 /// %slice = tensor.extract_slice %in ...
317 /// tensor<8x16x32xf32> to tensor<8x5x4xf32>
318 /// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
319 /// tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
320 /// ```
321 ///
322 /// Note - this pattern could be extended to be a swap pattern between
323 /// `tensor.expand_shape` and `tensor.extract_slice`, but is currently
324 /// implemented only as a bubble up pattern for `tensor.extract_slice`.
325 struct BubbleUpExpandShapeThroughExtractSlice
326  : public OpRewritePattern<tensor::ExtractSliceOp> {
328 
329  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
330  PatternRewriter &rewriter) const override {
331  auto expandShapeOp =
332  sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
333 
334  if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp,
335  rewriter)
336  .failed())
337  return failure();
338 
339  // The tensor.extract_slice before applying the pattern works on the result
340  // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
341  // referring to the state before applying the pattern are named with the
342  // prefix "expanded", and ones referring to the state after applying the
343  // pattern are named with the prefix "collapsed".
344  SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
345  SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
346  SmallVector<OpFoldResult> expandedShape =
347  getMixedValues(expandShapeOp.getStaticOutputShape(),
348  expandShapeOp.getOutputShape(), rewriter);
349 
350  // Helper variables and function for accumulating the size values.
351  Location loc = expandShapeOp->getLoc();
352  AffineExpr d0, d1, d2;
353  bindDims(rewriter.getContext(), d0, d1, d2);
354  // Multiply two integers.
355  auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
356  auto mulMap = AffineMap::get(2, 0, {d0 * d1});
357  return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
358  {v1, v2});
359  };
360 
361  // Compute new offsets, sizes, and strides for tensor.extract_slice.
362  // The new tensor.extract_slice will work on a tensor that has has a rank of
363  // ReassociationIndices.size(). In the loop a single offset, size, and
364  // stride value is computed per reassociation group.
365  SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes,
366  collapsedStrides;
367  for (const ReassociationIndices &indices :
368  expandShapeOp.getReassociationIndices()) {
369  // collapsedSize will hold the size of the single dim that represents the
370  // reassociation group in the non expanded tensor.
371  OpFoldResult collapsedSize = rewriter.getIndexAttr(1);
372  // The reassocGroupSizes and reassocGroupOffsets are used to create an
373  // affine.linearize_index op to linearize the single offset value required
374  // for this reassociation group.
375  SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
376 
377  for (long expandedDim : indices) {
378  // reassocGroupSizes and reassocGroupOffsets can be obtained directly
379  // from the expanded state, but the collapsed size requires calculation
380  // as it did not previously exist.
381  reassocGroupSizes.push_back(expandedShape[expandedDim]);
382  reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
383  collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
384  }
385 
386  SmallVector<Value> offsetVals =
387  llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) {
388  return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
389  });
390  OpFoldResult collapsedOffset =
391  rewriter
392  .create<affine::AffineLinearizeIndexOp>(loc, offsetVals,
393  reassocGroupSizes,
394  /*disjoint=*/true)
395  .getResult();
396  collapsedOffsets.push_back(collapsedOffset);
397  collapsedSizes.push_back(collapsedSize);
398 
399  // Only unit stride is supported.
400  collapsedStrides.push_back(rewriter.getIndexAttr(1));
401  }
402 
403  // The shape of the result can be obtained from the sizes passed in.
404  SmallVector<Value> dynDims;
405  SmallVector<int64_t> shape;
406  dispatchIndexOpFoldResults(expandedSizes, dynDims, shape);
407  RankedTensorType resultType = RankedTensorType::get(
408  shape, expandShapeOp.getResultType().getElementType());
409 
410  // Create a new ExtractSliceOp and ExpandShapeOp.
411  Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
412  loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes,
413  collapsedStrides);
414  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
415  sliceOp, resultType, newSliceOp,
416  expandShapeOp.getReassociationIndices(), expandedSizes);
417  return success();
418  }
419 
420  // Helper function to check if all the required conditions for the
421  // tensor.extract_slice to be bubbled up through the tensor.expand_shape are
422  // met.
423  LogicalResult
424  checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp,
425  tensor::ExpandShapeOp expandShapeOp,
426  PatternRewriter &rewriter) const {
427 
428  if (!expandShapeOp) {
429  return rewriter.notifyMatchFailure(
430  sliceOp, "tensor.extract_slice source not produced by expand_shape");
431  }
432 
433  if (!sliceOp.hasUnitStride()) {
434  return rewriter.notifyMatchFailure(
435  sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
436  "be supported in this transformation.");
437  }
438 
439  SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
440  SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
441 
442  if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
443  sizes.size()) {
444  return rewriter.notifyMatchFailure(sliceOp,
445  "unimplemented: rank reducing slice");
446  }
447 
448  SmallVector<OpFoldResult> outputShape =
449  getMixedValues(expandShapeOp.getStaticOutputShape(),
450  expandShapeOp.getOutputShape(), rewriter);
451 
452  std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
453  isZeroOffsetAndFullSize =
454  [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
455  if (!isConstantIntValue(offset, 0))
456  return false;
457  FailureOr<bool> maybeEqual =
458  ValueBoundsConstraintSet::areEqual(sliceSize, size);
459  return llvm::succeeded(maybeEqual) && maybeEqual.value();
460  };
461 
462  // Check that the slice is contiguous within each reassociation group.
463  // The slice is contiguous only if after the first dimension where a non
464  // unit slice is taken, the slice size on all subsequent dimensions of the
465  // group is equal to the entire size of the dimension.
466  // Examples of contiguous slices:
467  // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
468  // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
469  // Examples of non contiguous slices:
470  // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
471  // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
472  for (const ReassociationIndices &indices :
473  expandShapeOp.getReassociationIndices()) {
474  int64_t i = 0;
475  int64_t e = indices.size();
476  // Find the first expanded dim after the first dim with non-unit extracted
477  // size.
478  for (; i < e; ++i) {
479  if (!isConstantIntValue(sizes[indices[i]], 1)) {
480  // +1 to skip the first non-unit size dim.
481  i++;
482  break;
483  }
484  }
485 
486  // Verify that all subsequent dimensions extract the full size of the
487  // source tensor.
488  for (; i < e; ++i) {
489  int64_t expandedDim = indices[i];
490  if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
491  outputShape[expandedDim])) {
492  return rewriter.notifyMatchFailure(
493  sliceOp, "Not a contiguous slice of the expanded tensor.");
494  }
495  }
496  }
497 
498  return success();
499  }
500 };
501 
502 /// Converts `tensor.extract_slice(tensor.collapse_shape)` to
503 /// `tensor.collapse_shape(tensor.extract_slice)`.
504 ///
505 /// For this transformation to be possible - after bubbling up, the extraction
506 /// of the contiguous slice must be representable as a single slice obtained via
507 /// tensor.extract_slice within each reassociation group of the src.
508 ///
509 /// In case the size and offset extracted are static then this is possible if
510 /// the following conditions are met within each reassociation group:
511 /// Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the
512 /// dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the
513 /// shape of a desired slice. A slice of shape S can be extracted as a
514 /// contiguous span of elements if and only if there exists an index k in {0, 1,
515 /// ..., n} such that:
516 /// S_i = 1 for all i < k (that is, all leading dimensions are singleton),
517 /// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly
518 /// one dimension),
519 /// S_i = A_i for all i > k (that is, all trailing dimensions are preserved
520 /// in full).
521 /// In other words, the slice shape S must be of the form:
522 /// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ]
523 ///
524 /// In case the size and/or offset extracted are dynamic then this is possible
525 /// only if there is single dimension in the reassociation group that has a size
526 /// not equal to 1.
527 /// In other words, the tensor shape must be of the form:
528 /// [ 1, 1, ..., 1, A, 1, ...,1 ]
529 /// Note - it might be possible to enable this pattern for more cases when the
530 /// size/offset are dynamic via performing an analysis of the possible values
531 /// that could be given to the size/offset.
532 ///
533 /// Example:
534 /// The transformation is possible because each reassociation group can be
535 /// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?],
536 /// [20->10]).
537 /// ```
538 /// BEFORE:
539 /// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ...
540 /// tensor<8x16x1x7x20f32> to tensor<128x7x20xf32>
541 /// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1]
542 /// tensor<128x7x20xf32> to tensor<32x?x10xf32>
543 ///
544 /// AFTER:
545 /// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10]
546 // [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32>
547 /// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ...
548 /// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32>
549 /// ```
550 ///
551 /// Negative example:
552 /// The transformation is not possible because we cannot use a single slice to
553 /// represent the reassociation group [2x3x10->???]. If we would want the
554 /// collapse to be after the extraction, we would need to extract multiple
555 /// slices and concat them together.
556 /// ```
557 /// %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into
558 /// tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] :
559 /// tensor<60xf32> to tensor<15xf32>
560 /// ```
561 /// If we would want the collapse to be after the extraction, a possible
562 /// alternate transformation could be to extract multiple slices and concat them
563 /// together:
564 /// ```
565 /// %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] :
566 /// tensor<2x3x10xf32> to tensor <1x1x10xf32>
567 /// %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] :
568 /// tensor<2x3x10xf32> to tensor <1x1x5xf32>
569 /// %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} :
570 /// (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32>
571 /// %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32>
572 /// to tensor<15xf32>
573 /// ```
574 /// But this is not the intended purpose of the transformation.
575 struct BubbleUpCollapseShapeThroughExtractSlice
576  : public OpRewritePattern<tensor::ExtractSliceOp> {
578 
579  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
580  PatternRewriter &rewriter) const override {
581  auto collapseShapeOp =
582  sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
583  if (!collapseShapeOp) {
584  return rewriter.notifyMatchFailure(
585  sliceOp,
586  "tensor.extract_slice source not produced by tensor.collapse_shape");
587  }
588 
589  if (!sliceOp.hasUnitStride()) {
590  return rewriter.notifyMatchFailure(
591  sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
592  "be supported in this transformation.");
593  }
594 
595  // The tensor.extract_slice before applying the pattern works on the result
596  // of the tensor.collapse_shape, so variables (i.e. inputs for
597  // ExtractSliceOp) referring to the state before applying the pattern are
598  // named with the prefix "collapsed", and ones referring to the state after
599  // applying the pattern are named with the prefix "expanded".
600  SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
601  SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
602 
603  if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
604  collapsedSizes.size()) {
605  return rewriter.notifyMatchFailure(sliceOp,
606  "unimplemented: rank reducing slice");
607  }
608 
609  ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape();
610  SmallVector<ReassociationIndices, 4> reassociationIndices =
611  collapseShapeOp.getReassociationIndices();
612 
613  // Compute new offsets, sizes, and strides for tensor.extract_slice.
614  // The new tensor.extract_slice will work on a tensor that has has a rank
615  // equal to the rank of the src of the collapse_shape. In each iteration of
616  // the loop, the offsets and sizes will be computed per reassociation group.
617  SmallVector<OpFoldResult> expandedOffsets, expandedSizes;
618  SmallVector<OpFoldResult> expandedStrides(srcShape.size(),
619  rewriter.getIndexAttr(1));
620 
621  for (auto [collapsedSize, collapsedOffset, reassocIndices] :
622  llvm::zip_equal(collapsedSizes, collapsedOffsets,
623  collapseShapeOp.getReassociationIndices())) {
624  // CASE #1 - size and/or offset are dynamic.
625  // In this case, the slice can be represented as a contiguous slice only
626  // if there is a single dimension in the reassociation group that has a
627  // size not equal to 1.
628  if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
629  int nonUnitSizeCount = 0;
630  for (int64_t expandedShapeIdx : reassocIndices) {
631  if (srcShape[expandedShapeIdx] != 1) {
632  nonUnitSizeCount++;
633  expandedSizes.push_back(collapsedSize);
634  expandedOffsets.push_back(collapsedOffset);
635  continue;
636  }
637 
638  expandedSizes.push_back(rewriter.getIndexAttr(1));
639  expandedOffsets.push_back(rewriter.getIndexAttr(0));
640  }
641 
642  if (nonUnitSizeCount != 1) {
643  return rewriter.notifyMatchFailure(
644  sliceOp,
645  "unsupported: slice cannot be verified to be contiguous");
646  }
647  continue;
648  }
649 
650  // CASE #2 = size and offset are static.
651  // Verify that the slice can be represented as a contiguous slice of the
652  // src of the collapse_shape.
653  // Checking this is done on order of most internal dimensions first,
654  // so traversal is done in reverse order of the reassociation group.
655  // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
656  // ...,An] then we first find the size and offset for n...k+1 then for k
657  // and then for k-1...0.
658 
659  // currentCollapsedsize and currentCollapsedOffset are initialized with
660  // the original collapsed size and offset and divided by the expanded
661  // shape size in each dimension as we go along the reassociation group.
662  // In essence we are spreading the original collapsed size and offset over
663  // the various expanded slice dimensions.
664  // The variables are used both to check the validity of the slice and to
665  // compute the expanded sizes and offsets.
666  int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
667  int64_t currentCollapsedOffset =
668  getConstantIntValue(collapsedOffset).value();
669 
670  SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
671 
672  ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
673  reassocIndices.rend());
674  int64_t idx = 0;
675  int64_t reassocGroupSize = reassocIndices.size();
676 
677  // First handle the trailing dimensions where the slice size should be
678  // equal to the tensor shape and the offset should be 0 (n...k+1).
679  for (; idx < reassocGroupSize; ++idx) {
680  int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
681 
682  if (currentCollapsedsize < expandedShapeSize)
683  break;
684 
685  // We need to make sure that the slice size can be set to the shape size
686  // and the offset to 0.
687  if ((currentCollapsedsize % expandedShapeSize) != 0 ||
688  (currentCollapsedOffset % expandedShapeSize) != 0) {
689  return rewriter.notifyMatchFailure(
690  sliceOp, "unsupported: cannot be extracted as a contiguous slice "
691  "of the src of the collapse_shape");
692  }
693 
694  groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize));
695  groupExpandedOffsets.push_back(rewriter.getIndexAttr(0));
696 
697  currentCollapsedsize /= expandedShapeSize;
698  currentCollapsedOffset /= expandedShapeSize;
699  }
700 
701  // Now handle the first dim where slicing occurs on (k).
702  if (idx < reassocGroupSize) {
703  int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
704  int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
705  // We need to make sure that the slice size in this dim + offset will
706  // not exceed the shape size.
707  if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
708  return rewriter.notifyMatchFailure(
709  sliceOp, "unsupported: slice cannot be extracted as a contiguous "
710  "slice of the src of the collapse_shape");
711  }
712 
713  groupExpandedSizes.push_back(
714  rewriter.getIndexAttr(currentCollapsedsize));
715  groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
716 
717  currentCollapsedOffset /= expandedShapeSize;
718  }
719 
720  // Now handle the leading dimensions where the slice size is equal to 1
721  // (k-1...0).
722  // The size for these dimensions must be 1 because of how we constructed
723  // the slice size of the expanded shape. We spread the original collapsed
724  // size over the expanded shape sizes until we reached dimension k where
725  // the remaining size was smaller than the expanded shape size, and spread
726  // the remaining size on it. So, now we are left with only 1s.
727  for (idx++; idx < reassocGroupSize; ++idx) {
728  int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
729  int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
730  groupExpandedSizes.push_back(rewriter.getIndexAttr(1));
731  groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
732  currentCollapsedOffset /= expandedShapeSize;
733  }
734 
735  expandedSizes.append(groupExpandedSizes.rbegin(),
736  groupExpandedSizes.rend());
737  expandedOffsets.append(groupExpandedOffsets.rbegin(),
738  groupExpandedOffsets.rend());
739  }
740 
741  Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
742  collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), expandedOffsets,
743  expandedSizes, expandedStrides);
744  rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
745  sliceOp, sliceOp.getResultType(), newSliceOp,
746  collapseShapeOp.getReassociationIndices());
747 
748  return success();
749  }
750 };
751 
752 } // namespace
753 
756  patterns
757  .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
758  FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
759  FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
760  FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
761  FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
762  patterns.getContext());
763 }
764 
767  patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
768 }
769 
772  patterns.add<BubbleUpExpandShapeThroughExtractSlice,
773  BubbleUpCollapseShapeThroughExtractSlice>(patterns.getContext());
774 }
Base type for affine expression.
Definition: AffineExpr.h:68
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
MLIRContext * getContext() const
Definition: Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:594
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1217
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:340
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314