MLIR  22.0.0git
ReshapePatterns.cpp
Go to the documentation of this file.
1 //===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/PatternMatch.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/Support/LogicalResult.h"
17 
18 using namespace mlir;
19 using namespace mlir::tensor;
20 
21 namespace {
22 /// Fold expand_shape(extract_slice) ops that cancel itself out.
23 struct FoldExpandOfRankReducingExtract
24  : public OpRewritePattern<ExpandShapeOp> {
26 
27  LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
28  PatternRewriter &rewriter) const override {
29  RankedTensorType resultType = expandShapeOp.getResultType();
30  auto extractSliceOp =
31  expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
32  if (!extractSliceOp)
33  return failure();
34  RankedTensorType srcType = extractSliceOp.getSourceType();
35 
36  // Only cases where the ExpandShapeOp can be folded away entirely are
37  // supported. Moreover, only simple cases where the resulting ExtractSliceOp
38  // has no rank-reduction anymore are supported at the moment.
39  RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
40  srcType, extractSliceOp.getStaticOffsets(),
41  extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
42  if (nonReducingExtractType != resultType)
43  return failure();
44 
45  SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
46  SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
47  SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
48  rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
49  expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
50  mixedStrides);
51  return success();
52  }
53 };
54 
55 /// Fold collapse_shape which only removes static dimensions of size `1`
56 /// into extract_slice.
57 struct FoldUnPaddingCollapseIntoExtract
58  : public OpRewritePattern<tensor::CollapseShapeOp> {
60 
61  LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
62  PatternRewriter &rewriter) const override {
63  auto extractSliceOp =
64  collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
65  // Collapse cannot be folded away with multiple users of the extract slice
66  // and it is not necessarily beneficial to only convert the collapse into
67  // another extract slice.
68  if (!extractSliceOp || !extractSliceOp->hasOneUse())
69  return failure();
70 
71  // Only fold away simple collapse where all removed dimensions have static
72  // size `1`.
74  collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
76  return rewriter.notifyMatchFailure(collapseShapeOp,
77  "expected unpadding collapse");
78 
79  Value unPaddedExtractSlice = tensor::ExtractSliceOp::create(
80  rewriter, extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
81  extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
82  extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
83  rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice);
84  return success();
85  }
86 };
87 
88 /// Fold insert_slice(collapse_shape) ops that cancel itself out.
89 template <typename OpTy>
90 struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
92 
93  LogicalResult matchAndRewrite(OpTy insertSliceOp,
94  PatternRewriter &rewriter) const override {
95  auto collapseShapeOp =
96  insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
97  if (!collapseShapeOp)
98  return failure();
99  RankedTensorType srcType = collapseShapeOp.getSrcType();
100 
101  // Only cases where the CollapseShapeOp can be folded away entirely are
102  // supported. Moreover, only simple cases where the resulting InsertSliceOp
103  // has no rank-reduction anymore are supported at the moment.
104  RankedTensorType nonReducingInsertType =
105  RankedTensorType::get(insertSliceOp.getStaticSizes(),
106  insertSliceOp.getDestType().getElementType());
107  if (nonReducingInsertType != srcType)
108  return failure();
109 
110  SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
111  SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
112  SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
113  rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(),
114  insertSliceOp.getDest(), mixedOffsets,
115  mixedSizes, mixedStrides);
116  return success();
117  }
118 };
119 
120 /// Fold expand_shape which only adds static dimensions of size `1`
121 /// into insert_slice.
122 template <typename OpTy>
123 struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
125 
126  LogicalResult matchAndRewrite(OpTy insertSliceOp,
127  PatternRewriter &rewriter) const override {
128  auto expandShapeOp = insertSliceOp.getSource()
129  .template getDefiningOp<tensor::ExpandShapeOp>();
130  if (!expandShapeOp)
131  return failure();
132 
133  // Only fold away simple expansion where all added dimensions have static
134  // size `1`.
136  expandShapeOp.getResultType(), expandShapeOp.getSrcType());
138  return rewriter.notifyMatchFailure(insertSliceOp,
139  "expected rank increasing expansion");
140 
141  rewriter.modifyOpInPlace(insertSliceOp, [&]() {
142  insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
143  });
144  return success();
145  }
146 };
147 
148 /// Pattern to bubble up a tensor.expand_shape op through a producer
149 /// tensor.collapse_shape op that has non intersecting reassociations.
150 struct BubbleUpExpandThroughParallelCollapse
151  : public OpRewritePattern<tensor::ExpandShapeOp> {
153 
154  LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
155  PatternRewriter &rewriter) const override {
156  auto collapseOp =
157  expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
158  if (!collapseOp)
159  return failure();
160  auto expandReInds = expandOp.getReassociationIndices();
161  auto collapseReInds = collapseOp.getReassociationIndices();
162 
163  // Special case where the collapsed tensor to expand is a 0-D tensor,
164  // then the reassociation maps will be empty and not produce valid results.
165  if (expandReInds.size() == 0) {
166  return failure();
167  }
168 
169  // Reshapes are parallel to each other (by construction the number of
170  // reassociations specified in the collapse and expand are the same), if at
171  // any position
172  // 1. either the reassociation indices are of the same size, or
173  // 2. either the reassociation in the collapse or the expand is of size 1.
174  ArrayRef<int64_t> staticSourceSize = collapseOp.getSrcType().getShape();
175  ArrayRef<int64_t> staticResultSize = expandOp.getStaticOutputShape();
176  for (auto [expandReassociation, collapseReassociation] :
177  llvm::zip_equal(expandReInds, collapseReInds)) {
178  if (collapseReassociation.size() == expandReassociation.size()) {
179  // Even if the reassociations are the same, the collapse/expand should
180  // result in the same dimensions. i.e 4x8x2 into 64 should be expanded
181  // into 4x8x2 again. In presense of dynamic dimensions one can only
182  // verify "equality" when there is only one dynamic dimension present,
183  // and all other static dimensions are equal.
184  ArrayRef<int64_t> collapsedStaticShapes = staticSourceSize.slice(
185  collapseReassociation.front(), collapseReassociation.size());
186  int64_t numCollapsedDynamic =
187  llvm::count_if(collapsedStaticShapes, ShapedType::isDynamic);
188  ArrayRef<int64_t> expandedStaticShapes = staticResultSize.slice(
189  expandReassociation.front(), expandReassociation.size());
190  int64_t numExpandedDynamic =
191  llvm::count_if(expandedStaticShapes, ShapedType::isDynamic);
192  if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
193  collapsedStaticShapes != expandedStaticShapes) {
194  return failure();
195  }
196  continue;
197  }
198  // If the reassociations are not same, one or the other needs to be of
199  // size one.
200  if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
201  return failure();
202  }
203 
204  // Compute new reassociation indices and expanded/collaped shapes.
205  SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
206  Location loc = expandOp->getLoc();
207  SmallVector<OpFoldResult> sourceSizes =
208  tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
209  SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape();
210  SmallVector<OpFoldResult> newExpandSizes;
211 
212  int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0,
213  resultSizeIndex = 0;
214 
215  for (size_t idx = 0, idxEnd = collapseReInds.size(); idx < idxEnd; idx++) {
216  auto &collapseReassociation = collapseReInds[idx];
217  auto &expandReassociation = expandReInds[idx];
218 
219  // Case 1. The reassociations are same in the collapse producer
220  // and expand consumer. In the swapped expand, each of the final
221  // dimensions are kept as is in the expand and the collapse. So,
222  // for every element in the `ReassocationIndices` vector add a new
223  // `ReassociationIndices` vector for the swapped expand and collapse
224  // (of size 1).
225  if (collapseReassociation.size() == expandReassociation.size()) {
226  for (size_t i = 0; i < collapseReassociation.size(); ++i) {
227  newCollapseReInds.push_back({newCollapseIndex++});
228  newExpandReInds.push_back({newExpandIndex++});
229  newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
230  sourceSizeIndex++;
231  }
232  continue;
233  }
234 
235  // Case 2. The `ReassociationIndices` in the collapse is of size > 1 (and
236  // in the expand is of size == 1). In this case, the original dimensions
237  // are preserved on expansion and collapsed subsequently.
238  if (collapseReassociation.size() != 1) {
239  ReassociationIndices newCollapseReassociation;
240  for (size_t i = 0; i < collapseReassociation.size(); ++i) {
241  newCollapseReassociation.push_back(newCollapseIndex++);
242  newExpandReInds.push_back({newExpandIndex++});
243  newExpandSizes.push_back(sourceSizes[sourceSizeIndex++]);
244  }
245  resultSizeIndex++;
246  newCollapseReInds.push_back(newCollapseReassociation);
247  continue;
248  }
249 
250  // Case 3. The `ReassociationIndices` in the expand is of size > 1 (and
251  // in the collapse is of size == 1). In this case, the expansion happens
252  // first and the expanded dimensions are preserved on collapse.
253  ReassociationIndices newExpandReassociation;
254  for (size_t i = 0; i < expandReassociation.size(); ++i) {
255  newExpandReassociation.push_back(newExpandIndex++);
256  newCollapseReInds.push_back({newCollapseIndex++});
257  newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
258  }
259  newExpandReInds.push_back(newExpandReassociation);
260  sourceSizeIndex++;
261  }
262 
263  // Swap reshape order.
264  SmallVector<Value> dynamicSizes;
265  SmallVector<int64_t> staticSizes;
266  dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
267  auto expandResultType = expandOp.getResultType().clone(staticSizes);
268  Value newCollapseSrc = collapseOp.getSrc();
269  // If the number of reassociation indices in the new `expand_shape` op
270  // matches the number of dimensions of the result, then the expand_shape
271  // is a no-op.
272  if (newExpandReInds.size() != newExpandSizes.size()) {
273  newCollapseSrc = tensor::ExpandShapeOp::create(
274  rewriter, loc, expandResultType, newCollapseSrc, newExpandReInds,
275  newExpandSizes);
276  }
277 
278  // If the number of reassociation indices in the new `collapse_shape` op
279  // matches the number of dimensions of the source, then the collapse_shape
280  // is a no-op.
281  Value replacement = newCollapseSrc;
282  if (newCollapseReInds.size() != newExpandSizes.size()) {
283  replacement = tensor::CollapseShapeOp::create(
284  rewriter, loc, newCollapseSrc, newCollapseReInds);
285  }
286  rewriter.replaceOp(expandOp, replacement);
287  return success();
288  }
289 };
290 
291 /// Converts `tensor.extract_slice(tensor.expand_shape)` to
292 /// `tensor.expand_shape(tensor.extract_slice)`.
293 ///
294 /// For this transformation to be possible, the slice must be fully contiguous
295 /// within each reassociation group of the expand_shape. A slice is defined as
296 /// fully contiguous within a reassociation group if after flattening the
297 /// reassociation group to a single 1D range, then the slice taken out of the
298 /// group could be defined as a single contiguous subrange within that range.
299 ///
300 /// Rank reducing slices are not supported.
301 ///
302 /// Example:
303 /// The transformation is possible because each reassociation group has a
304 /// contiguous slice (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4]).
305 /// ```
306 /// BEFORE:
307 /// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
308 /// tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
309 /// %slice = tensor.extract_slice %reshape ...
310 /// tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
311 ///
312 /// AFTER:
313 /// %slice = tensor.extract_slice %in ...
314 /// tensor<8x16x32xf32> to tensor<8x5x4xf32>
315 /// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
316 /// tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
317 /// ```
318 ///
319 /// Note - this pattern could be extended to be a swap pattern between
320 /// `tensor.expand_shape` and `tensor.extract_slice`, but is currently
321 /// implemented only as a bubble up pattern for `tensor.extract_slice`.
322 struct BubbleUpExpandShapeThroughExtractSlice
323  : public OpRewritePattern<tensor::ExtractSliceOp> {
325 
326  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
327  PatternRewriter &rewriter) const override {
328  auto expandShapeOp =
329  sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
330 
331  if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp,
332  rewriter)
333  .failed())
334  return failure();
335 
336  // The tensor.extract_slice before applying the pattern works on the result
337  // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
338  // referring to the state before applying the pattern are named with the
339  // prefix "expanded", and ones referring to the state after applying the
340  // pattern are named with the prefix "collapsed".
341  SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
342  SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
343  SmallVector<OpFoldResult> expandedShape =
344  getMixedValues(expandShapeOp.getStaticOutputShape(),
345  expandShapeOp.getOutputShape(), rewriter);
346 
347  // Helper variables and function for accumulating the size values.
348  Location loc = expandShapeOp->getLoc();
349  AffineExpr d0, d1, d2;
350  bindDims(rewriter.getContext(), d0, d1, d2);
351  // Multiply two integers.
352  auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
353  auto mulMap = AffineMap::get(2, 0, {d0 * d1});
354  return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
355  {v1, v2});
356  };
357 
358  // Compute new offsets, sizes, and strides for tensor.extract_slice.
359  // The new tensor.extract_slice will work on a tensor that has has a rank of
360  // ReassociationIndices.size(). In the loop a single offset, size, and
361  // stride value is computed per reassociation group.
362  SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes,
363  collapsedStrides;
364  for (const ReassociationIndices &indices :
365  expandShapeOp.getReassociationIndices()) {
366  // collapsedSize will hold the size of the single dim that represents the
367  // reassociation group in the non expanded tensor.
368  OpFoldResult collapsedSize = rewriter.getIndexAttr(1);
369  // The reassocGroupSizes and reassocGroupOffsets are used to create an
370  // affine.linearize_index op to linearize the single offset value required
371  // for this reassociation group.
372  SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
373 
374  for (long expandedDim : indices) {
375  // reassocGroupSizes and reassocGroupOffsets can be obtained directly
376  // from the expanded state, but the collapsed size requires calculation
377  // as it did not previously exist.
378  reassocGroupSizes.push_back(expandedShape[expandedDim]);
379  reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
380  collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
381  }
382 
383  SmallVector<Value> offsetVals =
384  llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) {
385  return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
386  });
387  OpFoldResult collapsedOffset =
388  rewriter
389  .create<affine::AffineLinearizeIndexOp>(loc, offsetVals,
390  reassocGroupSizes,
391  /*disjoint=*/true)
392  .getResult();
393  collapsedOffsets.push_back(collapsedOffset);
394  collapsedSizes.push_back(collapsedSize);
395 
396  // Only unit stride is supported.
397  collapsedStrides.push_back(rewriter.getIndexAttr(1));
398  }
399 
400  // The shape of the result can be obtained from the sizes passed in.
401  SmallVector<Value> dynDims;
402  SmallVector<int64_t> shape;
403  dispatchIndexOpFoldResults(expandedSizes, dynDims, shape);
404  RankedTensorType resultType = RankedTensorType::get(
405  shape, expandShapeOp.getResultType().getElementType());
406 
407  // Create a new ExtractSliceOp and ExpandShapeOp.
408  Value newSliceOp = tensor::ExtractSliceOp::create(
409  rewriter, loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes,
410  collapsedStrides);
411  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
412  sliceOp, resultType, newSliceOp,
413  expandShapeOp.getReassociationIndices(), expandedSizes);
414  return success();
415  }
416 
417  // Helper function to check if all the required conditions for the
418  // tensor.extract_slice to be bubbled up through the tensor.expand_shape are
419  // met.
420  LogicalResult
421  checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp,
422  tensor::ExpandShapeOp expandShapeOp,
423  PatternRewriter &rewriter) const {
424 
425  if (!expandShapeOp) {
426  return rewriter.notifyMatchFailure(
427  sliceOp, "tensor.extract_slice source not produced by expand_shape");
428  }
429 
430  if (!sliceOp.hasUnitStride()) {
431  return rewriter.notifyMatchFailure(
432  sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
433  "be supported in this transformation.");
434  }
435 
436  SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
437  SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
438 
439  if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
440  sizes.size()) {
441  return rewriter.notifyMatchFailure(sliceOp,
442  "unimplemented: rank reducing slice");
443  }
444 
445  SmallVector<OpFoldResult> outputShape =
446  getMixedValues(expandShapeOp.getStaticOutputShape(),
447  expandShapeOp.getOutputShape(), rewriter);
448 
449  std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
450  isZeroOffsetAndFullSize =
451  [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
452  if (!isZeroInteger(offset))
453  return false;
454  FailureOr<bool> maybeEqual =
455  ValueBoundsConstraintSet::areEqual(sliceSize, size);
456  return llvm::succeeded(maybeEqual) && maybeEqual.value();
457  };
458 
459  // Check that the slice is contiguous within each reassociation group.
460  // The slice is contiguous only if after the first dimension where a non
461  // unit slice is taken, the slice size on all subsequent dimensions of the
462  // group is equal to the entire size of the dimension.
463  // Examples of contiguous slices:
464  // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
465  // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
466  // Examples of non contiguous slices:
467  // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
468  // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
469  for (const ReassociationIndices &indices :
470  expandShapeOp.getReassociationIndices()) {
471  int64_t i = 0;
472  int64_t e = indices.size();
473  // Find the first expanded dim after the first dim with non-unit extracted
474  // size.
475  for (; i < e; ++i) {
476  if (!isOneInteger(sizes[indices[i]])) {
477  // +1 to skip the first non-unit size dim.
478  i++;
479  break;
480  }
481  }
482 
483  // Verify that all subsequent dimensions extract the full size of the
484  // source tensor.
485  for (; i < e; ++i) {
486  int64_t expandedDim = indices[i];
487  if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
488  outputShape[expandedDim])) {
489  return rewriter.notifyMatchFailure(
490  sliceOp, "Not a contiguous slice of the expanded tensor.");
491  }
492  }
493  }
494 
495  return success();
496  }
497 };
498 
499 /// Converts `tensor.extract_slice(tensor.collapse_shape)` to
500 /// `tensor.collapse_shape(tensor.extract_slice)`.
501 ///
502 /// For this transformation to be possible - after bubbling up, the extraction
503 /// of the contiguous slice must be representable as a single slice obtained via
504 /// tensor.extract_slice within each reassociation group of the src.
505 ///
506 /// In case the size and offset extracted are static then this is possible if
507 /// the following conditions are met within each reassociation group:
508 /// Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the
509 /// dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the
510 /// shape of a desired slice. A slice of shape S can be extracted as a
511 /// contiguous span of elements if and only if there exists an index k in {0, 1,
512 /// ..., n} such that:
513 /// S_i = 1 for all i < k (that is, all leading dimensions are singleton),
514 /// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly
515 /// one dimension),
516 /// S_i = A_i for all i > k (that is, all trailing dimensions are preserved
517 /// in full).
518 /// In other words, the slice shape S must be of the form:
519 /// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ]
520 ///
521 /// In case the size and/or offset extracted are dynamic then this is possible
522 /// only if there is single dimension in the reassociation group that has a size
523 /// not equal to 1.
524 /// In other words, the tensor shape must be of the form:
525 /// [ 1, 1, ..., 1, A, 1, ...,1 ]
526 /// Note - it might be possible to enable this pattern for more cases when the
527 /// size/offset are dynamic via performing an analysis of the possible values
528 /// that could be given to the size/offset.
529 ///
530 /// Example:
531 /// The transformation is possible because each reassociation group can be
532 /// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?],
533 /// [20->10]).
534 /// ```
535 /// BEFORE:
536 /// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ...
537 /// tensor<8x16x1x7x20f32> to tensor<128x7x20xf32>
538 /// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1]
539 /// tensor<128x7x20xf32> to tensor<32x?x10xf32>
540 ///
541 /// AFTER:
542 /// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10]
543 // [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32>
544 /// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ...
545 /// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32>
546 /// ```
547 ///
548 /// Negative example:
549 /// The transformation is not possible because we cannot use a single slice to
550 /// represent the reassociation group [2x3x10->???]. If we would want the
551 /// collapse to be after the extraction, we would need to extract multiple
552 /// slices and concat them together.
553 /// ```
554 /// %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into
555 /// tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] :
556 /// tensor<60xf32> to tensor<15xf32>
557 /// ```
558 /// If we would want the collapse to be after the extraction, a possible
559 /// alternate transformation could be to extract multiple slices and concat them
560 /// together:
561 /// ```
562 /// %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] :
563 /// tensor<2x3x10xf32> to tensor <1x1x10xf32>
564 /// %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] :
565 /// tensor<2x3x10xf32> to tensor <1x1x5xf32>
566 /// %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} :
567 /// (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32>
568 /// %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32>
569 /// to tensor<15xf32>
570 /// ```
571 /// But this is not the intended purpose of the transformation.
572 struct BubbleUpCollapseShapeThroughExtractSlice
573  : public OpRewritePattern<tensor::ExtractSliceOp> {
575 
576  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
577  PatternRewriter &rewriter) const override {
578  auto collapseShapeOp =
579  sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
580  if (!collapseShapeOp) {
581  return rewriter.notifyMatchFailure(
582  sliceOp,
583  "tensor.extract_slice source not produced by tensor.collapse_shape");
584  }
585 
586  if (!sliceOp.hasUnitStride()) {
587  return rewriter.notifyMatchFailure(
588  sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
589  "be supported in this transformation.");
590  }
591 
592  // The tensor.extract_slice before applying the pattern works on the result
593  // of the tensor.collapse_shape, so variables (i.e. inputs for
594  // ExtractSliceOp) referring to the state before applying the pattern are
595  // named with the prefix "collapsed", and ones referring to the state after
596  // applying the pattern are named with the prefix "expanded".
597  SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
598  SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
599 
600  if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
601  collapsedSizes.size()) {
602  return rewriter.notifyMatchFailure(sliceOp,
603  "unimplemented: rank reducing slice");
604  }
605 
606  ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape();
607  SmallVector<ReassociationIndices, 4> reassociationIndices =
608  collapseShapeOp.getReassociationIndices();
609 
610  // Compute new offsets, sizes, and strides for tensor.extract_slice.
611  // The new tensor.extract_slice will work on a tensor that has has a rank
612  // equal to the rank of the src of the collapse_shape. In each iteration of
613  // the loop, the offsets and sizes will be computed per reassociation group.
614  SmallVector<OpFoldResult> expandedOffsets, expandedSizes;
615  SmallVector<OpFoldResult> expandedStrides(srcShape.size(),
616  rewriter.getIndexAttr(1));
617 
618  for (auto [collapsedSize, collapsedOffset, reassocIndices] :
619  llvm::zip_equal(collapsedSizes, collapsedOffsets,
620  collapseShapeOp.getReassociationIndices())) {
621  // CASE #1 - size and/or offset are dynamic.
622  // In this case, the slice can be represented as a contiguous slice only
623  // if there is a single dimension in the reassociation group that has a
624  // size not equal to 1.
625  if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
626  int nonUnitSizeCount = 0;
627  for (int64_t expandedShapeIdx : reassocIndices) {
628  if (srcShape[expandedShapeIdx] != 1) {
629  nonUnitSizeCount++;
630  expandedSizes.push_back(collapsedSize);
631  expandedOffsets.push_back(collapsedOffset);
632  continue;
633  }
634 
635  expandedSizes.push_back(rewriter.getIndexAttr(1));
636  expandedOffsets.push_back(rewriter.getIndexAttr(0));
637  }
638 
639  if (nonUnitSizeCount != 1) {
640  return rewriter.notifyMatchFailure(
641  sliceOp,
642  "unsupported: slice cannot be verified to be contiguous");
643  }
644  continue;
645  }
646 
647  // CASE #2 = size and offset are static.
648  // Verify that the slice can be represented as a contiguous slice of the
649  // src of the collapse_shape.
650  // Checking this is done on order of most internal dimensions first,
651  // so traversal is done in reverse order of the reassociation group.
652  // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
653  // ...,An] then we first find the size and offset for n...k+1 then for k
654  // and then for k-1...0.
655 
656  // currentCollapsedsize and currentCollapsedOffset are initialized with
657  // the original collapsed size and offset and divided by the expanded
658  // shape size in each dimension as we go along the reassociation group.
659  // In essence we are spreading the original collapsed size and offset over
660  // the various expanded slice dimensions.
661  // The variables are used both to check the validity of the slice and to
662  // compute the expanded sizes and offsets.
663  int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
664  int64_t currentCollapsedOffset =
665  getConstantIntValue(collapsedOffset).value();
666 
667  SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
668 
669  ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
670  reassocIndices.rend());
671  int64_t idx = 0;
672  int64_t reassocGroupSize = reassocIndices.size();
673 
674  // First handle the trailing dimensions where the slice size should be
675  // equal to the tensor shape and the offset should be 0 (n...k+1).
676  for (; idx < reassocGroupSize; ++idx) {
677  int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
678 
679  if (currentCollapsedsize < expandedShapeSize)
680  break;
681 
682  // We need to make sure that the slice size can be set to the shape size
683  // and the offset to 0.
684  if ((currentCollapsedsize % expandedShapeSize) != 0 ||
685  (currentCollapsedOffset % expandedShapeSize) != 0) {
686  return rewriter.notifyMatchFailure(
687  sliceOp, "unsupported: cannot be extracted as a contiguous slice "
688  "of the src of the collapse_shape");
689  }
690 
691  groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize));
692  groupExpandedOffsets.push_back(rewriter.getIndexAttr(0));
693 
694  currentCollapsedsize /= expandedShapeSize;
695  currentCollapsedOffset /= expandedShapeSize;
696  }
697 
698  // Now handle the first dim where slicing occurs on (k).
699  if (idx < reassocGroupSize) {
700  int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
701  int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
702  // We need to make sure that the slice size in this dim + offset will
703  // not exceed the shape size.
704  if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
705  return rewriter.notifyMatchFailure(
706  sliceOp, "unsupported: slice cannot be extracted as a contiguous "
707  "slice of the src of the collapse_shape");
708  }
709 
710  groupExpandedSizes.push_back(
711  rewriter.getIndexAttr(currentCollapsedsize));
712  groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
713 
714  currentCollapsedOffset /= expandedShapeSize;
715  }
716 
717  // Now handle the leading dimensions where the slice size is equal to 1
718  // (k-1...0).
719  // The size for these dimensions must be 1 because of how we constructed
720  // the slice size of the expanded shape. We spread the original collapsed
721  // size over the expanded shape sizes until we reached dimension k where
722  // the remaining size was smaller than the expanded shape size, and spread
723  // the remaining size on it. So, now we are left with only 1s.
724  for (idx++; idx < reassocGroupSize; ++idx) {
725  int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
726  int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
727  groupExpandedSizes.push_back(rewriter.getIndexAttr(1));
728  groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
729  currentCollapsedOffset /= expandedShapeSize;
730  }
731 
732  expandedSizes.append(groupExpandedSizes.rbegin(),
733  groupExpandedSizes.rend());
734  expandedOffsets.append(groupExpandedOffsets.rbegin(),
735  groupExpandedOffsets.rend());
736  }
737 
738  Value newSliceOp = tensor::ExtractSliceOp::create(
739  rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(),
740  expandedOffsets, expandedSizes, expandedStrides);
741  rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
742  sliceOp, sliceOp.getResultType(), newSliceOp,
743  collapseShapeOp.getReassociationIndices());
744 
745  return success();
746  }
747 };
748 
749 } // namespace
750 
753  patterns
754  .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
755  FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
756  FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
757  FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
758  FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
759  patterns.getContext());
760 }
761 
764  patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
765 }
766 
769  patterns.add<BubbleUpExpandShapeThroughExtractSlice,
770  BubbleUpCollapseShapeThroughExtractSlice>(patterns.getContext());
771 }
Base type for affine expression.
Definition: AffineExpr.h:68
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
MLIRContext * getContext() const
Definition: Builders.h:55
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:769
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:702
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:614
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1331
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:356
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314