MLIR  21.0.0git
ReshapePatterns.cpp
Go to the documentation of this file.
1 //===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/PatternMatch.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/Support/Debug.h"
17 #include "llvm/Support/LogicalResult.h"
18 
19 using namespace mlir;
20 using namespace mlir::tensor;
21 
22 namespace {
23 /// Fold expand_shape(extract_slice) ops that cancel itself out.
24 struct FoldExpandOfRankReducingExtract
25  : public OpRewritePattern<ExpandShapeOp> {
27 
28  LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
29  PatternRewriter &rewriter) const override {
30  RankedTensorType resultType = expandShapeOp.getResultType();
31  auto extractSliceOp =
32  expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
33  if (!extractSliceOp)
34  return failure();
35  RankedTensorType srcType = extractSliceOp.getSourceType();
36 
37  // Only cases where the ExpandShapeOp can be folded away entirely are
38  // supported. Moreover, only simple cases where the resulting ExtractSliceOp
39  // has no rank-reduction anymore are supported at the moment.
40  RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
41  srcType, extractSliceOp.getStaticOffsets(),
42  extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
43  if (nonReducingExtractType != resultType)
44  return failure();
45 
46  SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
47  SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
48  SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
49  rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
50  expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
51  mixedStrides);
52  return success();
53  }
54 };
55 
56 /// Fold collapse_shape which only removes static dimensions of size `1`
57 /// into extract_slice.
58 struct FoldUnPaddingCollapseIntoExtract
59  : public OpRewritePattern<tensor::CollapseShapeOp> {
61 
62  LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
63  PatternRewriter &rewriter) const override {
64  auto extractSliceOp =
65  collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
66  // Collapse cannot be folded away with multiple users of the extract slice
67  // and it is not necessarily beneficial to only convert the collapse into
68  // another extract slice.
69  if (!extractSliceOp || !extractSliceOp->hasOneUse())
70  return failure();
71 
72  // Only fold away simple collapse where all removed dimensions have static
73  // size `1`.
75  collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
77  return rewriter.notifyMatchFailure(collapseShapeOp,
78  "expected unpadding collapse");
79 
80  Value unPaddedExtractSlice = rewriter.create<tensor::ExtractSliceOp>(
81  extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
82  extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
83  extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
84  rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice);
85  return success();
86  }
87 };
88 
89 /// Fold insert_slice(collapse_shape) ops that cancel itself out.
90 template <typename OpTy>
91 struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
93 
94  LogicalResult matchAndRewrite(OpTy insertSliceOp,
95  PatternRewriter &rewriter) const override {
96  auto collapseShapeOp =
97  insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
98  if (!collapseShapeOp)
99  return failure();
100  RankedTensorType srcType = collapseShapeOp.getSrcType();
101 
102  // Only cases where the CollapseShapeOp can be folded away entirely are
103  // supported. Moreover, only simple cases where the resulting InsertSliceOp
104  // has no rank-reduction anymore are supported at the moment.
105  RankedTensorType nonReducingInsertType =
106  RankedTensorType::get(insertSliceOp.getStaticSizes(),
107  insertSliceOp.getDestType().getElementType());
108  if (nonReducingInsertType != srcType)
109  return failure();
110 
111  SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
112  SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
113  SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
114  rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(),
115  insertSliceOp.getDest(), mixedOffsets,
116  mixedSizes, mixedStrides);
117  return success();
118  }
119 };
120 
121 /// Fold expand_shape which only adds static dimensions of size `1`
122 /// into insert_slice.
123 template <typename OpTy>
124 struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
126 
127  LogicalResult matchAndRewrite(OpTy insertSliceOp,
128  PatternRewriter &rewriter) const override {
129  auto expandShapeOp = insertSliceOp.getSource()
130  .template getDefiningOp<tensor::ExpandShapeOp>();
131  if (!expandShapeOp)
132  return failure();
133 
134  // Only fold away simple expansion where all added dimensions have static
135  // size `1`.
137  expandShapeOp.getResultType(), expandShapeOp.getSrcType());
139  return rewriter.notifyMatchFailure(insertSliceOp,
140  "expected rank increasing expansion");
141 
142  rewriter.modifyOpInPlace(insertSliceOp, [&]() {
143  insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
144  });
145  return success();
146  }
147 };
148 
149 /// Pattern to bubble up a tensor.expand_shape op through a producer
150 /// tensor.collapse_shape op that has non intersecting reassociations.
151 struct BubbleUpExpandThroughParallelCollapse
152  : public OpRewritePattern<tensor::ExpandShapeOp> {
154 
155  LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
156  PatternRewriter &rewriter) const override {
157  auto collapseOp =
158  expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
159  if (!collapseOp)
160  return failure();
161  auto expandReInds = expandOp.getReassociationIndices();
162  auto collapseReInds = collapseOp.getReassociationIndices();
163 
164  // Special case where the collapsed tensor to expand is a 0-D tensor,
165  // then the reassociation maps will be empty and not produce valid results.
166  if (expandReInds.size() == 0) {
167  return failure();
168  }
169 
170  // Reshapes are parallel to each other (by construction the number of
171  // reassociations specified in the collapse and expand are the same), if at
172  // any position
173  // 1. either the reassociation indices are of the same size, or
174  // 2. either the reassociation in the collapse or the expand is of size 1.
175  ArrayRef<int64_t> staticSourceSize = collapseOp.getSrcType().getShape();
176  ArrayRef<int64_t> staticResultSize = expandOp.getStaticOutputShape();
177  for (auto [expandReassociation, collapseReassociation] :
178  llvm::zip_equal(expandReInds, collapseReInds)) {
179  if (collapseReassociation.size() == expandReassociation.size()) {
180  // Even if the reassociations are the same, the collapse/expand should
181  // result in the same dimensions. i.e 4x8x2 into 64 should be expanded
182  // into 4x8x2 again. In presense of dynamic dimensions one can only
183  // verify "equality" when there is only one dynamic dimension present,
184  // and all other static dimensions are equal.
185  ArrayRef<int64_t> collapsedStaticShapes = staticSourceSize.slice(
186  collapseReassociation.front(), collapseReassociation.size());
187  int64_t numCollapsedDynamic =
188  llvm::count_if(collapsedStaticShapes, ShapedType::isDynamic);
189  ArrayRef<int64_t> expandedStaticShapes = staticResultSize.slice(
190  expandReassociation.front(), expandReassociation.size());
191  int64_t numExpandedDynamic =
192  llvm::count_if(expandedStaticShapes, ShapedType::isDynamic);
193  if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
194  collapsedStaticShapes != expandedStaticShapes) {
195  return failure();
196  }
197  continue;
198  }
199  // If the reassociations are not same, one or the other needs to be of
200  // size one.
201  if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
202  return failure();
203  }
204 
205  // Compute new reassociation indices and expanded/collaped shapes.
206  SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
207  Location loc = expandOp->getLoc();
208  SmallVector<OpFoldResult> sourceSizes =
209  tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
210  SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape();
211  SmallVector<OpFoldResult> newExpandSizes;
212 
213  int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0,
214  resultSizeIndex = 0;
215 
216  for (size_t idx = 0, idxEnd = collapseReInds.size(); idx < idxEnd; idx++) {
217  auto &collapseReassociation = collapseReInds[idx];
218  auto &expandReassociation = expandReInds[idx];
219 
220  // Case 1. The reassociations are same in the collapse producer
221  // and expand consumer. In the swapped expand, each of the final
222  // dimensions are kept as is in the expand and the collapse. So,
223  // for every element in the `ReassocationIndices` vector add a new
224  // `ReassociationIndices` vector for the swapped expand and collapse
225  // (of size 1).
226  if (collapseReassociation.size() == expandReassociation.size()) {
227  for (size_t i = 0; i < collapseReassociation.size(); ++i) {
228  newCollapseReInds.push_back({newCollapseIndex++});
229  newExpandReInds.push_back({newExpandIndex++});
230  newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
231  sourceSizeIndex++;
232  }
233  continue;
234  }
235 
236  // Case 2. The `ReassociationIndices` in the collapse is of size > 1 (and
237  // in the expand is of size == 1). In this case, the original dimensions
238  // are preserved on expansion and collapsed subsequently.
239  if (collapseReassociation.size() != 1) {
240  ReassociationIndices newCollapseReassociation;
241  for (size_t i = 0; i < collapseReassociation.size(); ++i) {
242  newCollapseReassociation.push_back(newCollapseIndex++);
243  newExpandReInds.push_back({newExpandIndex++});
244  newExpandSizes.push_back(sourceSizes[sourceSizeIndex++]);
245  }
246  resultSizeIndex++;
247  newCollapseReInds.push_back(newCollapseReassociation);
248  continue;
249  }
250 
251  // Case 3. The `ReassociationIndices` in the expand is of size > 1 (and
252  // in the collapse is of size == 1). In this case, the expansion happens
253  // first and the expanded dimensions are preserved on collapse.
254  ReassociationIndices newExpandReassociation;
255  for (size_t i = 0; i < expandReassociation.size(); ++i) {
256  newExpandReassociation.push_back(newExpandIndex++);
257  newCollapseReInds.push_back({newCollapseIndex++});
258  newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
259  }
260  newExpandReInds.push_back(newExpandReassociation);
261  sourceSizeIndex++;
262  }
263 
264  // Swap reshape order.
265  SmallVector<Value> dynamicSizes;
266  SmallVector<int64_t> staticSizes;
267  dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
268  auto expandResultType = expandOp.getResultType().clone(staticSizes);
269  Value newCollapseSrc = collapseOp.getSrc();
270  // If the number of reassociation indices in the new `expand_shape` op
271  // matches the number of dimensions of the result, then the expand_shape
272  // is a no-op.
273  if (newExpandReInds.size() != newExpandSizes.size()) {
274  newCollapseSrc = rewriter.create<tensor::ExpandShapeOp>(
275  loc, expandResultType, newCollapseSrc, newExpandReInds,
276  newExpandSizes);
277  }
278 
279  // If the number of reassociation indices in the new `collapse_shape` op
280  // matches the number of dimensions of the source, then the collapse_shape
281  // is a no-op.
282  Value replacement = newCollapseSrc;
283  if (newCollapseReInds.size() != newExpandSizes.size()) {
284  replacement = rewriter.create<tensor::CollapseShapeOp>(
285  loc, newCollapseSrc, newCollapseReInds);
286  }
287  rewriter.replaceOp(expandOp, replacement);
288  return success();
289  }
290 };
291 
292 /// Converts `tensor.extract_slice(tensor.expand_shape)` to
293 /// `tensor.expand_shape(tensor.extract_slice)`.
294 ///
295 /// For this transformation to be possible, the slice must be fully contiguous
296 /// within each reassociation group of the expand_shape. A slice is defined as
297 /// fully contiguous within a reassociation group if after flattening the
298 /// reassociation group to a single 1D range, then the slice taken out of the
299 /// group could be defined as a single contiguous subrange within that range.
300 ///
301 /// Rank reducing slices are not supported.
302 ///
303 /// Example:
304 /// The transformation is possible because each reassociation group has a
305 /// contiguous slice (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4]).
306 /// ```
307 /// BEFORE:
308 /// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
309 /// tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
310 /// %slice = tensor.extract_slice %reshape ...
311 /// tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
312 ///
313 /// AFTER:
314 /// %slice = tensor.extract_slice %in ...
315 /// tensor<8x16x32xf32> to tensor<8x5x4xf32>
316 /// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
317 /// tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
318 /// ```
319 ///
320 /// Note - this pattern could be extended to be a swap pattern between
321 /// `tensor.expand_shape` and `tensor.extract_slice`, but is currently
322 /// implemented only as a bubble up pattern for `tensor.extract_slice`.
323 struct BubbleUpExpandShapeThroughExtractSlice
324  : public OpRewritePattern<tensor::ExtractSliceOp> {
326 
327  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
328  PatternRewriter &rewriter) const override {
329  auto expandShapeOp =
330  sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
331 
332  if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp,
333  rewriter)
334  .failed())
335  return failure();
336 
337  // The tensor.extract_slice before applying the pattern works on the result
338  // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
339  // referring to the state before applying the pattern are named with the
340  // prefix "expanded", and ones referring to the state after applying the
341  // pattern are named with the prefix "collapsed".
342  SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
343  SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
344  SmallVector<OpFoldResult> expandedShape =
345  getMixedValues(expandShapeOp.getStaticOutputShape(),
346  expandShapeOp.getOutputShape(), rewriter);
347 
348  // Helper variables and function for accumulating the size values.
349  Location loc = expandShapeOp->getLoc();
350  AffineExpr d0, d1, d2;
351  bindDims(rewriter.getContext(), d0, d1, d2);
352  // Multiply two integers.
353  auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
354  auto mulMap = AffineMap::get(2, 0, {d0 * d1});
355  return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
356  {v1, v2});
357  };
358 
359  // Compute new offsets, sizes, and strides for tensor.extract_slice.
360  // The new tensor.extract_slice will work on a tensor that has has a rank of
361  // ReassociationIndices.size(). In the loop a single offset, size, and
362  // stride value is computed per reassociation group.
363  SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes,
364  collapsedStrides;
365  for (const ReassociationIndices &indices :
366  expandShapeOp.getReassociationIndices()) {
367  // collapsedSize will hold the size of the single dim that represents the
368  // reassociation group in the non expanded tensor.
369  OpFoldResult collapsedSize = rewriter.getIndexAttr(1);
370  // The reassocGroupSizes and reassocGroupOffsets are used to create an
371  // affine.linearize_index op to linearize the single offset value required
372  // for this reassociation group.
373  SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
374 
375  for (long expandedDim : indices) {
376  // reassocGroupSizes and reassocGroupOffsets can be obtained directly
377  // from the expanded state, but the collapsed size requires calculation
378  // as it did not previously exist.
379  reassocGroupSizes.push_back(expandedShape[expandedDim]);
380  reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
381  collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
382  }
383 
384  SmallVector<Value> offsetVals =
385  llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) {
386  return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
387  });
388  OpFoldResult collapsedOffset =
389  rewriter
390  .create<affine::AffineLinearizeIndexOp>(loc, offsetVals,
391  reassocGroupSizes,
392  /*disjoint=*/true)
393  .getResult();
394  collapsedOffsets.push_back(collapsedOffset);
395  collapsedSizes.push_back(collapsedSize);
396 
397  // Only unit stride is supported.
398  collapsedStrides.push_back(rewriter.getIndexAttr(1));
399  }
400 
401  // The shape of the result can be obtained from the sizes passed in.
402  SmallVector<Value> dynDims;
403  SmallVector<int64_t> shape;
404  dispatchIndexOpFoldResults(expandedSizes, dynDims, shape);
405  RankedTensorType resultType = RankedTensorType::get(
406  shape, expandShapeOp.getResultType().getElementType());
407 
408  // Create a new ExtractSliceOp and ExpandShapeOp.
409  Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
410  loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes,
411  collapsedStrides);
412  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
413  sliceOp, resultType, newSliceOp,
414  expandShapeOp.getReassociationIndices(), expandedSizes);
415  return success();
416  }
417 
418  // Helper function to check if all the required conditions for the
419  // tensor.extract_slice to be bubbled up through the tensor.expand_shape are
420  // met.
421  LogicalResult
422  checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp,
423  tensor::ExpandShapeOp expandShapeOp,
424  PatternRewriter &rewriter) const {
425 
426  if (!expandShapeOp) {
427  return rewriter.notifyMatchFailure(
428  sliceOp, "tensor.extract_slice source not produced by expand_shape");
429  }
430 
431  if (!sliceOp.hasUnitStride()) {
432  return rewriter.notifyMatchFailure(
433  sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
434  "be supported in this transformation.");
435  }
436 
437  SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
438  SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
439 
440  if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
441  sizes.size()) {
442  return rewriter.notifyMatchFailure(sliceOp,
443  "unimplemented: rank reducing slice");
444  }
445 
446  SmallVector<OpFoldResult> outputShape =
447  getMixedValues(expandShapeOp.getStaticOutputShape(),
448  expandShapeOp.getOutputShape(), rewriter);
449 
450  std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
451  isZeroOffsetAndFullSize =
452  [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
453  if (!isZeroInteger(offset))
454  return false;
455  FailureOr<bool> maybeEqual =
456  ValueBoundsConstraintSet::areEqual(sliceSize, size);
457  return llvm::succeeded(maybeEqual) && maybeEqual.value();
458  };
459 
460  // Check that the slice is contiguous within each reassociation group.
461  // The slice is contiguous only if after the first dimension where a non
462  // unit slice is taken, the slice size on all subsequent dimensions of the
463  // group is equal to the entire size of the dimension.
464  // Examples of contiguous slices:
465  // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
466  // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
467  // Examples of non contiguous slices:
468  // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
469  // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
470  for (const ReassociationIndices &indices :
471  expandShapeOp.getReassociationIndices()) {
472  int64_t i = 0;
473  int64_t e = indices.size();
474  // Find the first expanded dim after the first dim with non-unit extracted
475  // size.
476  for (; i < e; ++i) {
477  if (!isOneInteger(sizes[indices[i]])) {
478  // +1 to skip the first non-unit size dim.
479  i++;
480  break;
481  }
482  }
483 
484  // Verify that all subsequent dimensions extract the full size of the
485  // source tensor.
486  for (; i < e; ++i) {
487  int64_t expandedDim = indices[i];
488  if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
489  outputShape[expandedDim])) {
490  return rewriter.notifyMatchFailure(
491  sliceOp, "Not a contiguous slice of the expanded tensor.");
492  }
493  }
494  }
495 
496  return success();
497  }
498 };
499 
500 /// Converts `tensor.extract_slice(tensor.collapse_shape)` to
501 /// `tensor.collapse_shape(tensor.extract_slice)`.
502 ///
503 /// For this transformation to be possible - after bubbling up, the extraction
504 /// of the contiguous slice must be representable as a single slice obtained via
505 /// tensor.extract_slice within each reassociation group of the src.
506 ///
507 /// In case the size and offset extracted are static then this is possible if
508 /// the following conditions are met within each reassociation group:
509 /// Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the
510 /// dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the
511 /// shape of a desired slice. A slice of shape S can be extracted as a
512 /// contiguous span of elements if and only if there exists an index k in {0, 1,
513 /// ..., n} such that:
514 /// S_i = 1 for all i < k (that is, all leading dimensions are singleton),
515 /// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly
516 /// one dimension),
517 /// S_i = A_i for all i > k (that is, all trailing dimensions are preserved
518 /// in full).
519 /// In other words, the slice shape S must be of the form:
520 /// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ]
521 ///
522 /// In case the size and/or offset extracted are dynamic then this is possible
523 /// only if there is single dimension in the reassociation group that has a size
524 /// not equal to 1.
525 /// In other words, the tensor shape must be of the form:
526 /// [ 1, 1, ..., 1, A, 1, ...,1 ]
527 /// Note - it might be possible to enable this pattern for more cases when the
528 /// size/offset are dynamic via performing an analysis of the possible values
529 /// that could be given to the size/offset.
530 ///
531 /// Example:
532 /// The transformation is possible because each reassociation group can be
533 /// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?],
534 /// [20->10]).
535 /// ```
536 /// BEFORE:
537 /// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ...
538 /// tensor<8x16x1x7x20f32> to tensor<128x7x20xf32>
539 /// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1]
540 /// tensor<128x7x20xf32> to tensor<32x?x10xf32>
541 ///
542 /// AFTER:
543 /// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10]
544 // [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32>
545 /// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ...
546 /// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32>
547 /// ```
548 ///
549 /// Negative example:
550 /// The transformation is not possible because we cannot use a single slice to
551 /// represent the reassociation group [2x3x10->???]. If we would want the
552 /// collapse to be after the extraction, we would need to extract multiple
553 /// slices and concat them together.
554 /// ```
555 /// %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into
556 /// tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] :
557 /// tensor<60xf32> to tensor<15xf32>
558 /// ```
559 /// If we would want the collapse to be after the extraction, a possible
560 /// alternate transformation could be to extract multiple slices and concat them
561 /// together:
562 /// ```
563 /// %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] :
564 /// tensor<2x3x10xf32> to tensor <1x1x10xf32>
565 /// %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] :
566 /// tensor<2x3x10xf32> to tensor <1x1x5xf32>
567 /// %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} :
568 /// (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32>
569 /// %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32>
570 /// to tensor<15xf32>
571 /// ```
572 /// But this is not the intended purpose of the transformation.
573 struct BubbleUpCollapseShapeThroughExtractSlice
574  : public OpRewritePattern<tensor::ExtractSliceOp> {
576 
577  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
578  PatternRewriter &rewriter) const override {
579  auto collapseShapeOp =
580  sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
581  if (!collapseShapeOp) {
582  return rewriter.notifyMatchFailure(
583  sliceOp,
584  "tensor.extract_slice source not produced by tensor.collapse_shape");
585  }
586 
587  if (!sliceOp.hasUnitStride()) {
588  return rewriter.notifyMatchFailure(
589  sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
590  "be supported in this transformation.");
591  }
592 
593  // The tensor.extract_slice before applying the pattern works on the result
594  // of the tensor.collapse_shape, so variables (i.e. inputs for
595  // ExtractSliceOp) referring to the state before applying the pattern are
596  // named with the prefix "collapsed", and ones referring to the state after
597  // applying the pattern are named with the prefix "expanded".
598  SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
599  SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
600 
601  if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
602  collapsedSizes.size()) {
603  return rewriter.notifyMatchFailure(sliceOp,
604  "unimplemented: rank reducing slice");
605  }
606 
607  ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape();
608  SmallVector<ReassociationIndices, 4> reassociationIndices =
609  collapseShapeOp.getReassociationIndices();
610 
611  // Compute new offsets, sizes, and strides for tensor.extract_slice.
612  // The new tensor.extract_slice will work on a tensor that has has a rank
613  // equal to the rank of the src of the collapse_shape. In each iteration of
614  // the loop, the offsets and sizes will be computed per reassociation group.
615  SmallVector<OpFoldResult> expandedOffsets, expandedSizes;
616  SmallVector<OpFoldResult> expandedStrides(srcShape.size(),
617  rewriter.getIndexAttr(1));
618 
619  for (auto [collapsedSize, collapsedOffset, reassocIndices] :
620  llvm::zip_equal(collapsedSizes, collapsedOffsets,
621  collapseShapeOp.getReassociationIndices())) {
622  // CASE #1 - size and/or offset are dynamic.
623  // In this case, the slice can be represented as a contiguous slice only
624  // if there is a single dimension in the reassociation group that has a
625  // size not equal to 1.
626  if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
627  int nonUnitSizeCount = 0;
628  for (int64_t expandedShapeIdx : reassocIndices) {
629  if (srcShape[expandedShapeIdx] != 1) {
630  nonUnitSizeCount++;
631  expandedSizes.push_back(collapsedSize);
632  expandedOffsets.push_back(collapsedOffset);
633  continue;
634  }
635 
636  expandedSizes.push_back(rewriter.getIndexAttr(1));
637  expandedOffsets.push_back(rewriter.getIndexAttr(0));
638  }
639 
640  if (nonUnitSizeCount != 1) {
641  return rewriter.notifyMatchFailure(
642  sliceOp,
643  "unsupported: slice cannot be verified to be contiguous");
644  }
645  continue;
646  }
647 
648  // CASE #2 = size and offset are static.
649  // Verify that the slice can be represented as a contiguous slice of the
650  // src of the collapse_shape.
651  // Checking this is done on order of most internal dimensions first,
652  // so traversal is done in reverse order of the reassociation group.
653  // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
654  // ...,An] then we first find the size and offset for n...k+1 then for k
655  // and then for k-1...0.
656 
657  // currentCollapsedsize and currentCollapsedOffset are initialized with
658  // the original collapsed size and offset and divided by the expanded
659  // shape size in each dimension as we go along the reassociation group.
660  // In essence we are spreading the original collapsed size and offset over
661  // the various expanded slice dimensions.
662  // The variables are used both to check the validity of the slice and to
663  // compute the expanded sizes and offsets.
664  int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
665  int64_t currentCollapsedOffset =
666  getConstantIntValue(collapsedOffset).value();
667 
668  SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
669 
670  ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
671  reassocIndices.rend());
672  int64_t idx = 0;
673  int64_t reassocGroupSize = reassocIndices.size();
674 
675  // First handle the trailing dimensions where the slice size should be
676  // equal to the tensor shape and the offset should be 0 (n...k+1).
677  for (; idx < reassocGroupSize; ++idx) {
678  int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
679 
680  if (currentCollapsedsize < expandedShapeSize)
681  break;
682 
683  // We need to make sure that the slice size can be set to the shape size
684  // and the offset to 0.
685  if ((currentCollapsedsize % expandedShapeSize) != 0 ||
686  (currentCollapsedOffset % expandedShapeSize) != 0) {
687  return rewriter.notifyMatchFailure(
688  sliceOp, "unsupported: cannot be extracted as a contiguous slice "
689  "of the src of the collapse_shape");
690  }
691 
692  groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize));
693  groupExpandedOffsets.push_back(rewriter.getIndexAttr(0));
694 
695  currentCollapsedsize /= expandedShapeSize;
696  currentCollapsedOffset /= expandedShapeSize;
697  }
698 
699  // Now handle the first dim where slicing occurs on (k).
700  if (idx < reassocGroupSize) {
701  int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
702  int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
703  // We need to make sure that the slice size in this dim + offset will
704  // not exceed the shape size.
705  if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
706  return rewriter.notifyMatchFailure(
707  sliceOp, "unsupported: slice cannot be extracted as a contiguous "
708  "slice of the src of the collapse_shape");
709  }
710 
711  groupExpandedSizes.push_back(
712  rewriter.getIndexAttr(currentCollapsedsize));
713  groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
714 
715  currentCollapsedOffset /= expandedShapeSize;
716  }
717 
718  // Now handle the leading dimensions where the slice size is equal to 1
719  // (k-1...0).
720  // The size for these dimensions must be 1 because of how we constructed
721  // the slice size of the expanded shape. We spread the original collapsed
722  // size over the expanded shape sizes until we reached dimension k where
723  // the remaining size was smaller than the expanded shape size, and spread
724  // the remaining size on it. So, now we are left with only 1s.
725  for (idx++; idx < reassocGroupSize; ++idx) {
726  int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
727  int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
728  groupExpandedSizes.push_back(rewriter.getIndexAttr(1));
729  groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
730  currentCollapsedOffset /= expandedShapeSize;
731  }
732 
733  expandedSizes.append(groupExpandedSizes.rbegin(),
734  groupExpandedSizes.rend());
735  expandedOffsets.append(groupExpandedOffsets.rbegin(),
736  groupExpandedOffsets.rend());
737  }
738 
739  Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
740  collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), expandedOffsets,
741  expandedSizes, expandedStrides);
742  rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
743  sliceOp, sliceOp.getResultType(), newSliceOp,
744  collapseShapeOp.getReassociationIndices());
745 
746  return success();
747  }
748 };
749 
750 } // namespace
751 
754  patterns
755  .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
756  FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
757  FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
758  FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
759  FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
760  patterns.getContext());
761 }
762 
765  patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
766 }
767 
770  patterns.add<BubbleUpExpandShapeThroughExtractSlice,
771  BubbleUpCollapseShapeThroughExtractSlice>(patterns.getContext());
772 }
Base type for affine expression.
Definition: AffineExpr.h:68
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
MLIRContext * getContext() const
Definition: Builders.h:55
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:594
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1225
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:73
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:340
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314