MLIR 23.0.0git
ReshapePatterns.cpp
Go to the documentation of this file.
1//===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "mlir/IR/Value.h"
16#include "llvm/ADT/STLExtras.h"
17#include "llvm/Support/LogicalResult.h"
18
19using namespace mlir;
20using namespace mlir::tensor;
21
22namespace {
23/// Fold expand_shape(extract_slice) ops that cancel itself out.
24struct FoldExpandOfRankReducingExtract
25 : public OpRewritePattern<ExpandShapeOp> {
26 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
27
28 LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
29 PatternRewriter &rewriter) const override {
30 RankedTensorType resultType = expandShapeOp.getResultType();
31 auto extractSliceOp =
32 expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
33 if (!extractSliceOp)
34 return failure();
35 RankedTensorType srcType = extractSliceOp.getSourceType();
36
37 // Only cases where the ExpandShapeOp can be folded away entirely are
38 // supported. Moreover, only simple cases where the resulting ExtractSliceOp
39 // has no rank-reduction anymore are supported at the moment.
40 RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
41 srcType, extractSliceOp.getStaticSizes());
42 if (nonReducingExtractType != resultType)
43 return failure();
44
45 SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
46 SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
47 SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
48 rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
49 expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
50 mixedStrides);
51 return success();
52 }
53};
54
55/// Fold collapse_shape which only removes static dimensions of size `1`
56/// into extract_slice.
57struct FoldUnPaddingCollapseIntoExtract
58 : public OpRewritePattern<tensor::CollapseShapeOp> {
59 using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern;
60
61 LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
62 PatternRewriter &rewriter) const override {
63 auto extractSliceOp =
64 collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
65 // Collapse cannot be folded away with multiple users of the extract slice
66 // and it is not necessarily beneficial to only convert the collapse into
67 // another extract slice.
68 if (!extractSliceOp || !extractSliceOp->hasOneUse())
69 return failure();
70
71 // Only fold away simple collapse where all removed dimensions have static
72 // size `1`.
74 collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
75 if (res != SliceVerificationResult::Success)
76 return rewriter.notifyMatchFailure(collapseShapeOp,
77 "expected unpadding collapse");
78
79 Value unPaddedExtractSlice = tensor::ExtractSliceOp::create(
80 rewriter, extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
81 extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
82 extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
83 rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice);
84 return success();
85 }
86};
87
88/// Fold insert_slice(collapse_shape) ops that cancel itself out.
89template <typename OpTy>
90struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
91 using OpRewritePattern<OpTy>::OpRewritePattern;
92
93 LogicalResult matchAndRewrite(OpTy insertSliceOp,
94 PatternRewriter &rewriter) const override {
95 auto collapseShapeOp =
96 insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
97 if (!collapseShapeOp)
98 return failure();
99 RankedTensorType srcType = collapseShapeOp.getSrcType();
100
101 // Only cases where the CollapseShapeOp can be folded away entirely are
102 // supported. Moreover, only simple cases where the resulting InsertSliceOp
103 // has no rank-reduction anymore are supported at the moment.
104 RankedTensorType nonReducingInsertType =
105 RankedTensorType::get(insertSliceOp.getStaticSizes(),
106 insertSliceOp.getDestType().getElementType());
107 if (nonReducingInsertType != srcType)
108 return failure();
109
110 SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
111 SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
112 SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
113 rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(),
114 insertSliceOp.getDest(), mixedOffsets,
115 mixedSizes, mixedStrides);
116 return success();
117 }
118};
119
120/// Fold expand_shape which only adds static dimensions of size `1`
121/// into insert_slice.
122template <typename OpTy>
123struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
124 using OpRewritePattern<OpTy>::OpRewritePattern;
125
126 LogicalResult matchAndRewrite(OpTy insertSliceOp,
127 PatternRewriter &rewriter) const override {
128 auto expandShapeOp = insertSliceOp.getSource()
129 .template getDefiningOp<tensor::ExpandShapeOp>();
130 if (!expandShapeOp)
131 return failure();
132
133 // Only fold away simple expansion where all added dimensions have static
134 // size `1`.
136 expandShapeOp.getResultType(), expandShapeOp.getSrcType());
137 if (res != SliceVerificationResult::Success)
138 return rewriter.notifyMatchFailure(insertSliceOp,
139 "expected rank increasing expansion");
140
141 rewriter.modifyOpInPlace(insertSliceOp, [&]() {
142 insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
143 });
144 return success();
145 }
146};
147
148/// Pattern to bubble up a tensor.expand_shape op through a producer
149/// tensor.collapse_shape op that has non intersecting reassociations.
150struct BubbleUpExpandThroughParallelCollapse
151 : public OpRewritePattern<tensor::ExpandShapeOp> {
152 using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
153
154 LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
155 PatternRewriter &rewriter) const override {
156 auto collapseOp =
157 expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
158 if (!collapseOp)
159 return failure();
160 auto expandReInds = expandOp.getReassociationIndices();
161 auto collapseReInds = collapseOp.getReassociationIndices();
162
163 // Special case where the collapsed tensor to expand is a 0-D tensor,
164 // then the reassociation maps will be empty and not produce valid results.
165 if (expandReInds.size() == 0) {
166 return failure();
167 }
168
169 // Reshapes are parallel to each other (by construction the number of
170 // reassociations specified in the collapse and expand are the same), if at
171 // any position
172 // 1. either the reassociation indices are of the same size, or
173 // 2. either the reassociation in the collapse or the expand is of size 1.
174 ArrayRef<int64_t> staticSourceSize = collapseOp.getSrcType().getShape();
175 ArrayRef<int64_t> staticResultSize = expandOp.getStaticOutputShape();
176 for (auto [expandReassociation, collapseReassociation] :
177 llvm::zip_equal(expandReInds, collapseReInds)) {
178 if (collapseReassociation.size() == expandReassociation.size()) {
179 // Even if the reassociations are the same, the collapse/expand should
180 // result in the same dimensions. i.e 4x8x2 into 64 should be expanded
181 // into 4x8x2 again. In presense of dynamic dimensions one can only
182 // verify "equality" when there is only one dynamic dimension present,
183 // and all other static dimensions are equal.
184 ArrayRef<int64_t> collapsedStaticShapes = staticSourceSize.slice(
185 collapseReassociation.front(), collapseReassociation.size());
186 int64_t numCollapsedDynamic =
187 llvm::count_if(collapsedStaticShapes, ShapedType::isDynamic);
188 ArrayRef<int64_t> expandedStaticShapes = staticResultSize.slice(
189 expandReassociation.front(), expandReassociation.size());
190 int64_t numExpandedDynamic =
191 llvm::count_if(expandedStaticShapes, ShapedType::isDynamic);
192 if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
193 collapsedStaticShapes != expandedStaticShapes) {
194 return failure();
195 }
196 continue;
197 }
198 // If the reassociations are not same, one or the other needs to be of
199 // size one.
200 if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
201 return failure();
202 }
203
204 // Compute new reassociation indices and expanded/collaped shapes.
205 SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
206 Location loc = expandOp->getLoc();
207 SmallVector<OpFoldResult> sourceSizes =
208 tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
209 SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape();
210 SmallVector<OpFoldResult> newExpandSizes;
211
212 int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0,
213 resultSizeIndex = 0;
214
215 for (size_t idx = 0, idxEnd = collapseReInds.size(); idx < idxEnd; idx++) {
216 auto &collapseReassociation = collapseReInds[idx];
217 auto &expandReassociation = expandReInds[idx];
218
219 // Case 1. The reassociations are same in the collapse producer
220 // and expand consumer. In the swapped expand, each of the final
221 // dimensions are kept as is in the expand and the collapse. So,
222 // for every element in the `ReassocationIndices` vector add a new
223 // `ReassociationIndices` vector for the swapped expand and collapse
224 // (of size 1).
225 if (collapseReassociation.size() == expandReassociation.size()) {
226 for (size_t i = 0; i < collapseReassociation.size(); ++i) {
227 newCollapseReInds.push_back({newCollapseIndex++});
228 newExpandReInds.push_back({newExpandIndex++});
229 newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
230 sourceSizeIndex++;
231 }
232 continue;
233 }
234
235 // Case 2. The `ReassociationIndices` in the collapse is of size > 1 (and
236 // in the expand is of size == 1). In this case, the original dimensions
237 // are preserved on expansion and collapsed subsequently.
238 if (collapseReassociation.size() != 1) {
239 ReassociationIndices newCollapseReassociation;
240 for (size_t i = 0; i < collapseReassociation.size(); ++i) {
241 newCollapseReassociation.push_back(newCollapseIndex++);
242 newExpandReInds.push_back({newExpandIndex++});
243 newExpandSizes.push_back(sourceSizes[sourceSizeIndex++]);
244 }
245 resultSizeIndex++;
246 newCollapseReInds.push_back(newCollapseReassociation);
247 continue;
248 }
249
250 // Case 3. The `ReassociationIndices` in the expand is of size > 1 (and
251 // in the collapse is of size == 1). In this case, the expansion happens
252 // first and the expanded dimensions are preserved on collapse.
253 ReassociationIndices newExpandReassociation;
254 for (size_t i = 0; i < expandReassociation.size(); ++i) {
255 newExpandReassociation.push_back(newExpandIndex++);
256 newCollapseReInds.push_back({newCollapseIndex++});
257 newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
258 }
259 newExpandReInds.push_back(newExpandReassociation);
260 sourceSizeIndex++;
261 }
262
263 // Swap reshape order.
264 SmallVector<Value> dynamicSizes;
265 SmallVector<int64_t> staticSizes;
266 dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
267 auto expandResultType = expandOp.getResultType().clone(staticSizes);
268 Value newCollapseSrc = collapseOp.getSrc();
269 // If the number of reassociation indices in the new `expand_shape` op
270 // matches the number of dimensions of the result, then the expand_shape
271 // is a no-op.
272 if (newExpandReInds.size() != newExpandSizes.size()) {
273 newCollapseSrc = tensor::ExpandShapeOp::create(
274 rewriter, loc, expandResultType, newCollapseSrc, newExpandReInds,
275 newExpandSizes);
276 }
277
278 // If the number of reassociation indices in the new `collapse_shape` op
279 // matches the number of dimensions of the source, then the collapse_shape
280 // is a no-op.
281 Value replacement = newCollapseSrc;
282 if (newCollapseReInds.size() != newExpandSizes.size()) {
283 replacement = tensor::CollapseShapeOp::create(
284 rewriter, loc, newCollapseSrc, newCollapseReInds);
285 }
286 rewriter.replaceOp(expandOp, replacement);
287 return success();
288 }
289};
290
291/// Converts `tensor.extract_slice(tensor.expand_shape)` to
292/// `tensor.expand_shape(tensor.extract_slice)`.
293///
294/// For this transformation to be possible, the slice must be fully contiguous
295/// within each reassociation group of the expand_shape. A slice is defined as
296/// fully contiguous within a reassociation group if after flattening the
297/// reassociation group to a single 1D range, then the slice taken out of the
298/// group could be defined as a single contiguous subrange within that range.
299///
300/// Rank reducing slices are not supported.
301///
302/// Example:
303/// The transformation is possible because each reassociation group has a
304/// contiguous slice (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4]).
305/// ```
306/// BEFORE:
307/// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
308/// tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
309/// %slice = tensor.extract_slice %reshape ...
310/// tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
311///
312/// AFTER:
313/// %slice = tensor.extract_slice %in ...
314/// tensor<8x16x32xf32> to tensor<8x5x4xf32>
315/// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
316/// tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
317/// ```
318///
319/// Note - this pattern could be extended to be a swap pattern between
320/// `tensor.expand_shape` and `tensor.extract_slice`, but is currently
321/// implemented only as a bubble up pattern for `tensor.extract_slice`.
322struct BubbleUpExtractSliceThroughExpandShape
323 : public OpRewritePattern<tensor::ExtractSliceOp> {
324 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
325
326 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
327 PatternRewriter &rewriter) const override {
328 auto expandShapeOp =
329 sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
330 if (!expandShapeOp) {
331 return rewriter.notifyMatchFailure(
332 sliceOp, "tensor.extract_slice source not produced by expand_shape");
333 }
334 SmallVector<ReassociationIndices> reassociation =
335 expandShapeOp.getReassociationIndices();
336
337 SmallVector<OpFoldResult> offsets, sizes, strides;
338 if (failed(getCollapsedExtractSliceInfo(rewriter, sliceOp, reassociation,
339 offsets, sizes, strides)))
340 return failure();
341
342 // The shape of the result can be obtained from the sizes passed in.
343 SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
344 RankedTensorType resultType = sliceOp.getResultType();
345
346 // Create a new ExtractSliceOp and ExpandShapeOp.
347 Location loc = sliceOp.getLoc();
348 Value newSliceOp = tensor::ExtractSliceOp::create(
349 rewriter, loc, expandShapeOp.getSrc(), offsets, sizes, strides);
350 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
351 sliceOp, resultType, newSliceOp,
352 expandShapeOp.getReassociationIndices(), expandedSizes);
353 return success();
354 }
355};
356
357/// Converts `tensor.extract_slice(tensor.collapse_shape)` to
358/// `tensor.collapse_shape(tensor.extract_slice)`.
359///
360/// For this transformation to be possible - after bubbling up, the extraction
361/// of the contiguous slice must be representable as a single slice obtained via
362/// tensor.extract_slice within each reassociation group of the src.
363///
364/// In case the size and offset extracted are static then this is possible if
365/// the following conditions are met within each reassociation group:
366/// Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the
367/// dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the
368/// shape of a desired slice. A slice of shape S can be extracted as a
369/// contiguous span of elements if and only if there exists an index k in {0, 1,
370/// ..., n} such that:
371/// S_i = 1 for all i < k (that is, all leading dimensions are singleton),
372/// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly
373/// one dimension),
374/// S_i = A_i for all i > k (that is, all trailing dimensions are preserved
375/// in full).
376/// In other words, the slice shape S must be of the form:
377/// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ]
378///
379/// In case the size and/or offset extracted are dynamic then this is possible
380/// only if there is single dimension in the reassociation group that has a size
381/// not equal to 1.
382/// In other words, the tensor shape must be of the form:
383/// [ 1, 1, ..., 1, A, 1, ...,1 ]
384/// Note - it might be possible to enable this pattern for more cases when the
385/// size/offset are dynamic via performing an analysis of the possible values
386/// that could be given to the size/offset.
387///
388/// Example:
389/// The transformation is possible because each reassociation group can be
390/// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?],
391/// [20->10]).
392/// ```
393/// BEFORE:
394/// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ...
395/// tensor<8x16x1x7x20f32> to tensor<128x7x20xf32>
396/// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1]
397/// tensor<128x7x20xf32> to tensor<32x?x10xf32>
398///
399/// AFTER:
400/// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10]
401// [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32>
402/// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ...
403/// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32>
404/// ```
405///
406/// Negative example:
407/// The transformation is not possible because we cannot use a single slice to
408/// represent the reassociation group [2x3x10->???]. If we would want the
409/// collapse to be after the extraction, we would need to extract multiple
410/// slices and concat them together.
411/// ```
412/// %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into
413/// tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] :
414/// tensor<60xf32> to tensor<15xf32>
415/// ```
416/// If we would want the collapse to be after the extraction, a possible
417/// alternate transformation could be to extract multiple slices and concat them
418/// together:
419/// ```
420/// %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] :
421/// tensor<2x3x10xf32> to tensor <1x1x10xf32>
422/// %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] :
423/// tensor<2x3x10xf32> to tensor <1x1x5xf32>
424/// %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} :
425/// (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32>
426/// %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32>
427/// to tensor<15xf32>
428/// ```
429/// But this is not the intended purpose of the transformation.
430struct BubbleUpExtractSliceThroughCollapseShape
431 : public OpRewritePattern<tensor::ExtractSliceOp> {
432 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
433
434 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
435 PatternRewriter &rewriter) const override {
436 auto collapseShapeOp =
437 sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
438 if (!collapseShapeOp) {
439 return rewriter.notifyMatchFailure(
440 sliceOp,
441 "tensor.extract_slice source not produced by tensor.collapse_shape");
442 }
443
444 SmallVector<OpFoldResult> offsets, sizes, strides;
446 rewriter, sliceOp, collapseShapeOp.getReassociationIndices(),
447 collapseShapeOp.getSrc(), offsets, sizes, strides)))
448 return failure();
449
450 Value newSliceOp = tensor::ExtractSliceOp::create(
451 rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), offsets,
452 sizes, strides);
453 rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
454 sliceOp, sliceOp.getResultType(), newSliceOp,
455 collapseShapeOp.getReassociationIndices());
456
457 return success();
458 }
459};
460
461} // namespace
462
464 OpBuilder &b, tensor::ExtractSliceOp sliceOp,
465 ArrayRef<ReassociationIndices> reassociation,
466 SmallVectorImpl<OpFoldResult> &collapsedOffsets,
467 SmallVectorImpl<OpFoldResult> &collapsedSizes,
468 SmallVectorImpl<OpFoldResult> &collapsedStrides) {
469 if (!sliceOp.hasUnitStride()) {
470 return failure();
471 }
472
473 SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
474 SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
475
476 if (static_cast<size_t>(sliceOp.getResultType().getRank()) != sizes.size()) {
477 return failure();
478 }
479
480 auto isZeroOffsetAndFullSize = [&](OpFoldResult offset,
481 OpFoldResult sliceSize, int64_t inputDim) {
482 if (!isZeroInteger(offset))
483 return false;
484 ValueBoundsConstraintSet::Variable inputSize(sliceOp.getSource(), inputDim);
485 FailureOr<bool> maybeEqual =
486 ValueBoundsConstraintSet::areEqual(sliceSize, inputSize);
487 return llvm::succeeded(maybeEqual) && maybeEqual.value();
488 };
489
490 // Check that the slice is contiguous within each reassociation group.
491 // The slice is contiguous only if after the first dimension where a non
492 // unit slice is taken, the slice size on all subsequent dimensions of the
493 // group is equal to the entire size of the dimension.
494 // Examples of contiguous slices:
495 // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
496 // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
497 // Examples of non contiguous slices:
498 // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
499 // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
500 for (const ReassociationIndices &indices : reassociation) {
501 int64_t i = 0;
502 int64_t e = indices.size();
503 // Find the first expanded dim after the first dim with non-unit extracted
504 // size.
505 for (; i < e; ++i) {
506 if (!isOneInteger(sizes[indices[i]])) {
507 // +1 to skip the first non-unit size dim.
508 i++;
509 break;
510 }
511 }
512
513 // Verify that all subsequent dimensions extract the full size of the
514 // source tensor.
515 for (; i < e; ++i) {
516 int64_t expandedDim = indices[i];
517 if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
518 expandedDim)) {
519 return failure();
520 }
521 }
522 }
523
524 // The tensor.extract_slice before applying the pattern works on the result
525 // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
526 // referring to the state before applying the pattern are named with the
527 // prefix "expanded", and ones referring to the state after applying the
528 // pattern are named with the prefix "collapsed".
529 Location loc = sliceOp.getLoc();
530 SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
531 SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
532 SmallVector<OpFoldResult> expandedShape =
533 getMixedSizes(b, loc, sliceOp.getSource());
534
535 // Helper variables and function for accumulating the size values.
536 AffineExpr d0, d1;
537 bindDims(b.getContext(), d0, d1);
538 // Multiply two integers.
539 auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
540 auto mulMap = AffineMap::get(2, 0, {d0 * d1});
541 return affine::makeComposedFoldedAffineApply(b, loc, mulMap, {v1, v2});
542 };
543
544 // Compute new offsets, sizes, and strides for tensor.extract_slice.
545 // The new tensor.extract_slice will work on a tensor that has has a rank of
546 // ReassociationIndices.size(). In the loop a single offset, size, and
547 // stride value is computed per reassociation group.
548 for (const ReassociationIndices &indices : reassociation) {
549 // collapsedSize will hold the size of the single dim that represents the
550 // reassociation group in the non expanded tensor.
551 OpFoldResult collapsedSize = b.getIndexAttr(1);
552 // The reassocGroupSizes and reassocGroupOffsets are used to create an
553 // affine.linearize_index op to linearize the single offset value required
554 // for this reassociation group.
555 SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
556
557 for (long expandedDim : indices) {
558 // reassocGroupSizes and reassocGroupOffsets can be obtained directly
559 // from the expanded state, but the collapsed size requires calculation
560 // as it did not previously exist.
561 reassocGroupSizes.push_back(expandedShape[expandedDim]);
562 reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
563 collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
564 }
565
566 SmallVector<Value> offsetVals =
567 llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) {
568 return getValueOrCreateConstantIndexOp(b, loc, ofr);
569 });
570 OpFoldResult collapsedOffset = affine::AffineLinearizeIndexOp::create(
571 b, loc, offsetVals, reassocGroupSizes,
572 /*disjoint=*/true)
573 .getResult();
574 collapsedOffsets.push_back(collapsedOffset);
575 collapsedSizes.push_back(collapsedSize);
576
577 // Only unit stride is supported.
578 collapsedStrides.push_back(b.getIndexAttr(1));
579 }
580 return success();
581}
582
583// Checks if the `ofr` is a multiple of the `factor`.
584// Handles both static integer and dynamic values
585// where the value is the result of an affine.apply.
586static bool isMultipleOf(OpFoldResult ofr, int64_t factor) {
587 std::optional<int64_t> staticValue = getConstantIntValue(ofr);
588 if (staticValue.has_value())
589 return staticValue.value() % factor == 0;
590
591 Value value = dyn_cast<Value>(ofr);
592 if (!value)
593 return false;
594 auto applyOp = value.getDefiningOp<affine::AffineApplyOp>();
595 if (!applyOp)
596 return false;
597 AffineMap map = applyOp.getAffineMap();
598 SmallVector<Value> operands(applyOp.getOperands());
600 map = simplifyAffineMap(map);
601 if (map.getNumResults() != 1)
602 return false;
603 return map.getResult(0).isMultipleOf(factor);
604}
605
606/// Given a `collapsedOffset` and `collapsedSize`, this function
607/// validates that the slice is representable as a contiguous slice
608/// in the `expandedShape` and computes the corresponding expanded sizes.
609/// Returns failure if the slice cannot be guaranteed to be contiguous.
610/// On success, populates `groupSizes` with the expanded sizes for each
611/// dimension in the reassociation group.
613 OpBuilder &b, OpFoldResult collapsedSize, OpFoldResult collapsedOffset,
614 const ReassociationIndices &reassocIndices, ArrayRef<int64_t> expandedShape,
615 SmallVectorImpl<OpFoldResult> &groupSizes) {
616 assert(groupSizes.empty() && "Group sizes must be empty");
617 // The first case is when there's only one non-unit dimension in the
618 // reassociation group.
619 // When there's only one non-unit dimension, the slice is trivially
620 // contiguous - offset and size go directly on that dimension.
621 // This works for both dynamic size and dynamic offset.
622 int nonUnitSizeCount = llvm::count_if(
623 reassocIndices, [&expandedShape](int64_t expandedShapeIdx) {
624 return expandedShape[expandedShapeIdx] != 1;
625 });
626 if (nonUnitSizeCount == 1) {
627 for (int64_t expandedShapeIdx : reassocIndices) {
628 if (expandedShape[expandedShapeIdx] != 1)
629 groupSizes.push_back(collapsedSize);
630 else
631 groupSizes.push_back(b.getIndexAttr(1));
632 }
633 return success();
634 }
635
636 // Having dynamic extracted size requires additional complex
637 // analysis to guarantee contiguous slicing.
638 if (isa<Value>(collapsedSize))
639 return failure();
640
641 std::optional<int64_t> staticSize = getConstantIntValue(collapsedSize);
642 assert(staticSize.has_value() && "Expected static size");
643
644 // The extracted size is only one element, offset may be static
645 // or dynamic, It's a trivial case where we always can guarantee
646 // contiguous slicing.
647 if (staticSize.value() == 1) {
648 for (size_t i = 0; i < reassocIndices.size(); ++i)
649 groupSizes.push_back(b.getIndexAttr(1));
650
651 return success();
652 }
653
654 // Size is static and greater than 1, offset may be static or dynamic.
655 // Use traversal to find dimension k where slicing occurs.
656 // Verify that the slice can be represented as a contiguous slice of the
657 // src of the collapse_shape.
658 // Checking this is done on order of most internal dimensions first,
659 // so traversal is done in reverse order of the reassociation group.
660 // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
661 // ...,An] then we first find the size and offset for n...k+1 then for k
662 // and then for k-1...0.
663
664 // currentCollapsedsize is initialized with the original collapsed size
665 // and divided by the expanded shape size in each dimension as we go along
666 // the reassociation group. In essence we are spreading the original
667 // collapsed size over the various expanded slice dimensions.
668 // currentOffsetDivisor is initialized with 1 and multiplied by the expanded
669 // shape size in each dimension as we go along the reassociation group.
670 // These variables are used both to check the validity of the slice and to
671 // compute the expanded sizes and offsets.
672 assert(staticSize.value() > 1 && "Expected size to be greater than 1");
673 int64_t currentCollapsedsize = staticSize.value();
674 int64_t currentOffsetDivisor = 1;
675
676 ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
677 reassocIndices.rend());
678 int64_t idx = 0;
679 int64_t reassocGroupSize = reassocIndices.size();
680
681 // First handle the trailing dimensions where the slice size should be
682 // equal to the tensor shape and the offset should be 0 (n...k+1).
683 for (; idx < reassocGroupSize; ++idx) {
684 int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
685 if (expandedShapeSize == ShapedType::kDynamic)
686 return failure();
687
688 if (currentCollapsedsize < expandedShapeSize)
689 break;
690
691 // Check size divisibility.
692 if ((currentCollapsedsize % expandedShapeSize) != 0)
693 return failure();
694
695 // Check dynamic/static offset divisibility.
696 currentOffsetDivisor *= expandedShapeSize;
697 if (!isMultipleOf(collapsedOffset, currentOffsetDivisor))
698 return failure();
699
700 // Trailing dims get full shape and zero offset.
701 groupSizes.push_back(b.getIndexAttr(expandedShapeSize));
702 currentCollapsedsize /= expandedShapeSize;
703 }
704
705 // Now handle the first dim where slicing occurs on (k).
706 if (idx < reassocGroupSize) {
707 int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
708 std::optional<int64_t> staticOffset = getConstantIntValue(collapsedOffset);
709
710 if (staticOffset.has_value()) {
711 // Static offset: check that offset + size doesn't exceed dimension.
712 int64_t offsetInDim =
713 (staticOffset.value() / currentOffsetDivisor) % expandedShapeSize;
714 if ((currentCollapsedsize + offsetInDim) > expandedShapeSize)
715 return failure();
716 } else {
717 // If the offset is dynamic, We could have more restricted conditions
718 // to guarantee contiguous slicing.
719 // For example, we could require that the dimension is divisible by the
720 // slice size and the offset is a multiple of the slice size.
721 // For more complex cases, we could use valueBoundsInterface
722 // to check the validity of the range.
723 if ((expandedShapeSize % currentCollapsedsize) != 0)
724 return failure();
725 if (!isMultipleOf(collapsedOffset, staticSize.value()))
726 return failure();
727 }
728 // Slicing dimension gets the remaining collapsed size.
729 groupSizes.push_back(b.getIndexAttr(currentCollapsedsize));
730 }
731
732 // Now handle the leading dimensions where the slice size is equal to 1
733 // (k-1...0).
734 // The size for these dimensions must be 1 because of how we constructed
735 // the slice size of the expanded shape. We spread the original collapsed
736 // size over the expanded shape sizes until we reached dimension k where
737 // the remaining size was smaller than the expanded shape size, and spread
738 // the remaining size on it. So, now we are left with only 1s.
739 for (idx++; idx < reassocGroupSize; ++idx)
740 groupSizes.push_back(b.getIndexAttr(1));
741
742 // Sizes were built in reverse order, so reverse them.
743 groupSizes = llvm::to_vector(llvm::reverse(groupSizes));
744 return success();
745}
746
748 OpBuilder &b, tensor::ExtractSliceOp sliceOp,
749 ArrayRef<ReassociationIndices> reassociation, Value expandedValue,
750 SmallVectorImpl<OpFoldResult> &expandedOffsets,
751 SmallVectorImpl<OpFoldResult> &expandedSizes,
752 SmallVectorImpl<OpFoldResult> &expandedStrides) {
753 if (!sliceOp.hasUnitStride()) {
754 return failure();
755 }
756
757 // The tensor.extract_slice before applying the pattern works on the result
758 // of the tensor.collapse_shape, so variables (i.e. inputs for
759 // ExtractSliceOp) referring to the state before applying the pattern are
760 // named with the prefix "collapsed", and ones referring to the state after
761 // applying the pattern are named with the prefix "expanded".
762 SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
763 SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
764 if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
765 collapsedSizes.size()) {
766 return failure();
767 }
768
769 // Compute new offsets, sizes, and strides for tensor.extract_slice.
770 // The new tensor.extract_slice will work on a tensor that has has a rank
771 // equal to the rank of the src of the collapse_shape. In each iteration of
772 // the loop, the offsets and sizes will be computed per reassociation group.
773 ArrayRef<int64_t> expandedShape =
774 cast<RankedTensorType>(expandedValue.getType()).getShape();
776 for (auto [collapsedSize, collapsedOffset, reassocIndices] :
777 llvm::zip_equal(collapsedSizes, collapsedOffsets, reassociation)) {
778
779 SmallVector<OpFoldResult> groupSizes;
781 b, collapsedSize, collapsedOffset, reassocIndices, expandedShape,
782 groupSizes);
783 if (failed(result))
784 return failure();
785 groupResults.emplace_back(groupSizes);
786 }
787
788 expandedStrides.resize(expandedShape.size(), b.getIndexAttr(1));
789 for (auto [groupIdx, reassocIndices] : llvm::enumerate(reassociation)) {
790 auto &sizes = groupResults[groupIdx];
791 expandedSizes.append(sizes);
792
794 for (int64_t expandedShapeIdx : reassocIndices)
795 basis.push_back(tensor::getMixedSize(b, sliceOp.getLoc(), expandedValue,
796 expandedShapeIdx));
797
798 OpFoldResult collapsedOffset = collapsedOffsets[groupIdx];
799 Value offsetVal =
800 getValueOrCreateConstantIndexOp(b, sliceOp.getLoc(), collapsedOffset);
801 auto delinearizeOp = affine::AffineDelinearizeIndexOp::create(
802 b, sliceOp.getLoc(), offsetVal, basis, /*hasOuterBound=*/true);
803 for (OpResult result : delinearizeOp.getResults())
804 expandedOffsets.push_back(result);
805 }
806 return success();
807}
808
812 .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
813 FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
814 FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
815 FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
816 FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
817 patterns.getContext());
818}
819
822 patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
823}
824
827 patterns.add<BubbleUpExtractSliceThroughExpandShape,
828 BubbleUpExtractSliceThroughCollapseShape>(patterns.getContext());
829}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static LogicalResult computeExpandedSliceInfoForReassocGroup(OpBuilder &b, OpFoldResult collapsedSize, OpFoldResult collapsedOffset, const ReassociationIndices &reassocIndices, ArrayRef< int64_t > expandedShape, SmallVectorImpl< OpFoldResult > &groupSizes)
Given a collapsedOffset and collapsedSize, this function validates that the slice is representable as...
static bool isMultipleOf(OpFoldResult ofr, int64_t factor)
#define mul(a, b)
Base type for affine expression.
Definition AffineExpr.h:68
bool isMultipleOf(int64_t factor) const
Return true if the affine expression is a multiple of 'factor'.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
This class represents a single result from folding an operation.
This is a value defined by a result of an operation.
Definition Value.h:457
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
A variable that can be added to the constraint set as a "column".
static FailureOr< bool > areEqual(const Variable &var1, const Variable &var2)
Compute whether the given variables are equal.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
LogicalResult getCollapsedExtractSliceInfo(OpBuilder &b, tensor::ExtractSliceOp sliceOp, ArrayRef< ReassociationIndices > reassociation, SmallVectorImpl< OpFoldResult > &collapsedOffsets, SmallVectorImpl< OpFoldResult > &collapsedSizes, SmallVectorImpl< OpFoldResult > &collapsedStrides)
Computes the offsets, sizes, and strides needed to build a collapsed sliceOp.
LogicalResult getExpandedExtractSliceInfo(OpBuilder &b, tensor::ExtractSliceOp sliceOp, ArrayRef< ReassociationIndices > reassociation, Value expandedValue, SmallVectorImpl< OpFoldResult > &expandedOffsets, SmallVectorImpl< OpFoldResult > &expandedSizes, SmallVectorImpl< OpFoldResult > &expandedStrides)
Computes the offsets, sizes, and strides needed to build an expanded sliceOp.
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns)
Populates patterns with patterns that bubble up tensor.expand_shape through tensor....
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:59
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:68
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...