MLIR  22.0.0git
ReshapePatterns.cpp
Go to the documentation of this file.
1 //===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/PatternMatch.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/Support/LogicalResult.h"
17 
18 using namespace mlir;
19 using namespace mlir::tensor;
20 
21 namespace {
22 /// Fold expand_shape(extract_slice) ops that cancel itself out.
23 struct FoldExpandOfRankReducingExtract
24  : public OpRewritePattern<ExpandShapeOp> {
26 
27  LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
28  PatternRewriter &rewriter) const override {
29  RankedTensorType resultType = expandShapeOp.getResultType();
30  auto extractSliceOp =
31  expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
32  if (!extractSliceOp)
33  return failure();
34  RankedTensorType srcType = extractSliceOp.getSourceType();
35 
36  // Only cases where the ExpandShapeOp can be folded away entirely are
37  // supported. Moreover, only simple cases where the resulting ExtractSliceOp
38  // has no rank-reduction anymore are supported at the moment.
39  RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
40  srcType, extractSliceOp.getStaticOffsets(),
41  extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
42  if (nonReducingExtractType != resultType)
43  return failure();
44 
45  SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
46  SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
47  SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
48  rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
49  expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
50  mixedStrides);
51  return success();
52  }
53 };
54 
55 /// Fold collapse_shape which only removes static dimensions of size `1`
56 /// into extract_slice.
57 struct FoldUnPaddingCollapseIntoExtract
58  : public OpRewritePattern<tensor::CollapseShapeOp> {
60 
61  LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
62  PatternRewriter &rewriter) const override {
63  auto extractSliceOp =
64  collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
65  // Collapse cannot be folded away with multiple users of the extract slice
66  // and it is not necessarily beneficial to only convert the collapse into
67  // another extract slice.
68  if (!extractSliceOp || !extractSliceOp->hasOneUse())
69  return failure();
70 
71  // Only fold away simple collapse where all removed dimensions have static
72  // size `1`.
74  collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
76  return rewriter.notifyMatchFailure(collapseShapeOp,
77  "expected unpadding collapse");
78 
79  Value unPaddedExtractSlice = tensor::ExtractSliceOp::create(
80  rewriter, extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
81  extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
82  extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
83  rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice);
84  return success();
85  }
86 };
87 
88 /// Fold insert_slice(collapse_shape) ops that cancel itself out.
89 template <typename OpTy>
90 struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
92 
93  LogicalResult matchAndRewrite(OpTy insertSliceOp,
94  PatternRewriter &rewriter) const override {
95  auto collapseShapeOp =
96  insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
97  if (!collapseShapeOp)
98  return failure();
99  RankedTensorType srcType = collapseShapeOp.getSrcType();
100 
101  // Only cases where the CollapseShapeOp can be folded away entirely are
102  // supported. Moreover, only simple cases where the resulting InsertSliceOp
103  // has no rank-reduction anymore are supported at the moment.
104  RankedTensorType nonReducingInsertType =
105  RankedTensorType::get(insertSliceOp.getStaticSizes(),
106  insertSliceOp.getDestType().getElementType());
107  if (nonReducingInsertType != srcType)
108  return failure();
109 
110  SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
111  SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
112  SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
113  rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(),
114  insertSliceOp.getDest(), mixedOffsets,
115  mixedSizes, mixedStrides);
116  return success();
117  }
118 };
119 
120 /// Fold expand_shape which only adds static dimensions of size `1`
121 /// into insert_slice.
122 template <typename OpTy>
123 struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
125 
126  LogicalResult matchAndRewrite(OpTy insertSliceOp,
127  PatternRewriter &rewriter) const override {
128  auto expandShapeOp = insertSliceOp.getSource()
129  .template getDefiningOp<tensor::ExpandShapeOp>();
130  if (!expandShapeOp)
131  return failure();
132 
133  // Only fold away simple expansion where all added dimensions have static
134  // size `1`.
136  expandShapeOp.getResultType(), expandShapeOp.getSrcType());
138  return rewriter.notifyMatchFailure(insertSliceOp,
139  "expected rank increasing expansion");
140 
141  rewriter.modifyOpInPlace(insertSliceOp, [&]() {
142  insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
143  });
144  return success();
145  }
146 };
147 
148 /// Pattern to bubble up a tensor.expand_shape op through a producer
149 /// tensor.collapse_shape op that has non intersecting reassociations.
150 struct BubbleUpExpandThroughParallelCollapse
151  : public OpRewritePattern<tensor::ExpandShapeOp> {
153 
154  LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
155  PatternRewriter &rewriter) const override {
156  auto collapseOp =
157  expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
158  if (!collapseOp)
159  return failure();
160  auto expandReInds = expandOp.getReassociationIndices();
161  auto collapseReInds = collapseOp.getReassociationIndices();
162 
163  // Special case where the collapsed tensor to expand is a 0-D tensor,
164  // then the reassociation maps will be empty and not produce valid results.
165  if (expandReInds.size() == 0) {
166  return failure();
167  }
168 
169  // Reshapes are parallel to each other (by construction the number of
170  // reassociations specified in the collapse and expand are the same), if at
171  // any position
172  // 1. either the reassociation indices are of the same size, or
173  // 2. either the reassociation in the collapse or the expand is of size 1.
174  ArrayRef<int64_t> staticSourceSize = collapseOp.getSrcType().getShape();
175  ArrayRef<int64_t> staticResultSize = expandOp.getStaticOutputShape();
176  for (auto [expandReassociation, collapseReassociation] :
177  llvm::zip_equal(expandReInds, collapseReInds)) {
178  if (collapseReassociation.size() == expandReassociation.size()) {
179  // Even if the reassociations are the same, the collapse/expand should
180  // result in the same dimensions. i.e 4x8x2 into 64 should be expanded
181  // into 4x8x2 again. In presense of dynamic dimensions one can only
182  // verify "equality" when there is only one dynamic dimension present,
183  // and all other static dimensions are equal.
184  ArrayRef<int64_t> collapsedStaticShapes = staticSourceSize.slice(
185  collapseReassociation.front(), collapseReassociation.size());
186  int64_t numCollapsedDynamic =
187  llvm::count_if(collapsedStaticShapes, ShapedType::isDynamic);
188  ArrayRef<int64_t> expandedStaticShapes = staticResultSize.slice(
189  expandReassociation.front(), expandReassociation.size());
190  int64_t numExpandedDynamic =
191  llvm::count_if(expandedStaticShapes, ShapedType::isDynamic);
192  if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
193  collapsedStaticShapes != expandedStaticShapes) {
194  return failure();
195  }
196  continue;
197  }
198  // If the reassociations are not same, one or the other needs to be of
199  // size one.
200  if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
201  return failure();
202  }
203 
204  // Compute new reassociation indices and expanded/collaped shapes.
205  SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
206  Location loc = expandOp->getLoc();
207  SmallVector<OpFoldResult> sourceSizes =
208  tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
209  SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape();
210  SmallVector<OpFoldResult> newExpandSizes;
211 
212  int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0,
213  resultSizeIndex = 0;
214 
215  for (size_t idx = 0, idxEnd = collapseReInds.size(); idx < idxEnd; idx++) {
216  auto &collapseReassociation = collapseReInds[idx];
217  auto &expandReassociation = expandReInds[idx];
218 
219  // Case 1. The reassociations are same in the collapse producer
220  // and expand consumer. In the swapped expand, each of the final
221  // dimensions are kept as is in the expand and the collapse. So,
222  // for every element in the `ReassocationIndices` vector add a new
223  // `ReassociationIndices` vector for the swapped expand and collapse
224  // (of size 1).
225  if (collapseReassociation.size() == expandReassociation.size()) {
226  for (size_t i = 0; i < collapseReassociation.size(); ++i) {
227  newCollapseReInds.push_back({newCollapseIndex++});
228  newExpandReInds.push_back({newExpandIndex++});
229  newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
230  sourceSizeIndex++;
231  }
232  continue;
233  }
234 
235  // Case 2. The `ReassociationIndices` in the collapse is of size > 1 (and
236  // in the expand is of size == 1). In this case, the original dimensions
237  // are preserved on expansion and collapsed subsequently.
238  if (collapseReassociation.size() != 1) {
239  ReassociationIndices newCollapseReassociation;
240  for (size_t i = 0; i < collapseReassociation.size(); ++i) {
241  newCollapseReassociation.push_back(newCollapseIndex++);
242  newExpandReInds.push_back({newExpandIndex++});
243  newExpandSizes.push_back(sourceSizes[sourceSizeIndex++]);
244  }
245  resultSizeIndex++;
246  newCollapseReInds.push_back(newCollapseReassociation);
247  continue;
248  }
249 
250  // Case 3. The `ReassociationIndices` in the expand is of size > 1 (and
251  // in the collapse is of size == 1). In this case, the expansion happens
252  // first and the expanded dimensions are preserved on collapse.
253  ReassociationIndices newExpandReassociation;
254  for (size_t i = 0; i < expandReassociation.size(); ++i) {
255  newExpandReassociation.push_back(newExpandIndex++);
256  newCollapseReInds.push_back({newCollapseIndex++});
257  newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
258  }
259  newExpandReInds.push_back(newExpandReassociation);
260  sourceSizeIndex++;
261  }
262 
263  // Swap reshape order.
264  SmallVector<Value> dynamicSizes;
265  SmallVector<int64_t> staticSizes;
266  dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
267  auto expandResultType = expandOp.getResultType().clone(staticSizes);
268  Value newCollapseSrc = collapseOp.getSrc();
269  // If the number of reassociation indices in the new `expand_shape` op
270  // matches the number of dimensions of the result, then the expand_shape
271  // is a no-op.
272  if (newExpandReInds.size() != newExpandSizes.size()) {
273  newCollapseSrc = tensor::ExpandShapeOp::create(
274  rewriter, loc, expandResultType, newCollapseSrc, newExpandReInds,
275  newExpandSizes);
276  }
277 
278  // If the number of reassociation indices in the new `collapse_shape` op
279  // matches the number of dimensions of the source, then the collapse_shape
280  // is a no-op.
281  Value replacement = newCollapseSrc;
282  if (newCollapseReInds.size() != newExpandSizes.size()) {
283  replacement = tensor::CollapseShapeOp::create(
284  rewriter, loc, newCollapseSrc, newCollapseReInds);
285  }
286  rewriter.replaceOp(expandOp, replacement);
287  return success();
288  }
289 };
290 
291 /// Converts `tensor.extract_slice(tensor.expand_shape)` to
292 /// `tensor.expand_shape(tensor.extract_slice)`.
293 ///
294 /// For this transformation to be possible, the slice must be fully contiguous
295 /// within each reassociation group of the expand_shape. A slice is defined as
296 /// fully contiguous within a reassociation group if after flattening the
297 /// reassociation group to a single 1D range, then the slice taken out of the
298 /// group could be defined as a single contiguous subrange within that range.
299 ///
300 /// Rank reducing slices are not supported.
301 ///
302 /// Example:
303 /// The transformation is possible because each reassociation group has a
304 /// contiguous slice (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4]).
305 /// ```
306 /// BEFORE:
307 /// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
308 /// tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
309 /// %slice = tensor.extract_slice %reshape ...
310 /// tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
311 ///
312 /// AFTER:
313 /// %slice = tensor.extract_slice %in ...
314 /// tensor<8x16x32xf32> to tensor<8x5x4xf32>
315 /// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
316 /// tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
317 /// ```
318 ///
319 /// Note - this pattern could be extended to be a swap pattern between
320 /// `tensor.expand_shape` and `tensor.extract_slice`, but is currently
321 /// implemented only as a bubble up pattern for `tensor.extract_slice`.
322 struct BubbleUpExpandShapeThroughExtractSlice
323  : public OpRewritePattern<tensor::ExtractSliceOp> {
325 
326  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
327  PatternRewriter &rewriter) const override {
328  auto expandShapeOp =
329  sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
330  if (!expandShapeOp) {
331  return rewriter.notifyMatchFailure(
332  sliceOp, "tensor.extract_slice source not produced by expand_shape");
333  }
334  SmallVector<ReassociationIndices> reassociation =
335  expandShapeOp.getReassociationIndices();
336 
337  SmallVector<OpFoldResult> offsets, sizes, strides;
338  if (failed(getCollapsedExtractSliceInfo(rewriter, sliceOp, reassociation,
339  offsets, sizes, strides)))
340  return failure();
341 
342  // The shape of the result can be obtained from the sizes passed in.
343  SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
344  RankedTensorType resultType = sliceOp.getResultType();
345 
346  // Create a new ExtractSliceOp and ExpandShapeOp.
347  Location loc = sliceOp.getLoc();
348  Value newSliceOp = tensor::ExtractSliceOp::create(
349  rewriter, loc, expandShapeOp.getSrc(), offsets, sizes, strides);
350  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
351  sliceOp, resultType, newSliceOp,
352  expandShapeOp.getReassociationIndices(), expandedSizes);
353  return success();
354  }
355 };
356 
357 /// Converts `tensor.extract_slice(tensor.collapse_shape)` to
358 /// `tensor.collapse_shape(tensor.extract_slice)`.
359 ///
360 /// For this transformation to be possible - after bubbling up, the extraction
361 /// of the contiguous slice must be representable as a single slice obtained via
362 /// tensor.extract_slice within each reassociation group of the src.
363 ///
364 /// In case the size and offset extracted are static then this is possible if
365 /// the following conditions are met within each reassociation group:
366 /// Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the
367 /// dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the
368 /// shape of a desired slice. A slice of shape S can be extracted as a
369 /// contiguous span of elements if and only if there exists an index k in {0, 1,
370 /// ..., n} such that:
371 /// S_i = 1 for all i < k (that is, all leading dimensions are singleton),
372 /// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly
373 /// one dimension),
374 /// S_i = A_i for all i > k (that is, all trailing dimensions are preserved
375 /// in full).
376 /// In other words, the slice shape S must be of the form:
377 /// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ]
378 ///
379 /// In case the size and/or offset extracted are dynamic then this is possible
380 /// only if there is single dimension in the reassociation group that has a size
381 /// not equal to 1.
382 /// In other words, the tensor shape must be of the form:
383 /// [ 1, 1, ..., 1, A, 1, ...,1 ]
384 /// Note - it might be possible to enable this pattern for more cases when the
385 /// size/offset are dynamic via performing an analysis of the possible values
386 /// that could be given to the size/offset.
387 ///
388 /// Example:
389 /// The transformation is possible because each reassociation group can be
390 /// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?],
391 /// [20->10]).
392 /// ```
393 /// BEFORE:
394 /// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ...
395 /// tensor<8x16x1x7x20f32> to tensor<128x7x20xf32>
396 /// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1]
397 /// tensor<128x7x20xf32> to tensor<32x?x10xf32>
398 ///
399 /// AFTER:
400 /// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10]
401 // [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32>
402 /// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ...
403 /// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32>
404 /// ```
405 ///
406 /// Negative example:
407 /// The transformation is not possible because we cannot use a single slice to
408 /// represent the reassociation group [2x3x10->???]. If we would want the
409 /// collapse to be after the extraction, we would need to extract multiple
410 /// slices and concat them together.
411 /// ```
412 /// %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into
413 /// tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] :
414 /// tensor<60xf32> to tensor<15xf32>
415 /// ```
416 /// If we would want the collapse to be after the extraction, a possible
417 /// alternate transformation could be to extract multiple slices and concat them
418 /// together:
419 /// ```
420 /// %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] :
421 /// tensor<2x3x10xf32> to tensor <1x1x10xf32>
422 /// %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] :
423 /// tensor<2x3x10xf32> to tensor <1x1x5xf32>
424 /// %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} :
425 /// (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32>
426 /// %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32>
427 /// to tensor<15xf32>
428 /// ```
429 /// But this is not the intended purpose of the transformation.
430 struct BubbleUpCollapseShapeThroughExtractSlice
431  : public OpRewritePattern<tensor::ExtractSliceOp> {
433 
434  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
435  PatternRewriter &rewriter) const override {
436  auto collapseShapeOp =
437  sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
438  if (!collapseShapeOp) {
439  return rewriter.notifyMatchFailure(
440  sliceOp,
441  "tensor.extract_slice source not produced by tensor.collapse_shape");
442  }
443 
444  SmallVector<OpFoldResult> offsets, sizes, strides;
446  rewriter, sliceOp, collapseShapeOp.getReassociationIndices(),
447  collapseShapeOp.getSrcType().getShape(), offsets, sizes, strides)))
448  return failure();
449 
450  Value newSliceOp = tensor::ExtractSliceOp::create(
451  rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), offsets,
452  sizes, strides);
453  rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
454  sliceOp, sliceOp.getResultType(), newSliceOp,
455  collapseShapeOp.getReassociationIndices());
456 
457  return success();
458  }
459 };
460 
461 } // namespace
462 
464  OpBuilder &b, tensor::ExtractSliceOp sliceOp,
465  ArrayRef<ReassociationIndices> reassociation,
466  SmallVectorImpl<OpFoldResult> &collapsedOffsets,
467  SmallVectorImpl<OpFoldResult> &collapsedSizes,
468  SmallVectorImpl<OpFoldResult> &collapsedStrides) {
469  if (!sliceOp.hasUnitStride()) {
470  return failure();
471  }
472 
473  SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
474  SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
475 
476  if (static_cast<size_t>(sliceOp.getResultType().getRank()) != sizes.size()) {
477  return failure();
478  }
479 
480  auto isZeroOffsetAndFullSize = [&](OpFoldResult offset,
481  OpFoldResult sliceSize, int64_t inputDim) {
482  if (!isZeroInteger(offset))
483  return false;
484  ValueBoundsConstraintSet::Variable inputSize(sliceOp.getSource(), inputDim);
485  FailureOr<bool> maybeEqual =
486  ValueBoundsConstraintSet::areEqual(sliceSize, inputSize);
487  return llvm::succeeded(maybeEqual) && maybeEqual.value();
488  };
489 
490  // Check that the slice is contiguous within each reassociation group.
491  // The slice is contiguous only if after the first dimension where a non
492  // unit slice is taken, the slice size on all subsequent dimensions of the
493  // group is equal to the entire size of the dimension.
494  // Examples of contiguous slices:
495  // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
496  // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
497  // Examples of non contiguous slices:
498  // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
499  // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
500  for (const ReassociationIndices &indices : reassociation) {
501  int64_t i = 0;
502  int64_t e = indices.size();
503  // Find the first expanded dim after the first dim with non-unit extracted
504  // size.
505  for (; i < e; ++i) {
506  if (!isOneInteger(sizes[indices[i]])) {
507  // +1 to skip the first non-unit size dim.
508  i++;
509  break;
510  }
511  }
512 
513  // Verify that all subsequent dimensions extract the full size of the
514  // source tensor.
515  for (; i < e; ++i) {
516  int64_t expandedDim = indices[i];
517  if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
518  expandedDim)) {
519  return failure();
520  }
521  }
522  }
523 
524  // The tensor.extract_slice before applying the pattern works on the result
525  // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
526  // referring to the state before applying the pattern are named with the
527  // prefix "expanded", and ones referring to the state after applying the
528  // pattern are named with the prefix "collapsed".
529  Location loc = sliceOp.getLoc();
530  SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
531  SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
532  SmallVector<OpFoldResult> expandedShape =
533  getMixedSizes(b, loc, sliceOp.getSource());
534 
535  // Helper variables and function for accumulating the size values.
536  AffineExpr d0, d1, d2;
537  bindDims(b.getContext(), d0, d1, d2);
538  // Multiply two integers.
539  auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
540  auto mulMap = AffineMap::get(2, 0, {d0 * d1});
541  return affine::makeComposedFoldedAffineApply(b, loc, mulMap, {v1, v2});
542  };
543 
544  // Compute new offsets, sizes, and strides for tensor.extract_slice.
545  // The new tensor.extract_slice will work on a tensor that has has a rank of
546  // ReassociationIndices.size(). In the loop a single offset, size, and
547  // stride value is computed per reassociation group.
548  for (const ReassociationIndices &indices : reassociation) {
549  // collapsedSize will hold the size of the single dim that represents the
550  // reassociation group in the non expanded tensor.
551  OpFoldResult collapsedSize = b.getIndexAttr(1);
552  // The reassocGroupSizes and reassocGroupOffsets are used to create an
553  // affine.linearize_index op to linearize the single offset value required
554  // for this reassociation group.
555  SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
556 
557  for (long expandedDim : indices) {
558  // reassocGroupSizes and reassocGroupOffsets can be obtained directly
559  // from the expanded state, but the collapsed size requires calculation
560  // as it did not previously exist.
561  reassocGroupSizes.push_back(expandedShape[expandedDim]);
562  reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
563  collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
564  }
565 
566  SmallVector<Value> offsetVals =
567  llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) {
568  return getValueOrCreateConstantIndexOp(b, loc, ofr);
569  });
570  OpFoldResult collapsedOffset = affine::AffineLinearizeIndexOp::create(
571  b, loc, offsetVals, reassocGroupSizes,
572  /*disjoint=*/true)
573  .getResult();
574  collapsedOffsets.push_back(collapsedOffset);
575  collapsedSizes.push_back(collapsedSize);
576 
577  // Only unit stride is supported.
578  collapsedStrides.push_back(b.getIndexAttr(1));
579  }
580  return success();
581 }
582 
584  OpBuilder &b, tensor::ExtractSliceOp sliceOp,
585  ArrayRef<ReassociationIndices> reassociation,
586  ArrayRef<int64_t> expandedShape,
587  SmallVectorImpl<OpFoldResult> &expandedOffsets,
588  SmallVectorImpl<OpFoldResult> &expandedSizes,
589  SmallVectorImpl<OpFoldResult> &expandedStrides) {
590  if (!sliceOp.hasUnitStride()) {
591  return failure();
592  }
593 
594  // The tensor.extract_slice before applying the pattern works on the result
595  // of the tensor.collapse_shape, so variables (i.e. inputs for
596  // ExtractSliceOp) referring to the state before applying the pattern are
597  // named with the prefix "collapsed", and ones referring to the state after
598  // applying the pattern are named with the prefix "expanded".
599  SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
600  SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
601  if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
602  collapsedSizes.size()) {
603  return failure();
604  }
605 
606  // Compute new offsets, sizes, and strides for tensor.extract_slice.
607  // The new tensor.extract_slice will work on a tensor that has has a rank
608  // equal to the rank of the src of the collapse_shape. In each iteration of
609  // the loop, the offsets and sizes will be computed per reassociation group.
610  expandedStrides.resize(expandedShape.size(), b.getIndexAttr(1));
611  for (auto [collapsedSize, collapsedOffset, reassocIndices] :
612  llvm::zip_equal(collapsedSizes, collapsedOffsets, reassociation)) {
613  // CASE #1 - size and/or offset are dynamic.
614  // In this case, the slice can be represented as a contiguous slice only
615  // if there is a single dimension in the reassociation group that has a
616  // size not equal to 1.
617  if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
618  int nonUnitSizeCount = 0;
619  for (int64_t expandedShapeIdx : reassocIndices) {
620  if (expandedShape[expandedShapeIdx] != 1) {
621  nonUnitSizeCount++;
622  expandedSizes.push_back(collapsedSize);
623  expandedOffsets.push_back(collapsedOffset);
624  continue;
625  }
626 
627  expandedSizes.push_back(b.getIndexAttr(1));
628  expandedOffsets.push_back(b.getIndexAttr(0));
629  }
630 
631  if (nonUnitSizeCount != 1) {
632  return failure();
633  }
634  continue;
635  }
636 
637  // CASE #2 = size and offset are static.
638  // Verify that the slice can be represented as a contiguous slice of the
639  // src of the collapse_shape.
640  // Checking this is done on order of most internal dimensions first,
641  // so traversal is done in reverse order of the reassociation group.
642  // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
643  // ...,An] then we first find the size and offset for n...k+1 then for k
644  // and then for k-1...0.
645 
646  // currentCollapsedsize and currentCollapsedOffset are initialized with
647  // the original collapsed size and offset and divided by the expanded
648  // shape size in each dimension as we go along the reassociation group.
649  // In essence we are spreading the original collapsed size and offset over
650  // the various expanded slice dimensions.
651  // The variables are used both to check the validity of the slice and to
652  // compute the expanded sizes and offsets.
653  int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
654  int64_t currentCollapsedOffset =
655  getConstantIntValue(collapsedOffset).value();
656  SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
657  ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
658  reassocIndices.rend());
659  int64_t idx = 0;
660  int64_t reassocGroupSize = reassocIndices.size();
661 
662  // First handle the trailing dimensions where the slice size should be
663  // equal to the tensor shape and the offset should be 0 (n...k+1).
664  for (; idx < reassocGroupSize; ++idx) {
665  int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
666 
667  if (currentCollapsedsize < expandedShapeSize)
668  break;
669 
670  // We need to make sure that the slice size can be set to the shape size
671  // and the offset to 0.
672  if ((currentCollapsedsize % expandedShapeSize) != 0 ||
673  (currentCollapsedOffset % expandedShapeSize) != 0) {
674  return failure();
675  }
676 
677  groupExpandedSizes.push_back(b.getIndexAttr(expandedShapeSize));
678  groupExpandedOffsets.push_back(b.getIndexAttr(0));
679 
680  currentCollapsedsize /= expandedShapeSize;
681  currentCollapsedOffset /= expandedShapeSize;
682  }
683 
684  // Now handle the first dim where slicing occurs on (k).
685  if (idx < reassocGroupSize) {
686  int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
687  int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
688  // We need to make sure that the slice size in this dim + offset will
689  // not exceed the shape size.
690  if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
691  return failure();
692  }
693  groupExpandedSizes.push_back(b.getIndexAttr(currentCollapsedsize));
694  groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim));
695  currentCollapsedOffset /= expandedShapeSize;
696  }
697 
698  // Now handle the leading dimensions where the slice size is equal to 1
699  // (k-1...0).
700  // The size for these dimensions must be 1 because of how we constructed
701  // the slice size of the expanded shape. We spread the original collapsed
702  // size over the expanded shape sizes until we reached dimension k where
703  // the remaining size was smaller than the expanded shape size, and spread
704  // the remaining size on it. So, now we are left with only 1s.
705  for (idx++; idx < reassocGroupSize; ++idx) {
706  int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
707  int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
708  groupExpandedSizes.push_back(b.getIndexAttr(1));
709  groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim));
710  currentCollapsedOffset /= expandedShapeSize;
711  }
712  expandedSizes.append(groupExpandedSizes.rbegin(),
713  groupExpandedSizes.rend());
714  expandedOffsets.append(groupExpandedOffsets.rbegin(),
715  groupExpandedOffsets.rend());
716  }
717  return success();
718 }
719 
722  patterns
723  .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
724  FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
725  FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
726  FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
727  FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
728  patterns.getContext());
729 }
730 
733  patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
734 }
735 
738  patterns.add<BubbleUpExpandShapeThroughExtractSlice,
739  BubbleUpCollapseShapeThroughExtractSlice>(patterns.getContext());
740 }
Base type for affine expression.
Definition: AffineExpr.h:68
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
MLIRContext * getContext() const
Definition: Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:207
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
A variable that can be added to the constraint set as a "column".
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1329
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
LogicalResult getCollapsedExtractSliceInfo(OpBuilder &b, tensor::ExtractSliceOp sliceOp, ArrayRef< ReassociationIndices > reassociation, SmallVectorImpl< OpFoldResult > &collapsedOffsets, SmallVectorImpl< OpFoldResult > &collapsedSizes, SmallVectorImpl< OpFoldResult > &collapsedStrides)
Computes the offsets, sizes, and strides needed to build a collapsed sliceOp.
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
LogicalResult getExpandedExtractSliceInfo(OpBuilder &b, tensor::ExtractSliceOp sliceOp, ArrayRef< ReassociationIndices > reassociation, ArrayRef< int64_t > expandedShape, SmallVectorImpl< OpFoldResult > &expandedOffsets, SmallVectorImpl< OpFoldResult > &expandedSizes, SmallVectorImpl< OpFoldResult > &expandedStrides)
Computes the offsets, sizes, and strides needed to build an expanded sliceOp.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
Definition: BuiltinTypes.h:356
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314