MLIR  17.0.0git
ReshapeOpsUtils.h
Go to the documentation of this file.
1 //===- ReshapeOpsUtils.h - Utilities used by reshape ops --*- C++ -*------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines utilities and common canonicalization patterns for
10 // reshape operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
15 #define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
16 
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Support/LLVM.h"
21 #include "llvm/ADT/StringRef.h"
22 #include <optional>
23 
24 namespace mlir {
25 
29 
30 /// Attribute name for the ArrayAttr which encodes reassociation indices.
31 constexpr StringRef getReassociationAttrName() { return "reassociation"; }
32 
33 /// Compose reassociation maps that are used in pair of reshape ops where one
34 /// is a producer and other is the consumer. Only valid to use this method when
35 /// both the producer and consumer are collapsing dimensions or both are
36 /// expanding dimensions.
37 ///
38 /// For example,
39 /// producerReassociation = [[0, 1], [2], [3, 4]]
40 /// consumerReassociation = [[0, 1], [2]]
41 ///
42 /// is folded into
43 ///
44 /// result = [[0, 1, 2], [3, 4]].
45 std::optional<SmallVector<ReassociationIndices>> composeReassociationIndices(
46  ArrayRef<ReassociationIndices> producerReassociations,
47  ArrayRef<ReassociationIndices> consumerReassociations,
48  MLIRContext *context);
49 
50 /// Convert reassociation indices to affine expressions.
51 SmallVector<SmallVector<AffineExpr, 2>, 2> convertReassociationIndicesToExprs(
52  MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices);
53 
54 /// Constructs affine maps out of Array<Array<AffineExpr>>.
55 SmallVector<AffineMap, 4>
56 getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation);
57 
58 /// Wraps a list of reassociations in an ArrayAttr.
59 ArrayAttr
61  ArrayRef<ReassociationIndices> reassociation);
62 
63 /// Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
64 SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices(
65  OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs);
66 
67 /// Return the reassociations maps to use to reshape given the source type and
68 /// the target type when possible. Return std::nullopt when this computation
69 /// failed.
70 std::optional<SmallVector<ReassociationIndices>>
71 getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType);
72 
73 /// Returns the reassociation maps to collapse `sourceShape` to `targetShape` if
74 /// possible.
75 std::optional<SmallVector<ReassociationIndices>>
76 getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
77  ArrayRef<int64_t> targetShape);
78 
79 /// Return true if the reassociation specification is valid, false otherwise.
80 /// When false, the `invalidIndex` integer pointer is optionally filled with the
81 /// index of the offending reassociation map.
82 bool isReassociationValid(ArrayRef<AffineMap> reassociation,
83  int *invalidIndex = nullptr);
84 
85 template <typename ReshapeOpTy, typename InverseReshapeOpTy>
86 static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
87  ArrayRef<Attribute> operands) {
88  // Fold producer-consumer reshape ops that where the operand type of the
89  // producer is same as the return type of the consumer.
90  auto reshapeSrcOp =
91  reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
92  if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
93  return reshapeSrcOp.getSrc();
94  // Reshape of a constant can be replaced with a new constant.
95  if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
96  return elements.reshape(
97  reshapeOp.getResult().getType().template cast<ShapedType>());
98  }
99  return nullptr;
100 }
101 
102 /// Common verifier for reshape-like types. Fills `expandedType` and
103 ///`collapsedType` with the proper `src` or `result` type.
104 template <typename Op, typename T>
105 static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
106  T collapsedType, bool isExpansion) {
107  unsigned expandedRank = expandedType.getRank();
108  unsigned collapsedRank = collapsedType.getRank();
109  if (expandedRank < collapsedRank)
110  return op.emitOpError("expected the type ")
111  << expandedType
112  << " to have higher rank than the type = " << collapsedType;
113  if (expandedRank == 0)
114  return op.emitOpError("expected non-zero memref ranks");
115  if (expandedRank == collapsedRank)
116  return op.emitOpError("expected to collapse or expand dims");
117 
118  if (collapsedRank == 0) {
119  // If collapsed rank is 0, then expanded type must be static shaped and of
120  // sizes 1.
121  if (llvm::any_of(expandedType.getShape(),
122  [](int64_t dim) -> bool { return dim != 1; }))
123  return op.emitOpError("invalid to reshape tensor/memref with non-unit "
124  "extent dimensions to zero-rank tensor/memref");
125  return success();
126  }
127  if (collapsedRank != op.getReassociation().size())
128  return op.emitOpError("expected rank of the collapsed type(")
129  << collapsedRank << ") to be the number of reassociation maps("
130  << op.getReassociation().size() << ")";
131  auto maps = op.getReassociationMaps();
132  for (auto it : llvm::enumerate(maps))
133  if (it.value().getNumDims() != expandedRank)
134  return op.emitOpError("expected reassociation map #")
135  << it.index() << " of same rank as expanded memref("
136  << expandedRank << "), but got " << it.value().getNumDims();
137  int invalidIdx = 0;
138  if (!isReassociationValid(maps, &invalidIdx))
139  return op.emitOpError("expected reassociation map #")
140  << invalidIdx << " to be valid and contiguous";
141  return verifyReshapeLikeShapes(op, collapsedType, expandedType, isExpansion);
142 }
143 
144 /// Verify that shapes of the reshaped types using following rules
145 /// 1) if a dimension in the collapsed type is static, then the corresponding
146 /// dimensions in the expanded shape should be
147 /// a) static
148 /// b) the product should be same as the collaped shape.
149 /// 2) if a dimension in the collaped type is dynamic, one and only one of the
150 /// corresponding dimensions in the expanded type should be dynamic. This
151 /// rule is only needed with reshape operations that are expanding.
152 LogicalResult reshapeLikeShapesAreCompatible(
153  function_ref<LogicalResult(const Twine &)> emitError,
154  ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
155  ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape);
156 
157 template <typename OpTy>
158 static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType,
159  ShapedType expandedType,
160  bool isExpandingReshape) {
162  [&](const Twine &msg) { return op->emitOpError(msg); },
163  collapsedType.getShape(), expandedType.getShape(),
164  op.getReassociationIndices(), isExpandingReshape);
165 }
166 
167 /// Returns true iff the type is a MemRefType and has a non-identity layout.
168 bool hasNonIdentityLayout(Type type);
169 
170 /// Pattern to collapse producer/consumer reshape ops that are both collapsing
171 /// dimensions or are both expanding dimensions.
172 template <typename ReshapeOpTy>
173 struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
175  LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
176  PatternRewriter &rewriter) const override {
177  auto srcReshapeOp =
178  reshapeOp.getSrc().template getDefiningOp<ReshapeOpTy>();
179  if (!srcReshapeOp)
180  return failure();
181 
182  ShapedType resultType = reshapeOp.getResultType();
183 
184  if (hasNonIdentityLayout(srcReshapeOp.getSrc().getType()) ||
185  hasNonIdentityLayout(reshapeOp.getSrc().getType()) ||
186  hasNonIdentityLayout(reshapeOp.getResult().getType()))
187  return failure();
188 
189  std::optional<SmallVector<ReassociationIndices>> reassociationIndices =
190  composeReassociationIndices(srcReshapeOp.getReassociationIndices(),
191  reshapeOp.getReassociationIndices(),
192  rewriter.getContext());
193  if (!reassociationIndices)
194  return failure();
195  rewriter.replaceOpWithNewOp<ReshapeOpTy>(
196  reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
197  return success();
198  }
199 };
200 
201 /// Pattern to compose
202 /// `collapse_shape(expand_shape(%src, reassociation_1), reassociation_2)`.
203 /// In that case both `srcType` and `resultType` can be expressed as a function
204 /// of `intermediateType`.
205 /// In order to demonstrate the approach, let's assume that `rank(srcType) >
206 /// `rank(resultType)`, i.e. the resulting operation should be `collapse_shape`.
207 /// In that case, we can iterate over every set of indices in `reassociation_2`
208 /// and try to find ids of sets of indices in `reassociation_1` that cover it
209 /// completely.
210 ///
211 /// Example:
212 ///
213 /// %0 = tensor.expand_shape %arg [[0], [1], [2, 3]]
214 /// : tensor<?x?x?xi64> into tensor<?x?x?x1xi64>
215 /// %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
216 /// : tensor<?x?x?x1xi64> into tensor<?x?xi64>
217 ///
218 /// can be canonicalized into
219 ///
220 /// %0 = tensor.collapse_shape %arg [[0, 1], [2]]
221 /// : tensor<?x?x?xi64> into tensor<?x?xi64>
222 ///
223 /// because [0] and [1] from `expand_shape` reassociation cover completely
224 /// `[0, 1]` from `collapse_shape`. If it is impossible to find such union of
225 /// indices, then we fail.
226 //
227 /// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
228 /// `reassociation_2` and produce `expand_shape`.
229 template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy>
230 struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
232  LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
233  PatternRewriter &rewriter) const override {
234  auto expandOp = collapseOp.getSrc().template getDefiningOp<ExpandOpTy>();
235  if (!expandOp)
236  return failure();
237 
238  ShapedType srcType = expandOp.getSrcType();
239  ShapedType resultType = collapseOp.getResultType();
240 
241  if (hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
242  hasNonIdentityLayout(expandOp.getSrc().getType()) ||
243  hasNonIdentityLayout(expandOp.getResult().getType()))
244  return failure();
245 
246  int64_t srcRank = srcType.getRank();
247  int64_t resultRank = resultType.getRank();
248  if (srcType == resultType)
249  return failure();
250 
251  SmallVector<ReassociationIndices, 4> higherRankReassociation,
252  lowerRankReassociation;
253 
254  if (srcRank > resultRank) {
255  higherRankReassociation = expandOp.getReassociationIndices();
256  lowerRankReassociation = collapseOp.getReassociationIndices();
257  } else {
258  higherRankReassociation = collapseOp.getReassociationIndices();
259  lowerRankReassociation = expandOp.getReassociationIndices();
260  }
261 
262  size_t higherRankIndicesID = 0;
263  SmallVector<ReassociationIndices, 4> composedReassociation;
264  for (const auto &lowerRankIndices : lowerRankReassociation) {
265  ReassociationIndices composedIndices;
266  while (higherRankIndicesID < higherRankReassociation.size()) {
267  auto rightmostIndex =
268  higherRankReassociation[higherRankIndicesID].back();
269  if (rightmostIndex > lowerRankIndices.back())
270  return failure();
271  composedIndices.push_back(higherRankIndicesID++);
272  if (rightmostIndex == lowerRankIndices.back())
273  break;
274  }
275  composedReassociation.push_back(composedIndices);
276  }
277  if (srcRank > resultRank) {
278  rewriter.replaceOpWithNewOp<CollapseOpTy>(
279  collapseOp, resultType, expandOp.getSrc(), composedReassociation);
280  } else if (srcRank < resultRank) {
281  rewriter.replaceOpWithNewOp<ExpandOpTy>(
282  collapseOp, resultType, expandOp.getSrc(), composedReassociation);
283  } else {
284  // Collapses/expansions that do not change the rank are not allowed. Use
285  // a cast instead.
286  assert(llvm::equal(srcType.getShape(), resultType.getShape()) &&
287  "expected same shape");
288  rewriter.replaceOpWithNewOp<CastOpTy>(collapseOp, resultType,
289  expandOp.getSrc());
290  }
291  return success();
292  }
293 };
294 
295 template <typename ExpandOpTy, typename CollapseOpTy>
296 struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
298  LogicalResult matchAndRewrite(ExpandOpTy expandOp,
299  PatternRewriter &rewriter) const override {
300  auto collapseOp = expandOp.getSrc().template getDefiningOp<CollapseOpTy>();
301  if (!collapseOp)
302  return failure();
303 
304  ShapedType srcType = collapseOp.getSrcType();
305  ShapedType resultType = expandOp.getResultType();
306 
307  if (hasNonIdentityLayout(expandOp.getSrc().getType()) ||
308  hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
309  hasNonIdentityLayout(collapseOp.getResult().getType()))
310  return failure();
311 
312  int64_t srcRank = srcType.getRank();
313  int64_t resultRank = resultType.getRank();
314  if (srcType == resultType)
315  return failure();
316 
317  auto srcReassociation = collapseOp.getReassociationIndices();
318  auto resultReassociation = expandOp.getReassociationIndices();
319  if (srcRank > resultRank) {
320  auto composedReassociation = findCollapsingReassociation(
321  srcReassociation, resultReassociation, srcType.getShape(),
322  resultType.getShape());
323  if (!composedReassociation)
324  return failure();
325 
326  rewriter.replaceOpWithNewOp<CollapseOpTy>(
327  expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
328  return success();
329  }
330  auto composedReassociation =
331  findCollapsingReassociation(resultReassociation, srcReassociation,
332  resultType.getShape(), srcType.getShape());
333  if (!composedReassociation)
334  return failure();
335 
336  rewriter.replaceOpWithNewOp<ExpandOpTy>(
337  expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
338  return success();
339  }
340 
341 private:
342  // Attempts to find a way to collapse `srcShape` to `resultShape` by
343  // collapsing subshapes defined by the reassociation indices.
344  std::optional<SmallVector<ReassociationIndices>> findCollapsingReassociation(
345  ArrayRef<ReassociationIndices> srcReassociation,
346  ArrayRef<ReassociationIndices> resultReassociation,
347  ArrayRef<int64_t> srcShape, ArrayRef<int64_t> resultShape) const {
348  SmallVector<ReassociationIndices, 4> composedReassociation;
349 
350  if (srcReassociation.empty())
351  return {getReassociationIndicesForCollapse(srcShape, resultShape)};
352 
353  for (auto item : llvm::zip(srcReassociation, resultReassociation)) {
354  auto &srcIndices = std::get<0>(item);
355  auto &resultIndices = std::get<1>(item);
356  auto srcSubShape = srcShape.slice(srcIndices.front(), srcIndices.size());
357  auto resultSubShape =
358  resultShape.slice(resultIndices.front(), resultIndices.size());
359 
360  if (srcSubShape.size() == resultSubShape.size()) {
361  if (srcSubShape == resultSubShape)
362  composedReassociation.push_back(srcIndices);
363  else
364  return std::nullopt;
365  }
366 
367  // Find reassociation to collapse `srcSubShape` into `resultSubShape`.
368  auto subShapeReassociation =
369  getReassociationIndicesForCollapse(srcSubShape, resultSubShape);
370  if (!subShapeReassociation)
371  return std::nullopt;
372 
373  // Remap the subshape indices back to the original srcShape.
374  for (auto &subshape_indices : *subShapeReassociation) {
375  ReassociationIndices shape_indices;
376  for (int64_t index : subshape_indices)
377  shape_indices.push_back(srcIndices.front() + index);
378  composedReassociation.push_back(shape_indices);
379  }
380  }
381  return {std::move(composedReassociation)};
382  }
383 };
384 
385 /// The input parameters `offsets`, `sizes`, `strides` specify a rectangular
386 /// non rank-reducing slice of the collapse_shape output. Try to find which
387 /// dimensions have been sliced and which dimensions are not sliced (offset = 0,
388 /// size = dim, size = 1). Note that this conservative as it cannot detect if a
389 /// dynamic size corresponds to the full tensor dimension or not.
390 llvm::SmallBitVector getSlicedDimensions(ArrayRef<OpFoldResult> sliceInputShape,
391  ArrayRef<Range> sliceParams);
392 
393 /// Determine which dimensions are linearized by a `tensor.collapse_shape` op by
394 /// inspecting its reassociation indices.
395 llvm::SmallBitVector
396 getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
397 
398 /// Given the parameters for both operations in a `CollapseShape->ExtractSlice`
399 /// chain and reified source and result shapes of the CollapseShapeOp, this
400 /// class provides two functions that assist with directly forming the result
401 /// of the extract slice by "tiling the CollapseShapeOp by 1".
402 //// Example:
403 // clang-format off
404 /// ```
405 /// %0 = linalg.generic ... -> tensor<3x7x11x10xf32>
406 /// %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : ... to tensor<341x10xf32>
407 /// %2 = tensor.extract_slice %1 [13, 0] [10, 10] [2, 1] : .... tensor<10x10xf32>
408 /// ```
409 /// This class helps build the below IR to replace %2:
410 /// ```
411 /// %dest = tensor.empty() : tensor<10x10xf32>
412 /// %2 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0) -> tensor<10x10xf32> {
413 /// %linear_index = affine.apply affine_map<(d0)[]->(d0*2 + 11)>(%iv)
414 /// %3:3 = arith.delinearize_index %iv into (3, 7, 11)
415 ///
416 /// // This function takes %3 (multiIndices) and the parameters for the slice below.
417 /// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
418 /// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
419 ///
420 /// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
421 /// tensor<1x1x1x10xf32> into tensor<1x10xf32>
422 /// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
423 /// tensor<1x10xf32> into tensor<10x10xf32>
424 /// scf.yield %6 : tensor<10x10xf32>
425 /// }
426 /// ```
427 // clang-format on
428 class SliceFromCollapseHelper {
429 public:
430  SliceFromCollapseHelper(ArrayRef<ReassociationIndices> reassociationIndices,
431  ArrayRef<OpFoldResult> collapseShapeInputShape,
432  ArrayRef<OpFoldResult> collapseShapeOutputShape,
433  ArrayRef<Range> extractSliceParams)
434  : reassociationIndices(reassociationIndices),
435  collapseShapeInputShape(collapseShapeInputShape),
436  collapseShapeOutputShape(collapseShapeOutputShape),
437  sliceParams(extractSliceParams),
438  linearizedDimensions(getLinearizedDimensions(reassociationIndices)),
439  slicedDimensions(getSlicedDimensions(collapseShapeOutputShape,
440  extractSliceParams)) {}
441 
442  /// This function takes multi-indices and maps them to ExtractSlice parameters
443  /// in the index space of the CollapseShape's source tensor. This function's
444  /// signature can be described by `(D_0, D_1,.. D_{n-1}) -> (offsets, sizes,
445  /// strides)` where `n` the number of "tiled dimensions", which are the
446  /// dimensions of the output that are linearized by the collapse shape op and
447  /// are also sliced. Each `D_i` is a tuple that must represent a valid
448  /// multi-index for the `i-th` tiled dimension. In the example above, there is
449  /// only one tiled dimension (D_0) and `arith.delinearize_index` produces the
450  /// multi-index (%3) that would be passed to this function to generate the
451  /// parameters for the `tensor.extract_slice` op (%4).
452  SmallVector<Range> getExtractSliceParams(MLIRContext *ctx,
453  ArrayRef<ValueRange> multiIndices);
454 
455  /// This function takes indices in the index space of the "tiled dimensions"
456  /// described above and returns a set of Range variables that describe how the
457  /// slice should be inserted into the destination. In the example above, `%iv`
458  /// would be passed to this function to generate the parameters for the
459  /// `tensor.insert_slice` op producing %6.
460  SmallVector<Range> getInsertSliceParams(MLIRContext *ctx,
461  ValueRange tileIndices);
462 
463 private:
464  SmallVector<ReassociationIndices> reassociationIndices;
465  SmallVector<OpFoldResult> collapseShapeInputShape;
466  SmallVector<OpFoldResult> collapseShapeOutputShape;
467  SmallVector<Range> sliceParams;
468  llvm::SmallBitVector linearizedDimensions;
469  llvm::SmallBitVector slicedDimensions;
470 };
471 
472 /// Parameters required to simplify a collapsing reshape op with a rank-reducing
473 /// slice operation. See `getSimplifyCollapseShapeWithRankReducingSliceInfo`.
474 struct CollapseShapeRankReducingSliceSimplificationInfo {
475  /// The shape of the output of the rank-reducing slice.
476  RankedTensorType sliceResultType;
477  /// The reassociation indices for the new collapse shape op, if required. If
478  /// `None`, the slice should replace the collapse shape op.
479  std::optional<SmallVector<ReassociationIndices>> newReassociationIndices;
480 };
481 
482 /// A collapsing reshape operation can sometimes be simplified or eliminated by
483 /// inserting a single rank-reducing slice operation between it and the source
484 /// tensor. The slice op will either take the place of the source, allowing for
485 /// a new, simpler reshape op to replace the original, or the reshape op will be
486 /// completely replaced by the slice result.
487 ///
488 /// This function returns the parameters required to implement this pattern. If
489 /// the pattern is not applicable, then failure is returned.
490 ///
491 /// ### Example:
492 /// ```
493 /// %result = tensor.collapse_shape %0 [[0, 1], [2, 3]]
494 /// : tensor<?x1x30x10xf32> to tensor<?x300xf32>
495 /// ```
496 /// can be transformed to
497 /// ```
498 /// %tmp = tensor.extract_slice %0 [0, 0, 0, 0]
499 /// [0, %dim1, 30, 30]
500 /// [1, 1, 1 1]
501 /// : tensor<?x1x30x10xf32> to tensor<?x30x10xf32>
502 /// %result = tensor.collapse_shape %tmp [[0], [1, 2]]
503 /// : tensor<?x30x10xf32> to tensor<?x300xf32>
504 /// ```
505 ///
506 /// ### Example:
507 /// ```
508 /// %result = tensor.collapse_shape %1 [[0, 1], [2]]
509 /// : tensor<?x1x30xf32> to tensor<?x30xf32>
510 /// ```
511 /// can be transformed to
512 /// ```
513 /// %result = tensor.extract_slice %1 [0, 0, 0]
514 /// [%dim2, 1, 30]
515 /// [1, 1, 1]
516 /// : tensor<?x1x30xf32> to tensor<?x30xf32>
517 /// ```
518 FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
519 getSimplifyCollapseShapeWithRankReducingSliceInfo(
520  RankedTensorType sourceType,
521  ArrayRef<ReassociationIndices> reassociationIndices);
522 
523 } // namespace mlir
524 
525 #endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
MLIRContext * getContext() const
Definition: Builders.h:55
An attribute that represents a reference to a dense vector or tensor object.
This class represents a single result from folding an operation.
Definition: OpDefinition.h:233
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:637
This provides public APIs that all operations should have.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:621
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:223
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
llvm::function_ref< Fn > function_ref
Definition: LLVM.h:148
llvm::SmallBitVector getSlicedDimensions(ArrayRef< OpFoldResult > sliceInputShape, ArrayRef< Range > sliceParams)
The input parameters offsets, sizes, strides specify a rectangular non rank-reducing slice of the col...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
bool hasNonIdentityLayout(Type type)
Returns true iff the type is a MemRefType and has a non-identity layout.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rules 1) if a dimension in the collapsed typ...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)
Returns the reassociation maps to collapse sourceShape to targetShape if possible.
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
std::optional< SmallVector< ReassociationIndices > > composeReassociationIndices(ArrayRef< ReassociationIndices > producerReassociations, ArrayRef< ReassociationIndices > consumerReassociations, MLIRContext *context)
Compose reassociation maps that are used in pair of reshape ops where one is a producer and other is ...
SmallVector< int64_t, 2 > ReassociationIndices
llvm::SmallBitVector getLinearizedDimensions(ArrayRef< ReassociationIndices > reassociationIndices)
Determine which dimensions are linearized by a tensor.collapse_shape op by inspecting its reassociati...
static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType, ShapedType expandedType, bool isExpandingReshape)
SmallVector< ReassociationIndices, 2 > convertReassociationMapsToIndices(OpBuilder &b, ArrayRef< ReassociationExprs > reassociationExprs)
Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
LogicalResult matchAndRewrite(CollapseOpTy collapseOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExpandOpTy expandOp, PatternRewriter &rewriter) const override
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, PatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357