MLIR  21.0.0git
ReshapeOpsUtils.h
Go to the documentation of this file.
1 //===- ReshapeOpsUtils.h - Utilities used by reshape ops --*- C++ -*------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines utilities and common canonicalization patterns for
10 // reshape operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
15 #define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
16 
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Support/LLVM.h"
22 #include "llvm/ADT/StringRef.h"
23 #include <optional>
24 
25 namespace mlir {
26 
27 using ReassociationIndices = SmallVector<int64_t, 2>;
30 
31 /// Attribute name for the ArrayAttr which encodes reassociation indices.
32 constexpr StringRef getReassociationAttrName() { return "reassociation"; }
33 
34 /// Compose reassociation maps that are used in pair of reshape ops where one
35 /// is a producer and other is the consumer. Only valid to use this method when
36 /// both the producer and consumer are collapsing dimensions or both are
37 /// expanding dimensions.
38 ///
39 /// For example,
40 /// producerReassociation = [[0, 1], [2], [3, 4]]
41 /// consumerReassociation = [[0, 1], [2]]
42 ///
43 /// is folded into
44 ///
45 /// result = [[0, 1, 2], [3, 4]].
46 std::optional<SmallVector<ReassociationIndices>> composeReassociationIndices(
47  ArrayRef<ReassociationIndices> producerReassociations,
48  ArrayRef<ReassociationIndices> consumerReassociations,
49  MLIRContext *context);
50 
51 /// Convert reassociation indices to affine expressions.
52 SmallVector<SmallVector<AffineExpr, 2>, 2> convertReassociationIndicesToExprs(
53  MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices);
54 
55 /// Constructs affine maps out of Array<Array<AffineExpr>>.
56 SmallVector<AffineMap, 4>
57 getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation);
58 
59 /// Wraps a list of reassociations in an ArrayAttr.
60 ArrayAttr
62  ArrayRef<ReassociationIndices> reassociation);
63 
64 /// Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
65 SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices(
66  ArrayRef<ReassociationExprs> reassociationExprs);
67 
68 /// Return the reassociations maps to use to reshape given the source type and
69 /// the target type when possible. Return std::nullopt when this computation
70 /// failed.
71 std::optional<SmallVector<ReassociationIndices>>
72 getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType);
73 
74 /// Returns the reassociation maps to collapse `sourceShape` to `targetShape` if
75 /// possible.
76 std::optional<SmallVector<ReassociationIndices>>
77 getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
78  ArrayRef<int64_t> targetShape);
79 
80 /// Return true if the reassociation specification is valid, false otherwise.
81 /// When false, the `invalidIndex` integer pointer is optionally filled with the
82 /// index of the offending reassociation map.
83 bool isReassociationValid(ArrayRef<AffineMap> reassociation,
84  int *invalidIndex = nullptr);
85 
86 template <typename ReshapeOpTy, typename InverseReshapeOpTy>
87 static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
88  ArrayRef<Attribute> operands) {
89  // Fold identity reshape.
90  if (reshapeOp.getSrcType() == reshapeOp.getType())
91  return reshapeOp.getSrc();
92 
93  // Reshape of a constant can be replaced with a new constant.
94  if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
95  return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
96 
97  // Fold if the producer reshape source has the same shape with at most 1
98  // dynamic dimension.
99  auto reshapeSrcOp =
100  reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
101  if (!reshapeSrcOp)
102  return nullptr;
103  auto srcType = reshapeSrcOp.getSrcType();
104  auto resultType = reshapeOp.getResultType();
105  if (srcType != resultType)
106  return nullptr;
107 
108  if (llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) {
109  return reshapeSrcOp.getSrc();
110  }
111 
112  // Fold producer-consumer reshape ops when they are perfect inverses of each
113  // other:
114  // 1) Reassociation indices are equivalent.
115  // 2) Boundary types are equivalent.
116  // 3) No reassociations have more than 1 dynamic dimension, and reassociated
117  // shapes are equal for each reassociation.
118  auto reassociations = reshapeOp.getReassociationIndices();
119  if (reassociations != reshapeSrcOp.getReassociationIndices())
120  return nullptr;
121  // If the reshapes are expanding and then collapsing, the ops can be folded
122  // despite multiple dynamic dimensions.
123  if (srcType.getRank() < reshapeSrcOp.getResultType().getRank())
124  return reshapeSrcOp.getSrc();
125  if (llvm::all_of(reassociations, [&](auto reInd) {
126  ArrayRef<int64_t> srcSlice =
127  srcType.getShape().slice(reInd.front(), reInd.size());
128  return llvm::count_if(srcSlice, ShapedType::isDynamic) < 2;
129  })) {
130  return reshapeSrcOp.getSrc();
131  }
132  return nullptr;
133 }
134 
135 /// Common verifier for reshape-like types. Fills `expandedType` and
136 ///`collapsedType` with the proper `src` or `result` type.
137 template <typename Op, typename T>
138 static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
139  T collapsedType, bool isExpansion) {
140 
141  unsigned expandedRank = expandedType.getRank();
142  unsigned collapsedRank = collapsedType.getRank();
143  if (expandedRank < collapsedRank)
144  return op.emitOpError("expected the expanded type, ")
145  << expandedType << " to have a higher (or same) rank "
146  << "than the collapsed type, " << collapsedType << '.';
147 
148  if (collapsedRank != op.getReassociation().size())
149  return op.emitOpError("expected collapsed rank (")
150  << collapsedRank << ") to equal the number of reassociation maps ("
151  << op.getReassociation().size() << ").";
152 
153  auto maps = op.getReassociationMaps();
154  for (auto it : llvm::enumerate(maps))
155  if (it.value().getNumDims() != expandedRank)
156  return op.emitOpError("expected reassociation map #")
157  << it.index() << " to have size equal to the expanded rank ("
158  << expandedRank << "), but it is " << it.value().getNumDims()
159  << '.';
160 
161  int invalidIdx = 0;
162  if (!isReassociationValid(maps, &invalidIdx))
163  return op.emitOpError("expected reassociation map #")
164  << invalidIdx << " to be valid and contiguous.";
165 
167  [&](const Twine &msg) { return op->emitOpError(msg); },
168  collapsedType.getShape(), expandedType.getShape(),
169  op.getReassociationIndices(), isExpansion);
170 }
171 
172 /// Verify that shapes of the reshaped types using following rule:
173 /// if a dimension in the collapsed type is static, then the corresponding
174 /// dimensions in the expanded shape should be
175 /// a) static
176 /// b) the product should be same as the collaped shape.
177 LogicalResult reshapeLikeShapesAreCompatible(
178  function_ref<LogicalResult(const Twine &)> emitError,
179  ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
180  ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape);
181 
182 /// Returns true iff the type is a MemRefType and has a non-identity layout.
183 bool hasNonIdentityLayout(Type type);
184 
186 
187 /// Pattern to collapse producer/consumer reshape ops that are both collapsing
188 /// dimensions or are both expanding dimensions.
189 template <typename ReshapeOpTy, ReshapeOpKind opKind>
190 struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
192  LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
193  PatternRewriter &rewriter) const override {
194  auto srcReshapeOp =
195  reshapeOp.getSrc().template getDefiningOp<ReshapeOpTy>();
196  if (!srcReshapeOp)
197  return failure();
198 
199  ShapedType resultType = reshapeOp.getResultType();
200 
201  if (hasNonIdentityLayout(srcReshapeOp.getSrc().getType()) ||
202  hasNonIdentityLayout(reshapeOp.getSrc().getType()) ||
203  hasNonIdentityLayout(reshapeOp.getResult().getType()))
204  return failure();
205 
206  std::optional<SmallVector<ReassociationIndices>> reassociationIndices =
207  composeReassociationIndices(srcReshapeOp.getReassociationIndices(),
208  reshapeOp.getReassociationIndices(),
209  rewriter.getContext());
210  if (!reassociationIndices)
211  return failure();
212 
213  if constexpr (opKind == ReshapeOpKind::kExpand) {
214  SmallVector<OpFoldResult> outputShape(
215  getMixedValues(reshapeOp.getStaticOutputShape(),
216  reshapeOp.getOutputShape(), rewriter));
217  rewriter.replaceOpWithNewOp<ReshapeOpTy>(
218  reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices,
219  outputShape);
220  } else {
221  rewriter.replaceOpWithNewOp<ReshapeOpTy>(
222  reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
223  }
224  return success();
225  }
226 };
227 
228 /// Pattern to compose
229 /// `collapse_shape(expand_shape(%src, reassociation_1), reassociation_2)`.
230 /// In that case both `srcType` and `resultType` can be expressed as a function
231 /// of `intermediateType`.
232 /// In order to demonstrate the approach, let's assume that `rank(srcType) >
233 /// `rank(resultType)`, i.e. the resulting operation should be `collapse_shape`.
234 /// In that case, we can iterate over every set of indices in `reassociation_2`
235 /// and try to find ids of sets of indices in `reassociation_1` that cover it
236 /// completely.
237 ///
238 /// Example:
239 ///
240 /// %0 = tensor.expand_shape %arg [[0], [1], [2, 3]]
241 /// : tensor<?x?x?xi64> into tensor<?x?x?x1xi64>
242 /// %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
243 /// : tensor<?x?x?x1xi64> into tensor<?x?xi64>
244 ///
245 /// can be canonicalized into
246 ///
247 /// %0 = tensor.collapse_shape %arg [[0, 1], [2]]
248 /// : tensor<?x?x?xi64> into tensor<?x?xi64>
249 ///
250 /// because [0] and [1] from `expand_shape` reassociation cover completely
251 /// `[0, 1]` from `collapse_shape`. If it is impossible to find such union of
252 /// indices, then we fail.
253 //
254 /// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
255 /// `reassociation_2` and produce `expand_shape`.
256 template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy,
257  typename DimOpTy, typename TensorTy>
258 struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
260  LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
261  PatternRewriter &rewriter) const override {
262  auto expandOp = collapseOp.getSrc().template getDefiningOp<ExpandOpTy>();
263  if (!expandOp)
264  return failure();
265 
266  ShapedType srcType = expandOp.getSrcType();
267  ShapedType resultType = collapseOp.getResultType();
268 
269  if (hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
270  hasNonIdentityLayout(expandOp.getSrc().getType()) ||
271  hasNonIdentityLayout(expandOp.getResult().getType()))
272  return failure();
273 
274  int64_t srcRank = srcType.getRank();
275  int64_t resultRank = resultType.getRank();
276  if (srcType == resultType)
277  return failure();
278 
279  SmallVector<ReassociationIndices, 4> higherRankReassociation,
280  lowerRankReassociation;
281 
282  if (srcRank > resultRank) {
283  higherRankReassociation = expandOp.getReassociationIndices();
284  lowerRankReassociation = collapseOp.getReassociationIndices();
285  } else {
286  higherRankReassociation = collapseOp.getReassociationIndices();
287  lowerRankReassociation = expandOp.getReassociationIndices();
288  }
289 
290  size_t higherRankIndicesID = 0;
291  SmallVector<ReassociationIndices, 4> composedReassociation;
292  for (const auto &lowerRankIndices : lowerRankReassociation) {
293  ReassociationIndices composedIndices;
294  while (higherRankIndicesID < higherRankReassociation.size()) {
295  auto rightmostIndex =
296  higherRankReassociation[higherRankIndicesID].back();
297  if (rightmostIndex > lowerRankIndices.back())
298  return failure();
299  composedIndices.push_back(higherRankIndicesID++);
300  if (rightmostIndex == lowerRankIndices.back())
301  break;
302  }
303  composedReassociation.push_back(composedIndices);
304  }
305  if (srcRank > resultRank) {
306  rewriter.replaceOpWithNewOp<CollapseOpTy>(
307  collapseOp, resultType, expandOp.getSrc(), composedReassociation);
308  } else if (srcRank < resultRank) {
309  // Compute the dynamic output shape for the new expand_shape op.
310  Location loc = collapseOp.getLoc();
311  SmallVector<OpFoldResult> origOutputShape =
312  expandOp.getMixedOutputShape();
313  SmallVector<OpFoldResult> newOutputShape;
314  for (const ReassociationIndices &indices :
315  collapseOp.getReassociationIndices()) {
316  int64_t numStaticElems = 1;
317  SmallVector<Value> dynamicSizes;
318  for (int64_t idx : indices) {
319  OpFoldResult size = origOutputShape[idx];
320  if (std::optional<int64_t> maybeCst = getConstantIntValue(size)) {
321  numStaticElems *= maybeCst.value();
322  continue;
323  }
324  dynamicSizes.push_back(cast<Value>(size));
325  }
326  if (dynamicSizes.empty()) {
327  newOutputShape.push_back(rewriter.getIndexAttr(numStaticElems));
328  continue;
329  }
330 
331  // There is at least one dynamic size, so we can initialize `result` to
332  // the first dynamic size.
333  Value result = dynamicSizes[0];
334  for (Value v : llvm::drop_begin(dynamicSizes))
335  result = rewriter.create<arith::MulIOp>(loc, result, v);
336  if (numStaticElems != 1) {
337  result = rewriter.create<arith::MulIOp>(
338  loc, result,
339  rewriter.create<arith::ConstantIndexOp>(loc, numStaticElems));
340  }
341  newOutputShape.push_back(result);
342  }
343  rewriter.replaceOpWithNewOp<ExpandOpTy>(
344  collapseOp, resultType, expandOp.getSrc(), composedReassociation,
345  newOutputShape);
346  } else {
347  // Collapses/expansions that do not change the rank are not allowed. Use
348  // a cast instead.
349  assert(llvm::equal(srcType.getShape(), resultType.getShape()) &&
350  "expected same shape");
351  rewriter.replaceOpWithNewOp<CastOpTy>(collapseOp, resultType,
352  expandOp.getSrc());
353  }
354  return success();
355  }
356 };
357 
358 template <typename ExpandOpTy, typename CollapseOpTy>
359 struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
361  LogicalResult matchAndRewrite(ExpandOpTy expandOp,
362  PatternRewriter &rewriter) const override {
363  auto collapseOp = expandOp.getSrc().template getDefiningOp<CollapseOpTy>();
364  if (!collapseOp)
365  return failure();
366 
367  ShapedType srcType = collapseOp.getSrcType();
368  ShapedType resultType = expandOp.getResultType();
369 
370  if (hasNonIdentityLayout(expandOp.getSrc().getType()) ||
371  hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
372  hasNonIdentityLayout(collapseOp.getResult().getType()))
373  return failure();
374 
375  int64_t srcRank = srcType.getRank();
376  int64_t resultRank = resultType.getRank();
377  if (srcRank == resultRank)
378  return failure();
379 
380  auto srcReassociation = collapseOp.getReassociationIndices();
381  auto resultReassociation = expandOp.getReassociationIndices();
382  if (srcRank > resultRank) {
383  auto composedReassociation = findCollapsingReassociation(
384  srcReassociation, resultReassociation, srcType.getShape(),
385  resultType.getShape());
386  if (!composedReassociation)
387  return failure();
388 
389  rewriter.replaceOpWithNewOp<CollapseOpTy>(
390  expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
391  return success();
392  }
393  auto composedReassociation =
394  findCollapsingReassociation(resultReassociation, srcReassociation,
395  resultType.getShape(), srcType.getShape());
396  if (!composedReassociation)
397  return failure();
398 
400  expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
401  rewriter.replaceOpWithNewOp<ExpandOpTy>(
402  expandOp, resultType, collapseOp.getSrc(), *composedReassociation,
403  outputShape);
404  return success();
405  }
406 
407 private:
408  // Attempts to find a way to collapse `srcShape` to `resultShape` by
409  // collapsing subshapes defined by the reassociation indices.
410  std::optional<SmallVector<ReassociationIndices>> findCollapsingReassociation(
411  ArrayRef<ReassociationIndices> srcReassociation,
412  ArrayRef<ReassociationIndices> resultReassociation,
413  ArrayRef<int64_t> srcShape, ArrayRef<int64_t> resultShape) const {
414  SmallVector<ReassociationIndices, 4> composedReassociation;
415 
416  if (srcReassociation.empty())
417  return {getReassociationIndicesForCollapse(srcShape, resultShape)};
418 
419  for (auto item : llvm::zip(srcReassociation, resultReassociation)) {
420  auto &srcIndices = std::get<0>(item);
421  auto &resultIndices = std::get<1>(item);
422  auto srcSubShape = srcShape.slice(srcIndices.front(), srcIndices.size());
423  auto resultSubShape =
424  resultShape.slice(resultIndices.front(), resultIndices.size());
425 
426  if (llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2 &&
427  llvm::count_if(resultSubShape, ShapedType::isDynamic) >= 2)
428  return std::nullopt;
429 
430  if (srcSubShape.size() == resultSubShape.size()) {
431  if (srcSubShape != resultSubShape)
432  return std::nullopt;
433 
434  for (auto index : llvm::seq<int64_t>(0, srcSubShape.size())) {
435  composedReassociation.emplace_back(1, srcIndices.front() + index);
436  }
437  continue;
438  }
439 
440  // Find reassociation to collapse `srcSubShape` into `resultSubShape`.
441  auto subShapeReassociation =
442  getReassociationIndicesForCollapse(srcSubShape, resultSubShape);
443  if (!subShapeReassociation)
444  return std::nullopt;
445 
446  // Remap the subshape indices back to the original srcShape.
447  for (auto &subshapeIndices : *subShapeReassociation) {
448  ReassociationIndices shapeIndices;
449  for (int64_t index : subshapeIndices)
450  shapeIndices.push_back(srcIndices.front() + index);
451  composedReassociation.push_back(shapeIndices);
452  }
453  }
454  return {std::move(composedReassociation)};
455  }
456 };
457 
458 /// The input parameters `offsets`, `sizes`, `strides` specify a rectangular
459 /// non rank-reducing slice of the collapse_shape output. Try to find which
460 /// dimensions have been sliced and which dimensions are not sliced (offset = 0,
461 /// size = dim, size = 1). Note that this conservative as it cannot detect if a
462 /// dynamic size corresponds to the full tensor dimension or not.
463 llvm::SmallBitVector getSlicedDimensions(ArrayRef<OpFoldResult> sliceInputShape,
464  ArrayRef<Range> sliceParams);
465 
466 /// Determine which dimensions are linearized by a `tensor.collapse_shape` op by
467 /// inspecting its reassociation indices.
468 llvm::SmallBitVector
469 getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
470 
471 /// Given the parameters for both operations in a `CollapseShape->ExtractSlice`
472 /// chain and reified source and result shapes of the CollapseShapeOp, this
473 /// class provides two functions that assist with directly forming the result
474 /// of the extract slice by "tiling the CollapseShapeOp by 1".
475 //// Example:
476 // clang-format off
477 /// ```
478 /// %0 = linalg.generic ... -> tensor<3x7x11x10xf32>
479 /// %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : ... to tensor<341x10xf32>
480 /// %2 = tensor.extract_slice %1 [13, 0] [10, 10] [2, 1] : .... tensor<10x10xf32>
481 /// ```
482 /// This class helps build the below IR to replace %2:
483 /// ```
484 /// %dest = tensor.empty() : tensor<10x10xf32>
485 /// %2 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0) -> tensor<10x10xf32> {
486 /// %linear_index = affine.apply affine_map<(d0)[]->(d0*2 + 11)>(%iv)
487 /// %3:3 = arith.delinearize_index %iv into (3, 7, 11)
488 ///
489 /// // This function takes %3 (multiIndices) and the parameters for the slice below.
490 /// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
491 /// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
492 ///
493 /// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
494 /// tensor<1x1x1x10xf32> into tensor<1x10xf32>
495 /// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
496 /// tensor<1x10xf32> into tensor<10x10xf32>
497 /// scf.yield %6 : tensor<10x10xf32>
498 /// }
499 /// ```
500 // clang-format on
501 class SliceFromCollapseHelper {
502 public:
503  SliceFromCollapseHelper(ArrayRef<ReassociationIndices> reassociationIndices,
504  ArrayRef<OpFoldResult> collapseShapeInputShape,
505  ArrayRef<OpFoldResult> collapseShapeOutputShape,
506  ArrayRef<Range> extractSliceParams)
507  : reassociationIndices(reassociationIndices),
508  collapseShapeInputShape(collapseShapeInputShape),
509  collapseShapeOutputShape(collapseShapeOutputShape),
510  sliceParams(extractSliceParams),
511  linearizedDimensions(getLinearizedDimensions(reassociationIndices)),
512  slicedDimensions(getSlicedDimensions(collapseShapeOutputShape,
513  extractSliceParams)) {}
514 
515  /// This function takes multi-indices and maps them to ExtractSlice parameters
516  /// in the index space of the CollapseShape's source tensor. This function's
517  /// signature can be described by `(D_0, D_1,.. D_{n-1}) -> (offsets, sizes,
518  /// strides)` where `n` the number of "tiled dimensions", which are the
519  /// dimensions of the output that are linearized by the collapse shape op and
520  /// are also sliced. Each `D_i` is a tuple that must represent a valid
521  /// multi-index for the `i-th` tiled dimension. In the example above, there is
522  /// only one tiled dimension (D_0) and `arith.delinearize_index` produces the
523  /// multi-index (%3) that would be passed to this function to generate the
524  /// parameters for the `tensor.extract_slice` op (%4).
525  SmallVector<Range> getExtractSliceParams(MLIRContext *ctx,
526  ArrayRef<ValueRange> multiIndices);
527 
528  /// This function takes indices in the index space of the "tiled dimensions"
529  /// described above and returns a set of Range variables that describe how the
530  /// slice should be inserted into the destination. In the example above, `%iv`
531  /// would be passed to this function to generate the parameters for the
532  /// `tensor.insert_slice` op producing %6.
533  SmallVector<Range> getInsertSliceParams(MLIRContext *ctx,
534  ValueRange tileIndices);
535 
536 private:
537  SmallVector<ReassociationIndices> reassociationIndices;
538  SmallVector<OpFoldResult> collapseShapeInputShape;
539  SmallVector<OpFoldResult> collapseShapeOutputShape;
540  SmallVector<Range> sliceParams;
541  llvm::SmallBitVector linearizedDimensions;
542  llvm::SmallBitVector slicedDimensions;
543 };
544 
545 /// Parameters required to simplify a collapsing reshape op with a rank-reducing
546 /// slice operation. See `getSimplifyCollapseShapeWithRankReducingSliceInfo`.
547 struct CollapseShapeRankReducingSliceSimplificationInfo {
548  /// The shape of the output of the rank-reducing slice.
549  RankedTensorType sliceResultType;
550  /// The reassociation indices for the new collapse shape op, if required. If
551  /// `std::nullopt`, the slice should replace the collapse shape op.
552  std::optional<SmallVector<ReassociationIndices>> newReassociationIndices;
553 };
554 
555 /// A collapsing reshape operation can sometimes be simplified or eliminated by
556 /// inserting a single rank-reducing slice operation between it and the source
557 /// tensor. The slice op will either take the place of the source, allowing for
558 /// a new, simpler reshape op to replace the original, or the reshape op will be
559 /// completely replaced by the slice result.
560 ///
561 /// This function returns the parameters required to implement this pattern. If
562 /// the pattern is not applicable, then failure is returned.
563 ///
564 /// ### Example:
565 /// ```
566 /// %result = tensor.collapse_shape %0 [[0, 1], [2, 3]]
567 /// : tensor<?x1x30x10xf32> to tensor<?x300xf32>
568 /// ```
569 /// can be transformed to
570 /// ```
571 /// %tmp = tensor.extract_slice %0 [0, 0, 0, 0]
572 /// [0, %dim1, 30, 30]
573 /// [1, 1, 1 1]
574 /// : tensor<?x1x30x10xf32> to tensor<?x30x10xf32>
575 /// %result = tensor.collapse_shape %tmp [[0], [1, 2]]
576 /// : tensor<?x30x10xf32> to tensor<?x300xf32>
577 /// ```
578 ///
579 /// ### Example:
580 /// ```
581 /// %result = tensor.collapse_shape %1 [[0, 1], [2]]
582 /// : tensor<?x1x30xf32> to tensor<?x30xf32>
583 /// ```
584 /// can be transformed to
585 /// ```
586 /// %result = tensor.extract_slice %1 [0, 0, 0]
587 /// [%dim2, 1, 30]
588 /// [1, 1, 1]
589 /// : tensor<?x1x30xf32> to tensor<?x30xf32>
590 /// ```
591 FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
592 getSimplifyCollapseShapeWithRankReducingSliceInfo(
593  RankedTensorType sourceType,
594  ArrayRef<ReassociationIndices> reassociationIndices);
595 
596 struct PackingMetadata {
597  SmallVector<int64_t> insertPositions;
598  SmallVector<int64_t> outerPositions;
599  SmallVector<ReassociationIndices> reassociations;
600 };
601 
602 /// Given a vector of `positions` indices representing desired packing insertion
603 /// points into a target vector (i.e. pack/unpack.inner_dim_pos), compute the
604 /// final positions in the target shape as well as the reshape reassociations.
605 // Note: This should not be called with a large positions array (or the
606 // implementation needs to be updated to use an N.log N sort instead of
607 // repeated N^2 counts).
608 PackingMetadata computePackingMetadata(int64_t packedRank,
609  ArrayRef<int64_t> innerDimPos);
610 
611 /// Try to remove a tensor operation if it would only reshape a constant.
612 /// Removes the op and replaces the constant with a new constant of the result
613 /// shape. When an optional cst attribute is passed, it is reshaped only if the
614 /// splat value matches the value in the attribute.
615 OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result,
616  std::optional<Attribute> cst = std::nullopt);
617 } // namespace mlir
618 
619 #endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, ArrayRef< MeshAxis > meshAxes, int64_t sliceAxis)
Definition: MeshOps.cpp:1191
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
MLIRContext * getContext() const
Definition: Builders.h:55
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:836
This provides public APIs that all operations should have.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:97
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
Definition: LLVM.h:152
llvm::SmallBitVector getSlicedDimensions(ArrayRef< OpFoldResult > sliceInputShape, ArrayRef< Range > sliceParams)
The input parameters offsets, sizes, strides specify a rectangular non rank-reducing slice of the col...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
bool hasNonIdentityLayout(Type type)
Returns true iff the type is a MemRefType and has a non-identity layout.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< ReassociationIndices, 2 > convertReassociationMapsToIndices(ArrayRef< ReassociationExprs > reassociationExprs)
Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rule: if a dimension in the collapsed type i...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)
Returns the reassociation maps to collapse sourceShape to targetShape if possible.
ArrayRef< int64_t > ReassociationIndicesRef
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
std::optional< SmallVector< ReassociationIndices > > composeReassociationIndices(ArrayRef< ReassociationIndices > producerReassociations, ArrayRef< ReassociationIndices > consumerReassociations, MLIRContext *context)
Compose reassociation maps that are used in pair of reshape ops where one is a producer and other is ...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
llvm::SmallBitVector getLinearizedDimensions(ArrayRef< ReassociationIndices > reassociationIndices)
Determine which dimensions are linearized by a tensor.collapse_shape op by inspecting its reassociati...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
SmallVector< int64_t, 2 > ReassociationIndices
Definition: Utils.h:27
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
LogicalResult matchAndRewrite(CollapseOpTy collapseOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExpandOpTy expandOp, PatternRewriter &rewriter) const override
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314