MLIR 23.0.0git
ReshapeOpsUtils.h
Go to the documentation of this file.
1//===- ReshapeOpsUtils.h - Utilities used by reshape ops --*- C++ -*------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This header file defines utilities and common canonicalization patterns for
10// reshape operations.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
15#define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
16
21#include "mlir/Support/LLVM.h"
22#include "llvm/ADT/StringRef.h"
23#include <optional>
24
25namespace mlir {
26
30
31/// Attribute name for the ArrayAttr which encodes reassociation indices.
32constexpr StringRef getReassociationAttrName() { return "reassociation"; }
33
34/// Compose reassociation maps that are used in pair of reshape ops where one
35/// is a producer and other is the consumer. Only valid to use this method when
36/// both the producer and consumer are collapsing dimensions or both are
37/// expanding dimensions.
38///
39/// For example,
40/// producerReassociation = [[0, 1], [2], [3, 4]]
41/// consumerReassociation = [[0, 1], [2]]
42///
43/// is folded into
44///
45/// result = [[0, 1, 2], [3, 4]].
46std::optional<SmallVector<ReassociationIndices>> composeReassociationIndices(
47 ArrayRef<ReassociationIndices> producerReassociations,
48 ArrayRef<ReassociationIndices> consumerReassociations,
49 MLIRContext *context);
50
51/// Convert reassociation indices to affine expressions.
52SmallVector<SmallVector<AffineExpr, 2>, 2> convertReassociationIndicesToExprs(
53 MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices);
54
55/// Constructs affine maps out of Array<Array<AffineExpr>>.
56SmallVector<AffineMap, 4>
57getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation);
58
59/// Wraps a list of reassociations in an ArrayAttr.
62 ArrayRef<ReassociationIndices> reassociation);
63
64/// Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
65SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices(
66 ArrayRef<ReassociationExprs> reassociationExprs);
67
68/// Return the reassociations maps to use to reshape given the source type and
69/// the target type when possible. Return std::nullopt when this computation
70/// failed.
71std::optional<SmallVector<ReassociationIndices>>
72getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType);
73
74/// Returns the reassociation maps to collapse `sourceShape` to `targetShape` if
75/// possible.
76std::optional<SmallVector<ReassociationIndices>>
77getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
78 ArrayRef<int64_t> targetShape);
79
80/// Return true if the reassociation specification is valid, false otherwise.
81/// When false, the `invalidIndex` integer pointer is optionally filled with the
82/// index of the offending reassociation map.
83bool isReassociationValid(ArrayRef<AffineMap> reassociation,
84 int *invalidIndex = nullptr);
85
86template <typename ReshapeOpTy, typename InverseReshapeOpTy>
87static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
88 ArrayRef<Attribute> operands) {
89 // Fold identity reshape.
90 if (reshapeOp.getSrcType() == reshapeOp.getType())
91 return reshapeOp.getSrc();
92
93 // Reshape of a constant can be replaced with a new constant, but only when
94 // the result type has a static shape. DenseElementsAttr::reshape requires
95 // a static shape to preserve the element count invariant.
96 if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front())) {
97 auto resultType = cast<ShapedType>(reshapeOp.getResult().getType());
98 if (resultType.hasStaticShape())
99 return elements.reshape(resultType);
100 }
101
102 // Fold if the producer reshape source has the same shape with at most 1
103 // dynamic dimension.
104 auto reshapeSrcOp =
105 reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
106 if (!reshapeSrcOp)
107 return nullptr;
108 auto srcType = reshapeSrcOp.getSrcType();
109 auto resultType = reshapeOp.getResultType();
110 if (srcType != resultType)
111 return nullptr;
112
113 if (llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) {
114 return reshapeSrcOp.getSrc();
115 }
116
117 // Fold producer-consumer reshape ops when they are perfect inverses of each
118 // other:
119 // 1) Reassociation indices are equivalent.
120 // 2) Boundary types are equivalent.
121 // 3) No reassociations have more than 1 dynamic dimension, and reassociated
122 // shapes are equal for each reassociation.
123 auto reassociations = reshapeOp.getReassociationIndices();
124 if (reassociations != reshapeSrcOp.getReassociationIndices())
125 return nullptr;
126 // If the reshapes are expanding and then collapsing, the ops can be folded
127 // despite multiple dynamic dimensions.
128 if (srcType.getRank() < reshapeSrcOp.getResultType().getRank())
129 return reshapeSrcOp.getSrc();
130 if (llvm::all_of(reassociations, [&](auto reInd) {
131 ArrayRef<int64_t> srcSlice =
132 srcType.getShape().slice(reInd.front(), reInd.size());
133 return llvm::count_if(srcSlice, ShapedType::isDynamic) < 2;
134 })) {
135 return reshapeSrcOp.getSrc();
136 }
137 return nullptr;
138}
139
140/// Common verifier for reshape-like types. Fills `expandedType` and
141///`collapsedType` with the proper `src` or `result` type.
142template <typename Op, typename T>
143static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
144 T collapsedType, bool isExpansion) {
145
146 unsigned expandedRank = expandedType.getRank();
147 unsigned collapsedRank = collapsedType.getRank();
148 if (expandedRank < collapsedRank)
149 return op.emitOpError("expected the expanded type, ")
150 << expandedType << " to have a higher (or same) rank "
151 << "than the collapsed type, " << collapsedType << '.';
152
153 if (collapsedRank != op.getReassociation().size())
154 return op.emitOpError("expected collapsed rank (")
155 << collapsedRank << ") to equal the number of reassociation maps ("
156 << op.getReassociation().size() << ").";
157
158 auto maps = op.getReassociationMaps();
159 for (auto it : llvm::enumerate(maps))
160 if (it.value().getNumDims() != expandedRank)
161 return op.emitOpError("expected reassociation map #")
162 << it.index() << " to have size equal to the expanded rank ("
163 << expandedRank << "), but it is " << it.value().getNumDims()
164 << '.';
165
166 int invalidIdx = 0;
167 if (!isReassociationValid(maps, &invalidIdx))
168 return op.emitOpError("expected reassociation map #")
169 << invalidIdx << " to be valid and contiguous.";
170
171 return reshapeLikeShapesAreCompatible(
172 [&](const Twine &msg) { return op->emitOpError(msg); },
173 collapsedType.getShape(), expandedType.getShape(),
174 op.getReassociationIndices(), isExpansion);
175}
176
177/// Verify that shapes of the reshaped types using following rule:
178/// if a dimension in the collapsed type is static, then the corresponding
179/// dimensions in the expanded shape should be
180/// a) static
181/// b) the product should be same as the collaped shape.
182LogicalResult reshapeLikeShapesAreCompatible(
183 function_ref<LogicalResult(const Twine &)> emitError,
184 ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
185 ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape);
186
187/// Returns true iff the type is a MemRefType and has a non-identity layout.
188bool hasNonIdentityLayout(Type type);
189
190enum class ReshapeOpKind { kExpand, kCollapse };
191
192/// Pattern to collapse producer/consumer reshape ops that are both collapsing
193/// dimensions or are both expanding dimensions.
194template <typename ReshapeOpTy, ReshapeOpKind opKind>
195struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
196 using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
197 LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
198 PatternRewriter &rewriter) const override {
199 auto srcReshapeOp =
200 reshapeOp.getSrc().template getDefiningOp<ReshapeOpTy>();
201 if (!srcReshapeOp)
202 return failure();
203
204 ShapedType resultType = reshapeOp.getResultType();
205
206 if (hasNonIdentityLayout(srcReshapeOp.getSrc().getType()) ||
207 hasNonIdentityLayout(reshapeOp.getSrc().getType()) ||
208 hasNonIdentityLayout(reshapeOp.getResult().getType()))
209 return failure();
210
211 std::optional<SmallVector<ReassociationIndices>> reassociationIndices =
212 composeReassociationIndices(srcReshapeOp.getReassociationIndices(),
213 reshapeOp.getReassociationIndices(),
214 rewriter.getContext());
215 if (!reassociationIndices)
216 return failure();
217
218 if constexpr (opKind == ReshapeOpKind::kExpand) {
219 SmallVector<OpFoldResult> outputShape(
220 getMixedValues(reshapeOp.getStaticOutputShape(),
221 reshapeOp.getOutputShape(), rewriter));
222 rewriter.replaceOpWithNewOp<ReshapeOpTy>(
223 reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices,
224 outputShape);
225 } else {
226 rewriter.replaceOpWithNewOp<ReshapeOpTy>(
227 reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
228 }
229 return success();
230 }
231};
232
233/// Pattern to compose
234/// `collapse_shape(expand_shape(%src, reassociation_1), reassociation_2)`.
235/// In that case both `srcType` and `resultType` can be expressed as a function
236/// of `intermediateType`.
237/// In order to demonstrate the approach, let's assume that `rank(srcType) >
238/// `rank(resultType)`, i.e. the resulting operation should be `collapse_shape`.
239/// In that case, we can iterate over every set of indices in `reassociation_2`
240/// and try to find ids of sets of indices in `reassociation_1` that cover it
241/// completely.
242///
243/// Example:
244///
245/// %0 = tensor.expand_shape %arg [[0], [1], [2, 3]]
246/// : tensor<?x?x?xi64> into tensor<?x?x?x1xi64>
247/// %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
248/// : tensor<?x?x?x1xi64> into tensor<?x?xi64>
249///
250/// can be canonicalized into
251///
252/// %0 = tensor.collapse_shape %arg [[0, 1], [2]]
253/// : tensor<?x?x?xi64> into tensor<?x?xi64>
254///
255/// because [0] and [1] from `expand_shape` reassociation cover completely
256/// `[0, 1]` from `collapse_shape`. If it is impossible to find such union of
257/// indices, then we fail.
258//
259/// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
260/// `reassociation_2` and produce `expand_shape`.
261template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy,
262 typename DimOpTy, typename TensorTy>
263struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
264 using OpRewritePattern<CollapseOpTy>::OpRewritePattern;
265 LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
266 PatternRewriter &rewriter) const override {
267 auto expandOp = collapseOp.getSrc().template getDefiningOp<ExpandOpTy>();
268 if (!expandOp)
269 return failure();
270
271 ShapedType srcType = expandOp.getSrcType();
272 ShapedType resultType = collapseOp.getResultType();
273
274 if (hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
275 hasNonIdentityLayout(expandOp.getSrc().getType()) ||
276 hasNonIdentityLayout(expandOp.getResult().getType()))
277 return failure();
278
279 int64_t srcRank = srcType.getRank();
280 int64_t resultRank = resultType.getRank();
281 if (srcType == resultType)
282 return failure();
283
284 SmallVector<ReassociationIndices, 4> higherRankReassociation,
285 lowerRankReassociation;
286
287 if (srcRank > resultRank) {
288 higherRankReassociation = expandOp.getReassociationIndices();
289 lowerRankReassociation = collapseOp.getReassociationIndices();
290 } else {
291 higherRankReassociation = collapseOp.getReassociationIndices();
292 lowerRankReassociation = expandOp.getReassociationIndices();
293 }
294
295 size_t higherRankIndicesID = 0;
296 SmallVector<ReassociationIndices, 4> composedReassociation;
297 for (const auto &lowerRankIndices : lowerRankReassociation) {
298 ReassociationIndices composedIndices;
299 while (higherRankIndicesID < higherRankReassociation.size()) {
300 auto rightmostIndex =
301 higherRankReassociation[higherRankIndicesID].back();
302 if (rightmostIndex > lowerRankIndices.back())
303 return failure();
304 composedIndices.push_back(higherRankIndicesID++);
305 if (rightmostIndex == lowerRankIndices.back())
306 break;
307 }
308 composedReassociation.push_back(composedIndices);
309 }
310 if (srcRank > resultRank) {
311 rewriter.replaceOpWithNewOp<CollapseOpTy>(
312 collapseOp, resultType, expandOp.getSrc(), composedReassociation);
313 } else if (srcRank < resultRank) {
314 // Compute the dynamic output shape for the new expand_shape op.
315 Location loc = collapseOp.getLoc();
316 SmallVector<OpFoldResult> origOutputShape =
317 expandOp.getMixedOutputShape();
318 SmallVector<OpFoldResult> newOutputShape;
319 for (const ReassociationIndices &indices :
320 collapseOp.getReassociationIndices()) {
321 int64_t numStaticElems = 1;
322 SmallVector<Value> dynamicSizes;
323 for (int64_t idx : indices) {
324 OpFoldResult size = origOutputShape[idx];
325 if (std::optional<int64_t> maybeCst = getConstantIntValue(size)) {
326 numStaticElems *= maybeCst.value();
327 continue;
328 }
329 dynamicSizes.push_back(cast<Value>(size));
330 }
331 if (dynamicSizes.empty()) {
332 newOutputShape.push_back(rewriter.getIndexAttr(numStaticElems));
333 continue;
334 }
335
336 // There is at least one dynamic size, so we can initialize `result` to
337 // the first dynamic size.
338 Value result = dynamicSizes[0];
339 for (Value v : llvm::drop_begin(dynamicSizes))
340 result = arith::MulIOp::create(rewriter, loc, result, v,
341 arith::IntegerOverflowFlags::nsw);
342 if (numStaticElems != 1) {
343 result = arith::MulIOp::create(
344 rewriter, loc, result,
345 arith::ConstantIndexOp::create(rewriter, loc, numStaticElems),
346 arith::IntegerOverflowFlags::nsw);
347 }
348 newOutputShape.push_back(result);
349 }
350 rewriter.replaceOpWithNewOp<ExpandOpTy>(
351 collapseOp, resultType, expandOp.getSrc(), composedReassociation,
352 newOutputShape);
353 } else {
354 // Collapses/expansions that do not change the rank are not allowed. Use
355 // a cast instead.
356 assert(llvm::equal(srcType.getShape(), resultType.getShape()) &&
357 "expected same shape");
358 rewriter.replaceOpWithNewOp<CastOpTy>(collapseOp, resultType,
359 expandOp.getSrc());
360 }
361 return success();
362 }
363};
364
365template <typename ExpandOpTy, typename CollapseOpTy, typename CastOpTy>
366struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
367 using OpRewritePattern<ExpandOpTy>::OpRewritePattern;
368 LogicalResult matchAndRewrite(ExpandOpTy expandOp,
369 PatternRewriter &rewriter) const override {
370 auto collapseOp = expandOp.getSrc().template getDefiningOp<CollapseOpTy>();
371 if (!collapseOp)
372 return failure();
373
374 ShapedType srcType = collapseOp.getSrcType();
375 ShapedType resultType = expandOp.getResultType();
376
377 if (hasNonIdentityLayout(expandOp.getSrc().getType()) ||
378 hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
379 hasNonIdentityLayout(collapseOp.getResult().getType())) {
380 if (srcType.hasStaticShape() &&
381 CastOpTy::areCastCompatible(srcType, resultType)) {
382 rewriter.replaceOpWithNewOp<CastOpTy>(expandOp, resultType,
383 collapseOp.getSrc());
384 return success();
385 }
386 return failure();
387 }
388
389 int64_t srcRank = srcType.getRank();
390 int64_t resultRank = resultType.getRank();
391 if (srcRank == resultRank)
392 return failure();
393
394 auto srcReassociation = collapseOp.getReassociationIndices();
395 auto resultReassociation = expandOp.getReassociationIndices();
396 if (srcRank > resultRank) {
397 auto composedReassociation = findCollapsingReassociation(
398 srcReassociation, resultReassociation, srcType.getShape(),
399 resultType.getShape());
400 if (!composedReassociation)
401 return failure();
402
403 rewriter.replaceOpWithNewOp<CollapseOpTy>(
404 expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
405 return success();
406 }
407 auto composedReassociation =
408 findCollapsingReassociation(resultReassociation, srcReassociation,
409 resultType.getShape(), srcType.getShape());
410 if (!composedReassociation)
411 return failure();
412
414 expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
415 rewriter.replaceOpWithNewOp<ExpandOpTy>(
416 expandOp, resultType, collapseOp.getSrc(), *composedReassociation,
417 outputShape);
418 return success();
419 }
420
421private:
422 // Attempts to find a way to collapse `srcShape` to `resultShape` by
423 // collapsing subshapes defined by the reassociation indices.
424 std::optional<SmallVector<ReassociationIndices>> findCollapsingReassociation(
425 ArrayRef<ReassociationIndices> srcReassociation,
426 ArrayRef<ReassociationIndices> resultReassociation,
427 ArrayRef<int64_t> srcShape, ArrayRef<int64_t> resultShape) const {
428 SmallVector<ReassociationIndices, 4> composedReassociation;
429
430 if (srcReassociation.empty())
431 return {getReassociationIndicesForCollapse(srcShape, resultShape)};
432
433 for (auto item : llvm::zip(srcReassociation, resultReassociation)) {
434 auto &srcIndices = std::get<0>(item);
435 auto &resultIndices = std::get<1>(item);
436 auto srcSubShape = srcShape.slice(srcIndices.front(), srcIndices.size());
437 auto resultSubShape =
438 resultShape.slice(resultIndices.front(), resultIndices.size());
439
440 if (llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2 &&
441 llvm::count_if(resultSubShape, ShapedType::isDynamic) >= 2)
442 return std::nullopt;
443
444 if (srcSubShape.size() == resultSubShape.size()) {
445 if (srcSubShape != resultSubShape)
446 return std::nullopt;
447
448 for (auto index : llvm::seq<int64_t>(0, srcSubShape.size())) {
449 composedReassociation.emplace_back(1, srcIndices.front() + index);
450 }
451 continue;
452 }
453
454 // Find reassociation to collapse `srcSubShape` into `resultSubShape`.
455 auto subShapeReassociation =
456 getReassociationIndicesForCollapse(srcSubShape, resultSubShape);
457 if (!subShapeReassociation)
458 return std::nullopt;
459
460 // Remap the subshape indices back to the original srcShape.
461 for (auto &subshapeIndices : *subShapeReassociation) {
462 ReassociationIndices shapeIndices;
463 for (int64_t index : subshapeIndices)
464 shapeIndices.push_back(srcIndices.front() + index);
465 composedReassociation.push_back(shapeIndices);
466 }
467 }
468 return {std::move(composedReassociation)};
469 }
470};
471
472/// The input parameters `offsets`, `sizes`, `strides` specify a rectangular
473/// non rank-reducing slice of the collapse_shape output. Try to find which
474/// dimensions have been sliced and which dimensions are not sliced (offset = 0,
475/// size = dim, size = 1). Note that this conservative as it cannot detect if a
476/// dynamic size corresponds to the full tensor dimension or not.
477llvm::SmallBitVector getSlicedDimensions(ArrayRef<OpFoldResult> sliceInputShape,
478 ArrayRef<Range> sliceParams);
479
480/// Determine which dimensions are linearized by a `tensor.collapse_shape` op by
481/// inspecting its reassociation indices.
482llvm::SmallBitVector
484
485/// Given the parameters for both operations in a `CollapseShape->ExtractSlice`
486/// chain and reified source and result shapes of the CollapseShapeOp, this
487/// class provides two functions that assist with directly forming the result
488/// of the extract slice by "tiling the CollapseShapeOp by 1".
489//// Example:
490// clang-format off
491/// ```
492/// %0 = linalg.generic ... -> tensor<3x7x11x10xf32>
493/// %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : ... to tensor<341x10xf32>
494/// %2 = tensor.extract_slice %1 [13, 0] [10, 10] [2, 1] : .... tensor<10x10xf32>
495/// ```
496/// This class helps build the below IR to replace %2:
497/// ```
498/// %dest = tensor.empty() : tensor<10x10xf32>
499/// %2 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0) -> tensor<10x10xf32> {
500/// %linear_index = affine.apply affine_map<(d0)[]->(d0*2 + 11)>(%iv)
501/// %3:3 = arith.delinearize_index %iv into (3, 7, 11)
502///
503/// // This function takes %3 (multiIndices) and the parameters for the slice below.
504/// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
505/// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
506///
507/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
508/// tensor<1x1x1x10xf32> into tensor<1x10xf32>
509/// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
510/// tensor<1x10xf32> into tensor<10x10xf32>
511/// scf.yield %6 : tensor<10x10xf32>
512/// }
513/// ```
514// clang-format on
515class SliceFromCollapseHelper {
516public:
517 SliceFromCollapseHelper(ArrayRef<ReassociationIndices> reassociationIndices,
518 ArrayRef<OpFoldResult> collapseShapeInputShape,
519 ArrayRef<OpFoldResult> collapseShapeOutputShape,
520 ArrayRef<Range> extractSliceParams)
521 : reassociationIndices(reassociationIndices),
522 collapseShapeInputShape(collapseShapeInputShape),
523 collapseShapeOutputShape(collapseShapeOutputShape),
524 sliceParams(extractSliceParams),
525 linearizedDimensions(getLinearizedDimensions(reassociationIndices)),
526 slicedDimensions(getSlicedDimensions(collapseShapeOutputShape,
527 extractSliceParams)) {}
528
529 /// This function takes multi-indices and maps them to ExtractSlice parameters
530 /// in the index space of the CollapseShape's source tensor. This function's
531 /// signature can be described by `(D_0, D_1,.. D_{n-1}) -> (offsets, sizes,
532 /// strides)` where `n` the number of "tiled dimensions", which are the
533 /// dimensions of the output that are linearized by the collapse shape op and
534 /// are also sliced. Each `D_i` is a tuple that must represent a valid
535 /// multi-index for the `i-th` tiled dimension. In the example above, there is
536 /// only one tiled dimension (D_0) and `arith.delinearize_index` produces the
537 /// multi-index (%3) that would be passed to this function to generate the
538 /// parameters for the `tensor.extract_slice` op (%4).
539 SmallVector<Range> getExtractSliceParams(MLIRContext *ctx,
540 ArrayRef<ValueRange> multiIndices);
541
542 /// This function takes indices in the index space of the "tiled dimensions"
543 /// described above and returns a set of Range variables that describe how the
544 /// slice should be inserted into the destination. In the example above, `%iv`
545 /// would be passed to this function to generate the parameters for the
546 /// `tensor.insert_slice` op producing %6.
547 SmallVector<Range> getInsertSliceParams(MLIRContext *ctx,
548 ValueRange tileIndices);
549
550private:
551 SmallVector<ReassociationIndices> reassociationIndices;
552 SmallVector<OpFoldResult> collapseShapeInputShape;
553 SmallVector<OpFoldResult> collapseShapeOutputShape;
554 SmallVector<Range> sliceParams;
555 llvm::SmallBitVector linearizedDimensions;
556 llvm::SmallBitVector slicedDimensions;
557};
558
559/// Parameters required to simplify a collapsing reshape op with a rank-reducing
560/// slice operation. See `getSimplifyCollapseShapeWithRankReducingSliceInfo`.
561struct CollapseShapeRankReducingSliceSimplificationInfo {
562 /// The shape of the output of the rank-reducing slice.
563 RankedTensorType sliceResultType;
564 /// The reassociation indices for the new collapse shape op, if required. If
565 /// `std::nullopt`, the slice should replace the collapse shape op.
566 std::optional<SmallVector<ReassociationIndices>> newReassociationIndices;
567};
568
569/// A collapsing reshape operation can sometimes be simplified or eliminated by
570/// inserting a single rank-reducing slice operation between it and the source
571/// tensor. The slice op will either take the place of the source, allowing for
572/// a new, simpler reshape op to replace the original, or the reshape op will be
573/// completely replaced by the slice result.
574///
575/// This function returns the parameters required to implement this pattern. If
576/// the pattern is not applicable, then failure is returned.
577///
578/// ### Example:
579/// ```
580/// %result = tensor.collapse_shape %0 [[0, 1], [2, 3]]
581/// : tensor<?x1x30x10xf32> to tensor<?x300xf32>
582/// ```
583/// can be transformed to
584/// ```
585/// %tmp = tensor.extract_slice %0 [0, 0, 0, 0]
586/// [0, %dim1, 30, 30]
587/// [1, 1, 1 1]
588/// : tensor<?x1x30x10xf32> to tensor<?x30x10xf32>
589/// %result = tensor.collapse_shape %tmp [[0], [1, 2]]
590/// : tensor<?x30x10xf32> to tensor<?x300xf32>
591/// ```
592///
593/// ### Example:
594/// ```
595/// %result = tensor.collapse_shape %1 [[0, 1], [2]]
596/// : tensor<?x1x30xf32> to tensor<?x30xf32>
597/// ```
598/// can be transformed to
599/// ```
600/// %result = tensor.extract_slice %1 [0, 0, 0]
601/// [%dim2, 1, 30]
602/// [1, 1, 1]
603/// : tensor<?x1x30xf32> to tensor<?x30xf32>
604/// ```
605FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
606getSimplifyCollapseShapeWithRankReducingSliceInfo(
607 RankedTensorType sourceType,
608 ArrayRef<ReassociationIndices> reassociationIndices);
609
610struct PackingMetadata {
611 SmallVector<int64_t> insertPositions;
612 SmallVector<int64_t> outerPositions;
613 SmallVector<ReassociationIndices> reassociations;
614};
615
616/// Given a vector of `positions` indices representing desired packing insertion
617/// points into a target vector (i.e. pack/unpack.inner_dim_pos), compute the
618/// final positions in the target shape as well as the reshape reassociations.
619// Note: This should not be called with a large positions array (or the
620// implementation needs to be updated to use an N.log N sort instead of
621// repeated N^2 counts).
622PackingMetadata computePackingMetadata(int64_t packedRank,
623 ArrayRef<int64_t> innerDimPos);
624
625/// Try to remove a tensor operation if it would only reshape a constant.
626/// Removes the op and replaces the constant with a new constant of the result
627/// shape. When an optional cst attribute is passed, it is reshaped only if the
628/// splat value matches the value in the attribute.
629OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result,
630 std::optional<Attribute> cst = std::nullopt);
631} // namespace mlir
632
633#endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
static RankedTensorType sliceResultType(Type operandType, GridOp grid, ArrayRef< GridAxis > gridAxes, int64_t sliceAxis)
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
An attribute that represents a reference to a dense vector or tensor object.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
llvm::SmallBitVector getSlicedDimensions(ArrayRef< OpFoldResult > sliceInputShape, ArrayRef< Range > sliceParams)
The input parameters offsets, sizes, strides specify a rectangular non rank-reducing slice of the col...
ArrayRef< int64_t > ReassociationIndicesRef
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< ReassociationIndices, 2 > convertReassociationMapsToIndices(ArrayRef< ReassociationExprs > reassociationExprs)
Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)
Returns the reassociation maps to collapse sourceShape to targetShape if possible.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
SmallVector< AffineExpr, 2 > ReassociationExprs
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
std::optional< SmallVector< ReassociationIndices > > composeReassociationIndices(ArrayRef< ReassociationIndices > producerReassociations, ArrayRef< ReassociationIndices > consumerReassociations, MLIRContext *context)
Compose reassociation maps that are used in pair of reshape ops where one is a producer and other is ...
llvm::SmallBitVector getLinearizedDimensions(ArrayRef< ReassociationIndices > reassociationIndices)
Determine which dimensions are linearized by a tensor.collapse_shape op by inspecting its reassociati...
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
Common verifier for reshape-like types.
LogicalResult matchAndRewrite(CollapseOpTy collapseOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExpandOpTy expandOp, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})