MLIR 22.0.0git
ReshapePatterns.cpp
Go to the documentation of this file.
1//===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15#include "llvm/ADT/STLExtras.h"
16#include "llvm/Support/LogicalResult.h"
17
18using namespace mlir;
19using namespace mlir::tensor;
20
21namespace {
22/// Fold expand_shape(extract_slice) ops that cancel itself out.
23struct FoldExpandOfRankReducingExtract
24 : public OpRewritePattern<ExpandShapeOp> {
25 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
26
27 LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
28 PatternRewriter &rewriter) const override {
29 RankedTensorType resultType = expandShapeOp.getResultType();
30 auto extractSliceOp =
31 expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
32 if (!extractSliceOp)
33 return failure();
34 RankedTensorType srcType = extractSliceOp.getSourceType();
35
36 // Only cases where the ExpandShapeOp can be folded away entirely are
37 // supported. Moreover, only simple cases where the resulting ExtractSliceOp
38 // has no rank-reduction anymore are supported at the moment.
39 RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
40 srcType, extractSliceOp.getStaticOffsets(),
41 extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
42 if (nonReducingExtractType != resultType)
43 return failure();
44
45 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
46 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
47 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
48 rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
49 expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
50 mixedStrides);
51 return success();
52 }
53};
54
55/// Fold collapse_shape which only removes static dimensions of size `1`
56/// into extract_slice.
57struct FoldUnPaddingCollapseIntoExtract
58 : public OpRewritePattern<tensor::CollapseShapeOp> {
59 using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern;
60
61 LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
62 PatternRewriter &rewriter) const override {
63 auto extractSliceOp =
64 collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
65 // Collapse cannot be folded away with multiple users of the extract slice
66 // and it is not necessarily beneficial to only convert the collapse into
67 // another extract slice.
68 if (!extractSliceOp || !extractSliceOp->hasOneUse())
69 return failure();
70
71 // Only fold away simple collapse where all removed dimensions have static
72 // size `1`.
74 collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
75 if (res != SliceVerificationResult::Success)
76 return rewriter.notifyMatchFailure(collapseShapeOp,
77 "expected unpadding collapse");
78
79 Value unPaddedExtractSlice = tensor::ExtractSliceOp::create(
80 rewriter, extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
81 extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
82 extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
83 rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice);
84 return success();
85 }
86};
87
88/// Fold insert_slice(collapse_shape) ops that cancel itself out.
89template <typename OpTy>
90struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
91 using OpRewritePattern<OpTy>::OpRewritePattern;
92
93 LogicalResult matchAndRewrite(OpTy insertSliceOp,
94 PatternRewriter &rewriter) const override {
95 auto collapseShapeOp =
96 insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
97 if (!collapseShapeOp)
98 return failure();
99 RankedTensorType srcType = collapseShapeOp.getSrcType();
100
101 // Only cases where the CollapseShapeOp can be folded away entirely are
102 // supported. Moreover, only simple cases where the resulting InsertSliceOp
103 // has no rank-reduction anymore are supported at the moment.
104 RankedTensorType nonReducingInsertType =
105 RankedTensorType::get(insertSliceOp.getStaticSizes(),
106 insertSliceOp.getDestType().getElementType());
107 if (nonReducingInsertType != srcType)
108 return failure();
109
110 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
111 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
112 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
113 rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(),
114 insertSliceOp.getDest(), mixedOffsets,
115 mixedSizes, mixedStrides);
116 return success();
117 }
118};
119
120/// Fold expand_shape which only adds static dimensions of size `1`
121/// into insert_slice.
122template <typename OpTy>
123struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
124 using OpRewritePattern<OpTy>::OpRewritePattern;
125
126 LogicalResult matchAndRewrite(OpTy insertSliceOp,
127 PatternRewriter &rewriter) const override {
128 auto expandShapeOp = insertSliceOp.getSource()
129 .template getDefiningOp<tensor::ExpandShapeOp>();
130 if (!expandShapeOp)
131 return failure();
132
133 // Only fold away simple expansion where all added dimensions have static
134 // size `1`.
136 expandShapeOp.getResultType(), expandShapeOp.getSrcType());
137 if (res != SliceVerificationResult::Success)
138 return rewriter.notifyMatchFailure(insertSliceOp,
139 "expected rank increasing expansion");
140
141 rewriter.modifyOpInPlace(insertSliceOp, [&]() {
142 insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
143 });
144 return success();
145 }
146};
147
148/// Pattern to bubble up a tensor.expand_shape op through a producer
149/// tensor.collapse_shape op that has non intersecting reassociations.
150struct BubbleUpExpandThroughParallelCollapse
151 : public OpRewritePattern<tensor::ExpandShapeOp> {
152 using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
153
154 LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
155 PatternRewriter &rewriter) const override {
156 auto collapseOp =
157 expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
158 if (!collapseOp)
159 return failure();
160 auto expandReInds = expandOp.getReassociationIndices();
161 auto collapseReInds = collapseOp.getReassociationIndices();
162
163 // Special case where the collapsed tensor to expand is a 0-D tensor,
164 // then the reassociation maps will be empty and not produce valid results.
165 if (expandReInds.size() == 0) {
166 return failure();
167 }
168
169 // Reshapes are parallel to each other (by construction the number of
170 // reassociations specified in the collapse and expand are the same), if at
171 // any position
172 // 1. either the reassociation indices are of the same size, or
173 // 2. either the reassociation in the collapse or the expand is of size 1.
174 ArrayRef<int64_t> staticSourceSize = collapseOp.getSrcType().getShape();
175 ArrayRef<int64_t> staticResultSize = expandOp.getStaticOutputShape();
176 for (auto [expandReassociation, collapseReassociation] :
177 llvm::zip_equal(expandReInds, collapseReInds)) {
178 if (collapseReassociation.size() == expandReassociation.size()) {
179 // Even if the reassociations are the same, the collapse/expand should
180 // result in the same dimensions. i.e 4x8x2 into 64 should be expanded
181 // into 4x8x2 again. In presense of dynamic dimensions one can only
182 // verify "equality" when there is only one dynamic dimension present,
183 // and all other static dimensions are equal.
184 ArrayRef<int64_t> collapsedStaticShapes = staticSourceSize.slice(
185 collapseReassociation.front(), collapseReassociation.size());
186 int64_t numCollapsedDynamic =
187 llvm::count_if(collapsedStaticShapes, ShapedType::isDynamic);
188 ArrayRef<int64_t> expandedStaticShapes = staticResultSize.slice(
189 expandReassociation.front(), expandReassociation.size());
190 int64_t numExpandedDynamic =
191 llvm::count_if(expandedStaticShapes, ShapedType::isDynamic);
192 if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
193 collapsedStaticShapes != expandedStaticShapes) {
194 return failure();
195 }
196 continue;
197 }
198 // If the reassociations are not same, one or the other needs to be of
199 // size one.
200 if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
201 return failure();
202 }
203
204 // Compute new reassociation indices and expanded/collaped shapes.
205 SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
206 Location loc = expandOp->getLoc();
207 SmallVector<OpFoldResult> sourceSizes =
208 tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
209 SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape();
210 SmallVector<OpFoldResult> newExpandSizes;
211
212 int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0,
213 resultSizeIndex = 0;
214
215 for (size_t idx = 0, idxEnd = collapseReInds.size(); idx < idxEnd; idx++) {
216 auto &collapseReassociation = collapseReInds[idx];
217 auto &expandReassociation = expandReInds[idx];
218
219 // Case 1. The reassociations are same in the collapse producer
220 // and expand consumer. In the swapped expand, each of the final
221 // dimensions are kept as is in the expand and the collapse. So,
222 // for every element in the `ReassocationIndices` vector add a new
223 // `ReassociationIndices` vector for the swapped expand and collapse
224 // (of size 1).
225 if (collapseReassociation.size() == expandReassociation.size()) {
226 for (size_t i = 0; i < collapseReassociation.size(); ++i) {
227 newCollapseReInds.push_back({newCollapseIndex++});
228 newExpandReInds.push_back({newExpandIndex++});
229 newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
230 sourceSizeIndex++;
231 }
232 continue;
233 }
234
235 // Case 2. The `ReassociationIndices` in the collapse is of size > 1 (and
236 // in the expand is of size == 1). In this case, the original dimensions
237 // are preserved on expansion and collapsed subsequently.
238 if (collapseReassociation.size() != 1) {
239 ReassociationIndices newCollapseReassociation;
240 for (size_t i = 0; i < collapseReassociation.size(); ++i) {
241 newCollapseReassociation.push_back(newCollapseIndex++);
242 newExpandReInds.push_back({newExpandIndex++});
243 newExpandSizes.push_back(sourceSizes[sourceSizeIndex++]);
244 }
245 resultSizeIndex++;
246 newCollapseReInds.push_back(newCollapseReassociation);
247 continue;
248 }
249
250 // Case 3. The `ReassociationIndices` in the expand is of size > 1 (and
251 // in the collapse is of size == 1). In this case, the expansion happens
252 // first and the expanded dimensions are preserved on collapse.
253 ReassociationIndices newExpandReassociation;
254 for (size_t i = 0; i < expandReassociation.size(); ++i) {
255 newExpandReassociation.push_back(newExpandIndex++);
256 newCollapseReInds.push_back({newCollapseIndex++});
257 newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
258 }
259 newExpandReInds.push_back(newExpandReassociation);
260 sourceSizeIndex++;
261 }
262
263 // Swap reshape order.
264 SmallVector<Value> dynamicSizes;
265 SmallVector<int64_t> staticSizes;
266 dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
267 auto expandResultType = expandOp.getResultType().clone(staticSizes);
268 Value newCollapseSrc = collapseOp.getSrc();
269 // If the number of reassociation indices in the new `expand_shape` op
270 // matches the number of dimensions of the result, then the expand_shape
271 // is a no-op.
272 if (newExpandReInds.size() != newExpandSizes.size()) {
273 newCollapseSrc = tensor::ExpandShapeOp::create(
274 rewriter, loc, expandResultType, newCollapseSrc, newExpandReInds,
275 newExpandSizes);
276 }
277
278 // If the number of reassociation indices in the new `collapse_shape` op
279 // matches the number of dimensions of the source, then the collapse_shape
280 // is a no-op.
281 Value replacement = newCollapseSrc;
282 if (newCollapseReInds.size() != newExpandSizes.size()) {
283 replacement = tensor::CollapseShapeOp::create(
284 rewriter, loc, newCollapseSrc, newCollapseReInds);
285 }
286 rewriter.replaceOp(expandOp, replacement);
287 return success();
288 }
289};
290
291/// Converts `tensor.extract_slice(tensor.expand_shape)` to
292/// `tensor.expand_shape(tensor.extract_slice)`.
293///
294/// For this transformation to be possible, the slice must be fully contiguous
295/// within each reassociation group of the expand_shape. A slice is defined as
296/// fully contiguous within a reassociation group if after flattening the
297/// reassociation group to a single 1D range, then the slice taken out of the
298/// group could be defined as a single contiguous subrange within that range.
299///
300/// Rank reducing slices are not supported.
301///
302/// Example:
303/// The transformation is possible because each reassociation group has a
304/// contiguous slice (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4]).
305/// ```
306/// BEFORE:
307/// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
308/// tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
309/// %slice = tensor.extract_slice %reshape ...
310/// tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
311///
312/// AFTER:
313/// %slice = tensor.extract_slice %in ...
314/// tensor<8x16x32xf32> to tensor<8x5x4xf32>
315/// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
316/// tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
317/// ```
318///
319/// Note - this pattern could be extended to be a swap pattern between
320/// `tensor.expand_shape` and `tensor.extract_slice`, but is currently
321/// implemented only as a bubble up pattern for `tensor.extract_slice`.
322struct BubbleUpExtractSliceThroughExpandShape
323 : public OpRewritePattern<tensor::ExtractSliceOp> {
324 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
325
326 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
327 PatternRewriter &rewriter) const override {
328 auto expandShapeOp =
329 sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
330 if (!expandShapeOp) {
331 return rewriter.notifyMatchFailure(
332 sliceOp, "tensor.extract_slice source not produced by expand_shape");
333 }
334 SmallVector<ReassociationIndices> reassociation =
335 expandShapeOp.getReassociationIndices();
336
337 SmallVector<OpFoldResult> offsets, sizes, strides;
338 if (failed(getCollapsedExtractSliceInfo(rewriter, sliceOp, reassociation,
339 offsets, sizes, strides)))
340 return failure();
341
342 // The shape of the result can be obtained from the sizes passed in.
343 SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
344 RankedTensorType resultType = sliceOp.getResultType();
345
346 // Create a new ExtractSliceOp and ExpandShapeOp.
347 Location loc = sliceOp.getLoc();
348 Value newSliceOp = tensor::ExtractSliceOp::create(
349 rewriter, loc, expandShapeOp.getSrc(), offsets, sizes, strides);
350 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
351 sliceOp, resultType, newSliceOp,
352 expandShapeOp.getReassociationIndices(), expandedSizes);
353 return success();
354 }
355};
356
357/// Converts `tensor.extract_slice(tensor.collapse_shape)` to
358/// `tensor.collapse_shape(tensor.extract_slice)`.
359///
360/// For this transformation to be possible - after bubbling up, the extraction
361/// of the contiguous slice must be representable as a single slice obtained via
362/// tensor.extract_slice within each reassociation group of the src.
363///
364/// In case the size and offset extracted are static then this is possible if
365/// the following conditions are met within each reassociation group:
366/// Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the
367/// dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the
368/// shape of a desired slice. A slice of shape S can be extracted as a
369/// contiguous span of elements if and only if there exists an index k in {0, 1,
370/// ..., n} such that:
371/// S_i = 1 for all i < k (that is, all leading dimensions are singleton),
372/// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly
373/// one dimension),
374/// S_i = A_i for all i > k (that is, all trailing dimensions are preserved
375/// in full).
376/// In other words, the slice shape S must be of the form:
377/// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ]
378///
379/// In case the size and/or offset extracted are dynamic then this is possible
380/// only if there is single dimension in the reassociation group that has a size
381/// not equal to 1.
382/// In other words, the tensor shape must be of the form:
383/// [ 1, 1, ..., 1, A, 1, ...,1 ]
384/// Note - it might be possible to enable this pattern for more cases when the
385/// size/offset are dynamic via performing an analysis of the possible values
386/// that could be given to the size/offset.
387///
388/// Example:
389/// The transformation is possible because each reassociation group can be
390/// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?],
391/// [20->10]).
392/// ```
393/// BEFORE:
394/// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ...
395/// tensor<8x16x1x7x20f32> to tensor<128x7x20xf32>
396/// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1]
397/// tensor<128x7x20xf32> to tensor<32x?x10xf32>
398///
399/// AFTER:
400/// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10]
401// [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32>
402/// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ...
403/// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32>
404/// ```
405///
406/// Negative example:
407/// The transformation is not possible because we cannot use a single slice to
408/// represent the reassociation group [2x3x10->???]. If we would want the
409/// collapse to be after the extraction, we would need to extract multiple
410/// slices and concat them together.
411/// ```
412/// %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into
413/// tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] :
414/// tensor<60xf32> to tensor<15xf32>
415/// ```
416/// If we would want the collapse to be after the extraction, a possible
417/// alternate transformation could be to extract multiple slices and concat them
418/// together:
419/// ```
420/// %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] :
421/// tensor<2x3x10xf32> to tensor <1x1x10xf32>
422/// %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] :
423/// tensor<2x3x10xf32> to tensor <1x1x5xf32>
424/// %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} :
425/// (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32>
426/// %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32>
427/// to tensor<15xf32>
428/// ```
429/// But this is not the intended purpose of the transformation.
430struct BubbleUpExtractSliceThroughCollapseShape
431 : public OpRewritePattern<tensor::ExtractSliceOp> {
432 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
433
434 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
435 PatternRewriter &rewriter) const override {
436 auto collapseShapeOp =
437 sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
438 if (!collapseShapeOp) {
439 return rewriter.notifyMatchFailure(
440 sliceOp,
441 "tensor.extract_slice source not produced by tensor.collapse_shape");
442 }
443
444 SmallVector<OpFoldResult> offsets, sizes, strides;
446 rewriter, sliceOp, collapseShapeOp.getReassociationIndices(),
447 collapseShapeOp.getSrcType().getShape(), offsets, sizes, strides)))
448 return failure();
449
450 Value newSliceOp = tensor::ExtractSliceOp::create(
451 rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), offsets,
452 sizes, strides);
453 rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
454 sliceOp, sliceOp.getResultType(), newSliceOp,
455 collapseShapeOp.getReassociationIndices());
456
457 return success();
458 }
459};
460
461} // namespace
462
464 OpBuilder &b, tensor::ExtractSliceOp sliceOp,
465 ArrayRef<ReassociationIndices> reassociation,
466 SmallVectorImpl<OpFoldResult> &collapsedOffsets,
467 SmallVectorImpl<OpFoldResult> &collapsedSizes,
468 SmallVectorImpl<OpFoldResult> &collapsedStrides) {
469 if (!sliceOp.hasUnitStride()) {
470 return failure();
471 }
472
473 SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
474 SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
475
476 if (static_cast<size_t>(sliceOp.getResultType().getRank()) != sizes.size()) {
477 return failure();
478 }
479
480 auto isZeroOffsetAndFullSize = [&](OpFoldResult offset,
481 OpFoldResult sliceSize, int64_t inputDim) {
482 if (!isZeroInteger(offset))
483 return false;
484 ValueBoundsConstraintSet::Variable inputSize(sliceOp.getSource(), inputDim);
485 FailureOr<bool> maybeEqual =
486 ValueBoundsConstraintSet::areEqual(sliceSize, inputSize);
487 return llvm::succeeded(maybeEqual) && maybeEqual.value();
488 };
489
490 // Check that the slice is contiguous within each reassociation group.
491 // The slice is contiguous only if after the first dimension where a non
492 // unit slice is taken, the slice size on all subsequent dimensions of the
493 // group is equal to the entire size of the dimension.
494 // Examples of contiguous slices:
495 // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
496 // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
497 // Examples of non contiguous slices:
498 // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
499 // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
500 for (const ReassociationIndices &indices : reassociation) {
501 int64_t i = 0;
502 int64_t e = indices.size();
503 // Find the first expanded dim after the first dim with non-unit extracted
504 // size.
505 for (; i < e; ++i) {
506 if (!isOneInteger(sizes[indices[i]])) {
507 // +1 to skip the first non-unit size dim.
508 i++;
509 break;
510 }
511 }
512
513 // Verify that all subsequent dimensions extract the full size of the
514 // source tensor.
515 for (; i < e; ++i) {
516 int64_t expandedDim = indices[i];
517 if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
518 expandedDim)) {
519 return failure();
520 }
521 }
522 }
523
524 // The tensor.extract_slice before applying the pattern works on the result
525 // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
526 // referring to the state before applying the pattern are named with the
527 // prefix "expanded", and ones referring to the state after applying the
528 // pattern are named with the prefix "collapsed".
529 Location loc = sliceOp.getLoc();
530 SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
531 SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
532 SmallVector<OpFoldResult> expandedShape =
533 getMixedSizes(b, loc, sliceOp.getSource());
534
535 // Helper variables and function for accumulating the size values.
536 AffineExpr d0, d1, d2;
537 bindDims(b.getContext(), d0, d1, d2);
538 // Multiply two integers.
539 auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
540 auto mulMap = AffineMap::get(2, 0, {d0 * d1});
541 return affine::makeComposedFoldedAffineApply(b, loc, mulMap, {v1, v2});
542 };
543
544 // Compute new offsets, sizes, and strides for tensor.extract_slice.
545 // The new tensor.extract_slice will work on a tensor that has has a rank of
546 // ReassociationIndices.size(). In the loop a single offset, size, and
547 // stride value is computed per reassociation group.
548 for (const ReassociationIndices &indices : reassociation) {
549 // collapsedSize will hold the size of the single dim that represents the
550 // reassociation group in the non expanded tensor.
551 OpFoldResult collapsedSize = b.getIndexAttr(1);
552 // The reassocGroupSizes and reassocGroupOffsets are used to create an
553 // affine.linearize_index op to linearize the single offset value required
554 // for this reassociation group.
555 SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
556
557 for (long expandedDim : indices) {
558 // reassocGroupSizes and reassocGroupOffsets can be obtained directly
559 // from the expanded state, but the collapsed size requires calculation
560 // as it did not previously exist.
561 reassocGroupSizes.push_back(expandedShape[expandedDim]);
562 reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
563 collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
564 }
565
566 SmallVector<Value> offsetVals =
567 llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) {
568 return getValueOrCreateConstantIndexOp(b, loc, ofr);
569 });
570 OpFoldResult collapsedOffset = affine::AffineLinearizeIndexOp::create(
571 b, loc, offsetVals, reassocGroupSizes,
572 /*disjoint=*/true)
573 .getResult();
574 collapsedOffsets.push_back(collapsedOffset);
575 collapsedSizes.push_back(collapsedSize);
576
577 // Only unit stride is supported.
578 collapsedStrides.push_back(b.getIndexAttr(1));
579 }
580 return success();
581}
582
584 OpBuilder &b, tensor::ExtractSliceOp sliceOp,
585 ArrayRef<ReassociationIndices> reassociation,
586 ArrayRef<int64_t> expandedShape,
587 SmallVectorImpl<OpFoldResult> &expandedOffsets,
588 SmallVectorImpl<OpFoldResult> &expandedSizes,
589 SmallVectorImpl<OpFoldResult> &expandedStrides) {
590 if (!sliceOp.hasUnitStride()) {
591 return failure();
592 }
593
594 // The tensor.extract_slice before applying the pattern works on the result
595 // of the tensor.collapse_shape, so variables (i.e. inputs for
596 // ExtractSliceOp) referring to the state before applying the pattern are
597 // named with the prefix "collapsed", and ones referring to the state after
598 // applying the pattern are named with the prefix "expanded".
599 SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
600 SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
601 if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
602 collapsedSizes.size()) {
603 return failure();
604 }
605
606 // Compute new offsets, sizes, and strides for tensor.extract_slice.
607 // The new tensor.extract_slice will work on a tensor that has has a rank
608 // equal to the rank of the src of the collapse_shape. In each iteration of
609 // the loop, the offsets and sizes will be computed per reassociation group.
610 expandedStrides.resize(expandedShape.size(), b.getIndexAttr(1));
611 for (auto [collapsedSize, collapsedOffset, reassocIndices] :
612 llvm::zip_equal(collapsedSizes, collapsedOffsets, reassociation)) {
613 // CASE #1 - size and/or offset are dynamic.
614 // In this case, the slice can be represented as a contiguous slice only
615 // if there is a single dimension in the reassociation group that has a
616 // size not equal to 1.
617 if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
618 int nonUnitSizeCount = 0;
619 for (int64_t expandedShapeIdx : reassocIndices) {
620 if (expandedShape[expandedShapeIdx] != 1) {
621 nonUnitSizeCount++;
622 expandedSizes.push_back(collapsedSize);
623 expandedOffsets.push_back(collapsedOffset);
624 continue;
625 }
626
627 expandedSizes.push_back(b.getIndexAttr(1));
628 expandedOffsets.push_back(b.getIndexAttr(0));
629 }
630
631 if (nonUnitSizeCount != 1) {
632 return failure();
633 }
634 continue;
635 }
636
637 // CASE #2 = size and offset are static.
638 // Verify that the slice can be represented as a contiguous slice of the
639 // src of the collapse_shape.
640 // Checking this is done on order of most internal dimensions first,
641 // so traversal is done in reverse order of the reassociation group.
642 // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
643 // ...,An] then we first find the size and offset for n...k+1 then for k
644 // and then for k-1...0.
645
646 // currentCollapsedsize and currentCollapsedOffset are initialized with
647 // the original collapsed size and offset and divided by the expanded
648 // shape size in each dimension as we go along the reassociation group.
649 // In essence we are spreading the original collapsed size and offset over
650 // the various expanded slice dimensions.
651 // The variables are used both to check the validity of the slice and to
652 // compute the expanded sizes and offsets.
653 int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
654 int64_t currentCollapsedOffset =
655 getConstantIntValue(collapsedOffset).value();
656 SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
657 ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
658 reassocIndices.rend());
659 int64_t idx = 0;
660 int64_t reassocGroupSize = reassocIndices.size();
661
662 // First handle the trailing dimensions where the slice size should be
663 // equal to the tensor shape and the offset should be 0 (n...k+1).
664 for (; idx < reassocGroupSize; ++idx) {
665 int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
666
667 if (currentCollapsedsize < expandedShapeSize)
668 break;
669
670 // We need to make sure that the slice size can be set to the shape size
671 // and the offset to 0.
672 if ((currentCollapsedsize % expandedShapeSize) != 0 ||
673 (currentCollapsedOffset % expandedShapeSize) != 0) {
674 return failure();
675 }
676
677 groupExpandedSizes.push_back(b.getIndexAttr(expandedShapeSize));
678 groupExpandedOffsets.push_back(b.getIndexAttr(0));
679
680 currentCollapsedsize /= expandedShapeSize;
681 currentCollapsedOffset /= expandedShapeSize;
682 }
683
684 // Now handle the first dim where slicing occurs on (k).
685 if (idx < reassocGroupSize) {
686 int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
687 int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
688 // We need to make sure that the slice size in this dim + offset will
689 // not exceed the shape size.
690 if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
691 return failure();
692 }
693 groupExpandedSizes.push_back(b.getIndexAttr(currentCollapsedsize));
694 groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim));
695 currentCollapsedOffset /= expandedShapeSize;
696 }
697
698 // Now handle the leading dimensions where the slice size is equal to 1
699 // (k-1...0).
700 // The size for these dimensions must be 1 because of how we constructed
701 // the slice size of the expanded shape. We spread the original collapsed
702 // size over the expanded shape sizes until we reached dimension k where
703 // the remaining size was smaller than the expanded shape size, and spread
704 // the remaining size on it. So, now we are left with only 1s.
705 for (idx++; idx < reassocGroupSize; ++idx) {
706 int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
707 int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
708 groupExpandedSizes.push_back(b.getIndexAttr(1));
709 groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim));
710 currentCollapsedOffset /= expandedShapeSize;
711 }
712 expandedSizes.append(groupExpandedSizes.rbegin(),
713 groupExpandedSizes.rend());
714 expandedOffsets.append(groupExpandedOffsets.rbegin(),
715 groupExpandedOffsets.rend());
716 }
717 return success();
718}
719
723 .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
724 FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
725 FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
726 FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
727 FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
728 patterns.getContext());
729}
730
733 patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
734}
735
738 patterns.add<BubbleUpExtractSliceThroughExpandShape,
739 BubbleUpExtractSliceThroughCollapseShape>(patterns.getContext());
740}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
#define mul(a, b)
Base type for affine expression.
Definition AffineExpr.h:68
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
This class represents a single result from folding an operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
A variable that can be added to the constraint set as a "column".
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
LogicalResult getCollapsedExtractSliceInfo(OpBuilder &b, tensor::ExtractSliceOp sliceOp, ArrayRef< ReassociationIndices > reassociation, SmallVectorImpl< OpFoldResult > &collapsedOffsets, SmallVectorImpl< OpFoldResult > &collapsedSizes, SmallVectorImpl< OpFoldResult > &collapsedStrides)
Computes the offsets, sizes, and strides needed to build a collapsed sliceOp.
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
LogicalResult getExpandedExtractSliceInfo(OpBuilder &b, tensor::ExtractSliceOp sliceOp, ArrayRef< ReassociationIndices > reassociation, ArrayRef< int64_t > expandedShape, SmallVectorImpl< OpFoldResult > &expandedOffsets, SmallVectorImpl< OpFoldResult > &expandedSizes, SmallVectorImpl< OpFoldResult > &expandedStrides)
Computes the offsets, sizes, and strides needed to build an expanded sliceOp.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...