MLIR 22.0.0git
ReshapePatterns.cpp
Go to the documentation of this file.
1//===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15#include "llvm/ADT/STLExtras.h"
16#include "llvm/Support/LogicalResult.h"
17
18using namespace mlir;
19using namespace mlir::tensor;
20
21namespace {
22/// Fold expand_shape(extract_slice) ops that cancel itself out.
23struct FoldExpandOfRankReducingExtract
24 : public OpRewritePattern<ExpandShapeOp> {
25 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
26
27 LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
28 PatternRewriter &rewriter) const override {
29 RankedTensorType resultType = expandShapeOp.getResultType();
30 auto extractSliceOp =
31 expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
32 if (!extractSliceOp)
33 return failure();
34 RankedTensorType srcType = extractSliceOp.getSourceType();
35
36 // Only cases where the ExpandShapeOp can be folded away entirely are
37 // supported. Moreover, only simple cases where the resulting ExtractSliceOp
38 // has no rank-reduction anymore are supported at the moment.
39 RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
40 srcType, extractSliceOp.getStaticSizes());
41 if (nonReducingExtractType != resultType)
42 return failure();
43
44 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
45 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
46 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
47 rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
48 expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
49 mixedStrides);
50 return success();
51 }
52};
53
54/// Fold collapse_shape which only removes static dimensions of size `1`
55/// into extract_slice.
56struct FoldUnPaddingCollapseIntoExtract
57 : public OpRewritePattern<tensor::CollapseShapeOp> {
58 using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern;
59
60 LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
61 PatternRewriter &rewriter) const override {
62 auto extractSliceOp =
63 collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
64 // Collapse cannot be folded away with multiple users of the extract slice
65 // and it is not necessarily beneficial to only convert the collapse into
66 // another extract slice.
67 if (!extractSliceOp || !extractSliceOp->hasOneUse())
68 return failure();
69
70 // Only fold away simple collapse where all removed dimensions have static
71 // size `1`.
73 collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
74 if (res != SliceVerificationResult::Success)
75 return rewriter.notifyMatchFailure(collapseShapeOp,
76 "expected unpadding collapse");
77
78 Value unPaddedExtractSlice = tensor::ExtractSliceOp::create(
79 rewriter, extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
80 extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
81 extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
82 rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice);
83 return success();
84 }
85};
86
87/// Fold insert_slice(collapse_shape) ops that cancel itself out.
88template <typename OpTy>
89struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
90 using OpRewritePattern<OpTy>::OpRewritePattern;
91
92 LogicalResult matchAndRewrite(OpTy insertSliceOp,
93 PatternRewriter &rewriter) const override {
94 auto collapseShapeOp =
95 insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
96 if (!collapseShapeOp)
97 return failure();
98 RankedTensorType srcType = collapseShapeOp.getSrcType();
99
100 // Only cases where the CollapseShapeOp can be folded away entirely are
101 // supported. Moreover, only simple cases where the resulting InsertSliceOp
102 // has no rank-reduction anymore are supported at the moment.
103 RankedTensorType nonReducingInsertType =
104 RankedTensorType::get(insertSliceOp.getStaticSizes(),
105 insertSliceOp.getDestType().getElementType());
106 if (nonReducingInsertType != srcType)
107 return failure();
108
109 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
110 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
111 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
112 rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(),
113 insertSliceOp.getDest(), mixedOffsets,
114 mixedSizes, mixedStrides);
115 return success();
116 }
117};
118
119/// Fold expand_shape which only adds static dimensions of size `1`
120/// into insert_slice.
121template <typename OpTy>
122struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
123 using OpRewritePattern<OpTy>::OpRewritePattern;
124
125 LogicalResult matchAndRewrite(OpTy insertSliceOp,
126 PatternRewriter &rewriter) const override {
127 auto expandShapeOp = insertSliceOp.getSource()
128 .template getDefiningOp<tensor::ExpandShapeOp>();
129 if (!expandShapeOp)
130 return failure();
131
132 // Only fold away simple expansion where all added dimensions have static
133 // size `1`.
135 expandShapeOp.getResultType(), expandShapeOp.getSrcType());
136 if (res != SliceVerificationResult::Success)
137 return rewriter.notifyMatchFailure(insertSliceOp,
138 "expected rank increasing expansion");
139
140 rewriter.modifyOpInPlace(insertSliceOp, [&]() {
141 insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
142 });
143 return success();
144 }
145};
146
147/// Pattern to bubble up a tensor.expand_shape op through a producer
148/// tensor.collapse_shape op that has non intersecting reassociations.
149struct BubbleUpExpandThroughParallelCollapse
150 : public OpRewritePattern<tensor::ExpandShapeOp> {
151 using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
152
153 LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
154 PatternRewriter &rewriter) const override {
155 auto collapseOp =
156 expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
157 if (!collapseOp)
158 return failure();
159 auto expandReInds = expandOp.getReassociationIndices();
160 auto collapseReInds = collapseOp.getReassociationIndices();
161
162 // Special case where the collapsed tensor to expand is a 0-D tensor,
163 // then the reassociation maps will be empty and not produce valid results.
164 if (expandReInds.size() == 0) {
165 return failure();
166 }
167
168 // Reshapes are parallel to each other (by construction the number of
169 // reassociations specified in the collapse and expand are the same), if at
170 // any position
171 // 1. either the reassociation indices are of the same size, or
172 // 2. either the reassociation in the collapse or the expand is of size 1.
173 ArrayRef<int64_t> staticSourceSize = collapseOp.getSrcType().getShape();
174 ArrayRef<int64_t> staticResultSize = expandOp.getStaticOutputShape();
175 for (auto [expandReassociation, collapseReassociation] :
176 llvm::zip_equal(expandReInds, collapseReInds)) {
177 if (collapseReassociation.size() == expandReassociation.size()) {
178 // Even if the reassociations are the same, the collapse/expand should
179 // result in the same dimensions. i.e 4x8x2 into 64 should be expanded
180 // into 4x8x2 again. In presense of dynamic dimensions one can only
181 // verify "equality" when there is only one dynamic dimension present,
182 // and all other static dimensions are equal.
183 ArrayRef<int64_t> collapsedStaticShapes = staticSourceSize.slice(
184 collapseReassociation.front(), collapseReassociation.size());
185 int64_t numCollapsedDynamic =
186 llvm::count_if(collapsedStaticShapes, ShapedType::isDynamic);
187 ArrayRef<int64_t> expandedStaticShapes = staticResultSize.slice(
188 expandReassociation.front(), expandReassociation.size());
189 int64_t numExpandedDynamic =
190 llvm::count_if(expandedStaticShapes, ShapedType::isDynamic);
191 if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
192 collapsedStaticShapes != expandedStaticShapes) {
193 return failure();
194 }
195 continue;
196 }
197 // If the reassociations are not same, one or the other needs to be of
198 // size one.
199 if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
200 return failure();
201 }
202
203 // Compute new reassociation indices and expanded/collaped shapes.
204 SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
205 Location loc = expandOp->getLoc();
206 SmallVector<OpFoldResult> sourceSizes =
207 tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
208 SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape();
209 SmallVector<OpFoldResult> newExpandSizes;
210
211 int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0,
212 resultSizeIndex = 0;
213
214 for (size_t idx = 0, idxEnd = collapseReInds.size(); idx < idxEnd; idx++) {
215 auto &collapseReassociation = collapseReInds[idx];
216 auto &expandReassociation = expandReInds[idx];
217
218 // Case 1. The reassociations are same in the collapse producer
219 // and expand consumer. In the swapped expand, each of the final
220 // dimensions are kept as is in the expand and the collapse. So,
221 // for every element in the `ReassocationIndices` vector add a new
222 // `ReassociationIndices` vector for the swapped expand and collapse
223 // (of size 1).
224 if (collapseReassociation.size() == expandReassociation.size()) {
225 for (size_t i = 0; i < collapseReassociation.size(); ++i) {
226 newCollapseReInds.push_back({newCollapseIndex++});
227 newExpandReInds.push_back({newExpandIndex++});
228 newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
229 sourceSizeIndex++;
230 }
231 continue;
232 }
233
234 // Case 2. The `ReassociationIndices` in the collapse is of size > 1 (and
235 // in the expand is of size == 1). In this case, the original dimensions
236 // are preserved on expansion and collapsed subsequently.
237 if (collapseReassociation.size() != 1) {
238 ReassociationIndices newCollapseReassociation;
239 for (size_t i = 0; i < collapseReassociation.size(); ++i) {
240 newCollapseReassociation.push_back(newCollapseIndex++);
241 newExpandReInds.push_back({newExpandIndex++});
242 newExpandSizes.push_back(sourceSizes[sourceSizeIndex++]);
243 }
244 resultSizeIndex++;
245 newCollapseReInds.push_back(newCollapseReassociation);
246 continue;
247 }
248
249 // Case 3. The `ReassociationIndices` in the expand is of size > 1 (and
250 // in the collapse is of size == 1). In this case, the expansion happens
251 // first and the expanded dimensions are preserved on collapse.
252 ReassociationIndices newExpandReassociation;
253 for (size_t i = 0; i < expandReassociation.size(); ++i) {
254 newExpandReassociation.push_back(newExpandIndex++);
255 newCollapseReInds.push_back({newCollapseIndex++});
256 newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
257 }
258 newExpandReInds.push_back(newExpandReassociation);
259 sourceSizeIndex++;
260 }
261
262 // Swap reshape order.
263 SmallVector<Value> dynamicSizes;
264 SmallVector<int64_t> staticSizes;
265 dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
266 auto expandResultType = expandOp.getResultType().clone(staticSizes);
267 Value newCollapseSrc = collapseOp.getSrc();
268 // If the number of reassociation indices in the new `expand_shape` op
269 // matches the number of dimensions of the result, then the expand_shape
270 // is a no-op.
271 if (newExpandReInds.size() != newExpandSizes.size()) {
272 newCollapseSrc = tensor::ExpandShapeOp::create(
273 rewriter, loc, expandResultType, newCollapseSrc, newExpandReInds,
274 newExpandSizes);
275 }
276
277 // If the number of reassociation indices in the new `collapse_shape` op
278 // matches the number of dimensions of the source, then the collapse_shape
279 // is a no-op.
280 Value replacement = newCollapseSrc;
281 if (newCollapseReInds.size() != newExpandSizes.size()) {
282 replacement = tensor::CollapseShapeOp::create(
283 rewriter, loc, newCollapseSrc, newCollapseReInds);
284 }
285 rewriter.replaceOp(expandOp, replacement);
286 return success();
287 }
288};
289
290/// Converts `tensor.extract_slice(tensor.expand_shape)` to
291/// `tensor.expand_shape(tensor.extract_slice)`.
292///
293/// For this transformation to be possible, the slice must be fully contiguous
294/// within each reassociation group of the expand_shape. A slice is defined as
295/// fully contiguous within a reassociation group if after flattening the
296/// reassociation group to a single 1D range, then the slice taken out of the
297/// group could be defined as a single contiguous subrange within that range.
298///
299/// Rank reducing slices are not supported.
300///
301/// Example:
302/// The transformation is possible because each reassociation group has a
303/// contiguous slice (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4]).
304/// ```
305/// BEFORE:
306/// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
307/// tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
308/// %slice = tensor.extract_slice %reshape ...
309/// tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
310///
311/// AFTER:
312/// %slice = tensor.extract_slice %in ...
313/// tensor<8x16x32xf32> to tensor<8x5x4xf32>
314/// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
315/// tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
316/// ```
317///
318/// Note - this pattern could be extended to be a swap pattern between
319/// `tensor.expand_shape` and `tensor.extract_slice`, but is currently
320/// implemented only as a bubble up pattern for `tensor.extract_slice`.
321struct BubbleUpExtractSliceThroughExpandShape
322 : public OpRewritePattern<tensor::ExtractSliceOp> {
323 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
324
325 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
326 PatternRewriter &rewriter) const override {
327 auto expandShapeOp =
328 sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
329 if (!expandShapeOp) {
330 return rewriter.notifyMatchFailure(
331 sliceOp, "tensor.extract_slice source not produced by expand_shape");
332 }
333 SmallVector<ReassociationIndices> reassociation =
334 expandShapeOp.getReassociationIndices();
335
336 SmallVector<OpFoldResult> offsets, sizes, strides;
337 if (failed(getCollapsedExtractSliceInfo(rewriter, sliceOp, reassociation,
338 offsets, sizes, strides)))
339 return failure();
340
341 // The shape of the result can be obtained from the sizes passed in.
342 SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
343 RankedTensorType resultType = sliceOp.getResultType();
344
345 // Create a new ExtractSliceOp and ExpandShapeOp.
346 Location loc = sliceOp.getLoc();
347 Value newSliceOp = tensor::ExtractSliceOp::create(
348 rewriter, loc, expandShapeOp.getSrc(), offsets, sizes, strides);
349 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
350 sliceOp, resultType, newSliceOp,
351 expandShapeOp.getReassociationIndices(), expandedSizes);
352 return success();
353 }
354};
355
356/// Converts `tensor.extract_slice(tensor.collapse_shape)` to
357/// `tensor.collapse_shape(tensor.extract_slice)`.
358///
359/// For this transformation to be possible - after bubbling up, the extraction
360/// of the contiguous slice must be representable as a single slice obtained via
361/// tensor.extract_slice within each reassociation group of the src.
362///
363/// In case the size and offset extracted are static then this is possible if
364/// the following conditions are met within each reassociation group:
365/// Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the
366/// dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the
367/// shape of a desired slice. A slice of shape S can be extracted as a
368/// contiguous span of elements if and only if there exists an index k in {0, 1,
369/// ..., n} such that:
370/// S_i = 1 for all i < k (that is, all leading dimensions are singleton),
371/// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly
372/// one dimension),
373/// S_i = A_i for all i > k (that is, all trailing dimensions are preserved
374/// in full).
375/// In other words, the slice shape S must be of the form:
376/// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ]
377///
378/// In case the size and/or offset extracted are dynamic then this is possible
379/// only if there is single dimension in the reassociation group that has a size
380/// not equal to 1.
381/// In other words, the tensor shape must be of the form:
382/// [ 1, 1, ..., 1, A, 1, ...,1 ]
383/// Note - it might be possible to enable this pattern for more cases when the
384/// size/offset are dynamic via performing an analysis of the possible values
385/// that could be given to the size/offset.
386///
387/// Example:
388/// The transformation is possible because each reassociation group can be
389/// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?],
390/// [20->10]).
391/// ```
392/// BEFORE:
393/// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ...
394/// tensor<8x16x1x7x20f32> to tensor<128x7x20xf32>
395/// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1]
396/// tensor<128x7x20xf32> to tensor<32x?x10xf32>
397///
398/// AFTER:
399/// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10]
400// [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32>
401/// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ...
402/// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32>
403/// ```
404///
405/// Negative example:
406/// The transformation is not possible because we cannot use a single slice to
407/// represent the reassociation group [2x3x10->???]. If we would want the
408/// collapse to be after the extraction, we would need to extract multiple
409/// slices and concat them together.
410/// ```
411/// %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into
412/// tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] :
413/// tensor<60xf32> to tensor<15xf32>
414/// ```
415/// If we would want the collapse to be after the extraction, a possible
416/// alternate transformation could be to extract multiple slices and concat them
417/// together:
418/// ```
419/// %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] :
420/// tensor<2x3x10xf32> to tensor <1x1x10xf32>
421/// %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] :
422/// tensor<2x3x10xf32> to tensor <1x1x5xf32>
423/// %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} :
424/// (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32>
425/// %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32>
426/// to tensor<15xf32>
427/// ```
428/// But this is not the intended purpose of the transformation.
429struct BubbleUpExtractSliceThroughCollapseShape
430 : public OpRewritePattern<tensor::ExtractSliceOp> {
431 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
432
433 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
434 PatternRewriter &rewriter) const override {
435 auto collapseShapeOp =
436 sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
437 if (!collapseShapeOp) {
438 return rewriter.notifyMatchFailure(
439 sliceOp,
440 "tensor.extract_slice source not produced by tensor.collapse_shape");
441 }
442
443 SmallVector<OpFoldResult> offsets, sizes, strides;
445 rewriter, sliceOp, collapseShapeOp.getReassociationIndices(),
446 collapseShapeOp.getSrcType().getShape(), offsets, sizes, strides)))
447 return failure();
448
449 Value newSliceOp = tensor::ExtractSliceOp::create(
450 rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), offsets,
451 sizes, strides);
452 rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
453 sliceOp, sliceOp.getResultType(), newSliceOp,
454 collapseShapeOp.getReassociationIndices());
455
456 return success();
457 }
458};
459
460} // namespace
461
463 OpBuilder &b, tensor::ExtractSliceOp sliceOp,
464 ArrayRef<ReassociationIndices> reassociation,
465 SmallVectorImpl<OpFoldResult> &collapsedOffsets,
466 SmallVectorImpl<OpFoldResult> &collapsedSizes,
467 SmallVectorImpl<OpFoldResult> &collapsedStrides) {
468 if (!sliceOp.hasUnitStride()) {
469 return failure();
470 }
471
472 SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
473 SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
474
475 if (static_cast<size_t>(sliceOp.getResultType().getRank()) != sizes.size()) {
476 return failure();
477 }
478
479 auto isZeroOffsetAndFullSize = [&](OpFoldResult offset,
480 OpFoldResult sliceSize, int64_t inputDim) {
481 if (!isZeroInteger(offset))
482 return false;
483 ValueBoundsConstraintSet::Variable inputSize(sliceOp.getSource(), inputDim);
484 FailureOr<bool> maybeEqual =
485 ValueBoundsConstraintSet::areEqual(sliceSize, inputSize);
486 return llvm::succeeded(maybeEqual) && maybeEqual.value();
487 };
488
489 // Check that the slice is contiguous within each reassociation group.
490 // The slice is contiguous only if after the first dimension where a non
491 // unit slice is taken, the slice size on all subsequent dimensions of the
492 // group is equal to the entire size of the dimension.
493 // Examples of contiguous slices:
494 // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
495 // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
496 // Examples of non contiguous slices:
497 // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
498 // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
499 for (const ReassociationIndices &indices : reassociation) {
500 int64_t i = 0;
501 int64_t e = indices.size();
502 // Find the first expanded dim after the first dim with non-unit extracted
503 // size.
504 for (; i < e; ++i) {
505 if (!isOneInteger(sizes[indices[i]])) {
506 // +1 to skip the first non-unit size dim.
507 i++;
508 break;
509 }
510 }
511
512 // Verify that all subsequent dimensions extract the full size of the
513 // source tensor.
514 for (; i < e; ++i) {
515 int64_t expandedDim = indices[i];
516 if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
517 expandedDim)) {
518 return failure();
519 }
520 }
521 }
522
523 // The tensor.extract_slice before applying the pattern works on the result
524 // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
525 // referring to the state before applying the pattern are named with the
526 // prefix "expanded", and ones referring to the state after applying the
527 // pattern are named with the prefix "collapsed".
528 Location loc = sliceOp.getLoc();
529 SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
530 SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
531 SmallVector<OpFoldResult> expandedShape =
532 getMixedSizes(b, loc, sliceOp.getSource());
533
534 // Helper variables and function for accumulating the size values.
535 AffineExpr d0, d1;
536 bindDims(b.getContext(), d0, d1);
537 // Multiply two integers.
538 auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
539 auto mulMap = AffineMap::get(2, 0, {d0 * d1});
540 return affine::makeComposedFoldedAffineApply(b, loc, mulMap, {v1, v2});
541 };
542
543 // Compute new offsets, sizes, and strides for tensor.extract_slice.
544 // The new tensor.extract_slice will work on a tensor that has has a rank of
545 // ReassociationIndices.size(). In the loop a single offset, size, and
546 // stride value is computed per reassociation group.
547 for (const ReassociationIndices &indices : reassociation) {
548 // collapsedSize will hold the size of the single dim that represents the
549 // reassociation group in the non expanded tensor.
550 OpFoldResult collapsedSize = b.getIndexAttr(1);
551 // The reassocGroupSizes and reassocGroupOffsets are used to create an
552 // affine.linearize_index op to linearize the single offset value required
553 // for this reassociation group.
554 SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
555
556 for (long expandedDim : indices) {
557 // reassocGroupSizes and reassocGroupOffsets can be obtained directly
558 // from the expanded state, but the collapsed size requires calculation
559 // as it did not previously exist.
560 reassocGroupSizes.push_back(expandedShape[expandedDim]);
561 reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
562 collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
563 }
564
565 SmallVector<Value> offsetVals =
566 llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) {
567 return getValueOrCreateConstantIndexOp(b, loc, ofr);
568 });
569 OpFoldResult collapsedOffset = affine::AffineLinearizeIndexOp::create(
570 b, loc, offsetVals, reassocGroupSizes,
571 /*disjoint=*/true)
572 .getResult();
573 collapsedOffsets.push_back(collapsedOffset);
574 collapsedSizes.push_back(collapsedSize);
575
576 // Only unit stride is supported.
577 collapsedStrides.push_back(b.getIndexAttr(1));
578 }
579 return success();
580}
581
583 OpBuilder &b, tensor::ExtractSliceOp sliceOp,
584 ArrayRef<ReassociationIndices> reassociation,
585 ArrayRef<int64_t> expandedShape,
586 SmallVectorImpl<OpFoldResult> &expandedOffsets,
587 SmallVectorImpl<OpFoldResult> &expandedSizes,
588 SmallVectorImpl<OpFoldResult> &expandedStrides) {
589 if (!sliceOp.hasUnitStride()) {
590 return failure();
591 }
592
593 // The tensor.extract_slice before applying the pattern works on the result
594 // of the tensor.collapse_shape, so variables (i.e. inputs for
595 // ExtractSliceOp) referring to the state before applying the pattern are
596 // named with the prefix "collapsed", and ones referring to the state after
597 // applying the pattern are named with the prefix "expanded".
598 SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
599 SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
600 if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
601 collapsedSizes.size()) {
602 return failure();
603 }
604
605 // Compute new offsets, sizes, and strides for tensor.extract_slice.
606 // The new tensor.extract_slice will work on a tensor that has has a rank
607 // equal to the rank of the src of the collapse_shape. In each iteration of
608 // the loop, the offsets and sizes will be computed per reassociation group.
609 expandedStrides.resize(expandedShape.size(), b.getIndexAttr(1));
610 for (auto [collapsedSize, collapsedOffset, reassocIndices] :
611 llvm::zip_equal(collapsedSizes, collapsedOffsets, reassociation)) {
612 // CASE #1 - size and/or offset are dynamic.
613 // In this case, the slice can be represented as a contiguous slice only
614 // if there is a single dimension in the reassociation group that has a
615 // size not equal to 1.
616 if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
617 int nonUnitSizeCount = 0;
618 for (int64_t expandedShapeIdx : reassocIndices) {
619 if (expandedShape[expandedShapeIdx] != 1) {
620 nonUnitSizeCount++;
621 expandedSizes.push_back(collapsedSize);
622 expandedOffsets.push_back(collapsedOffset);
623 continue;
624 }
625
626 expandedSizes.push_back(b.getIndexAttr(1));
627 expandedOffsets.push_back(b.getIndexAttr(0));
628 }
629
630 if (nonUnitSizeCount != 1) {
631 return failure();
632 }
633 continue;
634 }
635
636 // CASE #2 = size and offset are static.
637 // Verify that the slice can be represented as a contiguous slice of the
638 // src of the collapse_shape.
639 // Checking this is done on order of most internal dimensions first,
640 // so traversal is done in reverse order of the reassociation group.
641 // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
642 // ...,An] then we first find the size and offset for n...k+1 then for k
643 // and then for k-1...0.
644
645 // currentCollapsedsize and currentCollapsedOffset are initialized with
646 // the original collapsed size and offset and divided by the expanded
647 // shape size in each dimension as we go along the reassociation group.
648 // In essence we are spreading the original collapsed size and offset over
649 // the various expanded slice dimensions.
650 // The variables are used both to check the validity of the slice and to
651 // compute the expanded sizes and offsets.
652 int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
653 int64_t currentCollapsedOffset =
654 getConstantIntValue(collapsedOffset).value();
655 SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
656 ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
657 reassocIndices.rend());
658 int64_t idx = 0;
659 int64_t reassocGroupSize = reassocIndices.size();
660
661 // First handle the trailing dimensions where the slice size should be
662 // equal to the tensor shape and the offset should be 0 (n...k+1).
663 for (; idx < reassocGroupSize; ++idx) {
664 int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
665
666 if (currentCollapsedsize < expandedShapeSize)
667 break;
668
669 // We need to make sure that the slice size can be set to the shape size
670 // and the offset to 0.
671 if ((currentCollapsedsize % expandedShapeSize) != 0 ||
672 (currentCollapsedOffset % expandedShapeSize) != 0) {
673 return failure();
674 }
675
676 groupExpandedSizes.push_back(b.getIndexAttr(expandedShapeSize));
677 groupExpandedOffsets.push_back(b.getIndexAttr(0));
678
679 currentCollapsedsize /= expandedShapeSize;
680 currentCollapsedOffset /= expandedShapeSize;
681 }
682
683 // Now handle the first dim where slicing occurs on (k).
684 if (idx < reassocGroupSize) {
685 int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
686 int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
687 // We need to make sure that the slice size in this dim + offset will
688 // not exceed the shape size.
689 if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
690 return failure();
691 }
692 groupExpandedSizes.push_back(b.getIndexAttr(currentCollapsedsize));
693 groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim));
694 currentCollapsedOffset /= expandedShapeSize;
695 }
696
697 // Now handle the leading dimensions where the slice size is equal to 1
698 // (k-1...0).
699 // The size for these dimensions must be 1 because of how we constructed
700 // the slice size of the expanded shape. We spread the original collapsed
701 // size over the expanded shape sizes until we reached dimension k where
702 // the remaining size was smaller than the expanded shape size, and spread
703 // the remaining size on it. So, now we are left with only 1s.
704 for (idx++; idx < reassocGroupSize; ++idx) {
705 int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
706 int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
707 groupExpandedSizes.push_back(b.getIndexAttr(1));
708 groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim));
709 currentCollapsedOffset /= expandedShapeSize;
710 }
711 expandedSizes.append(groupExpandedSizes.rbegin(),
712 groupExpandedSizes.rend());
713 expandedOffsets.append(groupExpandedOffsets.rbegin(),
714 groupExpandedOffsets.rend());
715 }
716 return success();
717}
718
722 .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
723 FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
724 FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
725 FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
726 FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
727 patterns.getContext());
728}
729
732 patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
733}
734
737 patterns.add<BubbleUpExtractSliceThroughExpandShape,
738 BubbleUpExtractSliceThroughCollapseShape>(patterns.getContext());
739}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
#define mul(a, b)
Base type for affine expression.
Definition AffineExpr.h:68
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
This class represents a single result from folding an operation.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
A variable that can be added to the constraint set as a "column".
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
LogicalResult getCollapsedExtractSliceInfo(OpBuilder &b, tensor::ExtractSliceOp sliceOp, ArrayRef< ReassociationIndices > reassociation, SmallVectorImpl< OpFoldResult > &collapsedOffsets, SmallVectorImpl< OpFoldResult > &collapsedSizes, SmallVectorImpl< OpFoldResult > &collapsedStrides)
Computes the offsets, sizes, and strides needed to build a collapsed sliceOp.
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
LogicalResult getExpandedExtractSliceInfo(OpBuilder &b, tensor::ExtractSliceOp sliceOp, ArrayRef< ReassociationIndices > reassociation, ArrayRef< int64_t > expandedShape, SmallVectorImpl< OpFoldResult > &expandedOffsets, SmallVectorImpl< OpFoldResult > &expandedSizes, SmallVectorImpl< OpFoldResult > &expandedStrides)
Computes the offsets, sizes, and strides needed to build an expanded sliceOp.
Include the generated interface declarations.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...