MLIR  21.0.0git
ReshapeOpsUtils.cpp
Go to the documentation of this file.
1 //===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Builders.h"
14 #include "llvm/ADT/ArrayRef.h"
15 #include "llvm/ADT/SmallVector.h"
16 #include "llvm/Support/LogicalResult.h"
17 
18 #include <numeric>
19 #include <optional>
20 
21 using namespace mlir;
22 
23 std::optional<SmallVector<ReassociationIndices>>
25  ShapedType targetType) {
26  if (sourceType.getRank() > targetType.getRank())
27  return getReassociationIndicesForCollapse(sourceType.getShape(),
28  targetType.getShape());
29  if (sourceType.getRank() < targetType.getRank())
30  return getReassociationIndicesForCollapse(targetType.getShape(),
31  sourceType.getShape());
32  return std::nullopt;
33 }
34 
35 namespace {
36 /// A simple struct to represent ReassociationIndices as an inclusive interval.
37 /// It's designed to be feasibly minimal, so the call sites should manage the
38 /// validity of the range manually.
39 struct ReassociationIndexRange {
40  /// FIXME: Signed type is used for consistency with ReassociationIndices.
41  /// We should consider refactoring all reassociation utilities to use unsigned
42  /// types.
43  int64_t leftIdx = 0, rightIdx = 0;
44 
45  /// Util for manual checks of the range's validity
46  LogicalResult verify() const {
47  return leftIdx >= 0 && (leftIdx <= rightIdx) ? success() : failure();
48  }
49 
50  /// Checks range's containment within another range. Treats the edges
51  /// non-exclusively.
52  bool isInRange(const ReassociationIndexRange &outerRange) const {
53  return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx;
54  }
55 
56  unsigned size() const {
57  assert(succeeded(verify()));
58  return rightIdx - leftIdx + 1;
59  }
60  bool containsSingleIndex() const { return size() == 1; }
61 
62  /// Collects indices that do not overlap between this and another range.
64  getNonOverlappingIndicesWith(ReassociationIndexRange &rhs) const {
65  if (rightIdx < rhs.leftIdx) {
66  // The intervals do not overlap - concatenate the indices from both.
67  auto jointFullIndices = getFullIndices();
68  jointFullIndices.append(rhs.getFullIndices());
69  return jointFullIndices;
70  }
71  ReassociationIndices result;
72  // Handle the chunk left of the overlapping range.
73  int64_t leftStart = std::min(leftIdx, rhs.leftIdx);
74  int64_t leftEnd = std::max(leftIdx, rhs.leftIdx);
75  llvm::append_range(result, llvm::seq(leftStart, leftEnd));
76  // Handle the chunk right of the overlapping range. Symmetrically, we should
77  // skip the edge of the overlap AND include the rightmost index.
78  int64_t rightStart = std::min(rightIdx, rhs.rightIdx) + 1;
79  int64_t rightEnd = std::max(rightIdx, rhs.rightIdx);
80  if (rightStart < rightEnd)
81  llvm::append_range(result, llvm::seq_inclusive(rightStart, rightEnd));
82  return result;
83  }
84 
85  /// Converts the range into ReassociationIndices.
86  ReassociationIndices getFullIndices() const {
87  ReassociationIndices result;
88  for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) {
89  result.push_back(idx);
90  }
91  return result;
92  }
93 };
94 } // namespace
95 
96 /// Starting from `sourceStartIdx`, searches `sourceShape` for the first
97 /// sequence that can be collapsed into a dynamic dimension (at least one must
98 /// be present in the source).
99 /// By default, lazily returns once the first dynamic dimension has been found.
100 /// Setting `matchGreedily` as `true` will also mark all subsequent
101 /// source dimensions for collapsing into the target.
102 static FailureOr<ReassociationIndexRange>
104  int64_t sourceStartIdx,
105  bool matchGreedily = false) {
106  const unsigned numSourceDims = sourceShape.size();
107  ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
108  std::optional<ReassociationIndexRange> resultRange = std::nullopt;
109 
110  ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
111  for (; iterationRange.isInRange(sourceShapeAsRange);
112  iterationRange.rightIdx++) {
113  int64_t sourceSize = sourceShape[iterationRange.rightIdx];
114  if (sourceSize == ShapedType::kDynamic) {
115  resultRange = iterationRange;
116  break;
117  }
118  }
119  if (!resultRange)
120  return failure();
121  if (matchGreedily)
122  resultRange->rightIdx = sourceShapeAsRange.rightIdx;
123  return *resultRange;
124 }
125 
126 /// Starting from `sourceStartIdx`, searches `sourceShape` for the first
127 /// sequence of static dimensions such that their product matches `targetSize`.
128 /// By default, lazily returns once the product matches the target size. Setting
129 /// `matchGreedily` as `true` will append all neighboring unit dimensions
130 /// (dimensions of 1) to the match.
131 static FailureOr<ReassociationIndexRange>
133  int64_t sourceStartIdx, int64_t targetSize,
134  bool matchGreedily = false) {
135  const unsigned numSourceDims = sourceShape.size();
136  ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
137  std::optional<ReassociationIndexRange> resultRange = std::nullopt;
138 
139  ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
140  int64_t prodOfCollapsedDims = 1;
141  while (iterationRange.isInRange(sourceShapeAsRange)) {
142  int64_t sourceSize = sourceShape[iterationRange.rightIdx];
143  if (sourceSize == ShapedType::kDynamic) {
144  // Reassociation for a static dim cannot include a dynamic dim. Reset
145  // induction variables to essentially restart the loop from the next
146  // source dimension.
147  prodOfCollapsedDims = 1;
148  iterationRange = {iterationRange.rightIdx + 1,
149  iterationRange.rightIdx + 1};
150  continue;
151  }
152  prodOfCollapsedDims *= sourceSize;
153  // If the target size has been exceeded without matching, we need to shift
154  // the range start right. From the start of the range, roll back the
155  // multiplication until the target size exceeds the product again.
156  while (prodOfCollapsedDims > targetSize &&
157  !iterationRange.containsSingleIndex()) {
158  int64_t frontSourceSize = sourceShape[iterationRange.leftIdx];
159  prodOfCollapsedDims /= frontSourceSize;
160  // Shrink the range rightwards
161  iterationRange.leftIdx++;
162  }
163  // We could've reached the target size with the current dimension,
164  // also as a result of the above shift to right.
165  if (prodOfCollapsedDims == targetSize) {
166  resultRange = iterationRange;
167  break;
168  }
169  // Increment the iteration range
170  iterationRange.rightIdx++;
171  }
172  if (!resultRange)
173  return failure();
174  if (matchGreedily) {
175  // We now want to collect all unit dimensions directly after the target
176  // product match. Advance the iterator to avoid OOB when the product match
177  // happens at the last element.
178  iterationRange.rightIdx++;
179  while (iterationRange.isInRange(sourceShapeAsRange) &&
180  sourceShape[iterationRange.rightIdx] == 1) {
181  resultRange = iterationRange;
182  iterationRange.rightIdx++;
183  }
184  }
185  return *resultRange;
186 }
187 
188 /// Attempts to find a valid collapsing reassociation of `sourceShape` into
189 /// `targetShape` through a simple traversal. If successful, an array of source
190 /// index ranges is returned, correspondingly to each dimension in the target
191 /// shape. The resulting indices shall fully cover the `sourceShape` without
192 /// overlaps.
193 ///
194 /// The algorithm is essentially a lazy one, searching for non-greedy matches -
195 /// it will only yield a greedy match for the last target dimension.
196 /// FIXME: The algorithm can only backtrack when it needs to append an offset
197 /// for a static target dimension to the preceding dynamic one (this retains the
198 /// linear complexity). As feasible, consider adding further backtracking
199 /// routines to enable more reassociations, e.g.:
200 /// - ?x2x?x2 into ?x2
201 static FailureOr<SmallVector<ReassociationIndexRange>>
203  ArrayRef<int64_t> targetShape) {
204  unsigned numSourceDims = sourceShape.size(),
205  numTargetDims = targetShape.size();
206  assert(numSourceDims > numTargetDims);
207  ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
208 
210  reassocRanges.reserve(numTargetDims);
211  // We'll iterate in strides of 2 to enable pseudo-backtracking for simple
212  // cases, e.g.:
213  // - ?x2x3x5 into ?x15
214  std::optional<int64_t> prevTargetSize = std::nullopt;
215  for (unsigned targetDimIdx = 0, sourceDimIdx = 0;
216  targetDimIdx < numTargetDims; ++targetDimIdx) {
217  int64_t targetSize = targetShape[targetDimIdx];
218  // Simply check if there are any subsequent target dimensions left - if not,
219  // the match must be made greedily.
220  bool shouldMatchGreedily = targetDimIdx == numTargetDims - 1;
221  FailureOr<ReassociationIndexRange> sourceRange;
222  if (targetSize == ShapedType::kDynamic) {
224  sourceShape, sourceDimIdx, shouldMatchGreedily);
225  } else {
226  sourceRange = findReassociationRangeForSize(
227  sourceShape, sourceDimIdx, targetSize, shouldMatchGreedily);
228  }
229 
230  // Run sanity checks on the returned index range.
231  if (failed(sourceRange) || failed(sourceRange->verify()) ||
232  !sourceRange->isInRange(sourceShapeAsRange))
233  return failure();
234  if (sourceRange->leftIdx > sourceDimIdx) {
235  // If some source dimensions had to be skipped in order to find a match,
236  // they must be collapsed into the directly preceding dynamic dimension.
237  if (!prevTargetSize || prevTargetSize != ShapedType::kDynamic)
238  return failure();
239  reassocRanges.back().rightIdx = sourceRange->leftIdx - 1;
240  }
241 
242  // Store the gathered information as required for the next iteration.
243  prevTargetSize = targetSize;
244  sourceDimIdx = sourceRange->rightIdx + 1;
245  reassocRanges.push_back(*sourceRange);
246  }
247  // Fail if the source shape wasn't a full match for the target shape. We only
248  // need to check the last recorded index - any other gaps should have been
249  // mended by the main loop.
250  if (reassocRanges.back().rightIdx < sourceShapeAsRange.rightIdx)
251  return failure();
252  return reassocRanges;
253 }
254 
255 /// A variant of `findReassociationRangesForCollapse(...)` that can also scan
256 /// the shapes right-to-left.
257 static FailureOr<SmallVector<ReassociationIndexRange>>
259  ArrayRef<int64_t> targetShape,
260  bool iterateRightToLeft) {
261  if (!iterateRightToLeft)
262  return findReassociationRangesForCollapse(sourceShape, targetShape);
263  // NB: To iterate right-to-left, we currently reverse the shapes and then
264  // reverse the result back. The reversed shapes must not be temporary, as
265  // we're passing through an ArrayRef.
266  // FIXME: It would be preferable to avoid the expensive copies. At the moment,
267  // this approach is chosen for readability of the main implementation.
268  std::vector<int64_t> sourceToReverse = sourceShape.vec(),
269  targetToReverse = targetShape.vec();
270  std::reverse(sourceToReverse.begin(), sourceToReverse.end());
271  std::reverse(targetToReverse.begin(), targetToReverse.end());
272  auto invertedRanges =
273  findReassociationRangesForCollapse(sourceToReverse, targetToReverse);
274  if (failed(invertedRanges))
275  return failure();
276  SmallVector<ReassociationIndexRange> &rangesToInvert = *invertedRanges;
277  unsigned numSourceDims = sourceShape.size();
278  // We have received the ranges for inverted shapes. Now we have to invert
279  // the ranges back to correspond with the original source shape.
280  for (auto &range : rangesToInvert) {
281  int64_t invLeftIdx = range.leftIdx, invRightIdx = range.rightIdx;
282  range.leftIdx = numSourceDims - 1 - invRightIdx;
283  range.rightIdx = numSourceDims - 1 - invLeftIdx;
284  }
285  // Also invert the ordering of the ranges to correspond with the original
286  // target shape.
287  std::reverse(rangesToInvert.begin(), rangesToInvert.end());
288  return rangesToInvert;
289 }
290 
291 std::optional<SmallVector<ReassociationIndices>>
293  ArrayRef<int64_t> targetShape) {
294  unsigned numSourceDims = sourceShape.size(),
295  numTargetDims = targetShape.size();
296  // We're supposed to search for a collapsing reassociation. If the sizes
297  // match, there's no actual collapsing taking place - it's either a no-op or a
298  // `tensor.reshape`-style reassociation (that would be beyond the scope of
299  // this utility).
300  if (numSourceDims <= numTargetDims)
301  return std::nullopt;
302  // Early handling for scalar target types. We should report an invalid
303  // reassociation for non-unit static dimensions - no chance to collapse these
304  // into a scalar.
305  if (numTargetDims == 0) {
306  for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;
307  ++sourceDimIdx) {
308  int64_t sourceSize = sourceShape[sourceDimIdx];
309  if (sourceSize != 1 && sourceSize != ShapedType::kDynamic)
310  return std::nullopt;
311  }
313  }
314 
315  // Collect source ranges by iterating over the target shape left-to-right.
316  FailureOr<SmallVector<ReassociationIndexRange>> maybeForwardRanges =
317  findReassociationRangesForCollapse(sourceShape, targetShape);
318  if (failed(maybeForwardRanges))
319  return std::nullopt;
320  auto &ranges = *maybeForwardRanges;
321  // Now do the same in reverse. We need to get another valid reassociation
322  // through some other strategy, and then compare the results in order to
323  // disambiguate mixed subshapes, such as:
324  // ?x?x? into ?x?, ?x2x? into ?x?, ?x2x3x6x? into ?x6x?
325  // This leads us to lose some of the reassociation opportunities that can only
326  // be found by iterating in a certain direction, e.g. 2x2x? into 2x? - without
327  // backtracking, the algorithm will fail right-to-left. However, this is the
328  // best way to preserve correctness.
329  FailureOr<SmallVector<ReassociationIndexRange>> maybeReverseRanges =
330  findReassociationRangesForCollapse(sourceShape, targetShape,
331  /*iterateRightToLeft=*/true);
332  if (failed(maybeReverseRanges))
333  return std::nullopt;
334  auto &reverseRanges = *maybeReverseRanges;
335 
336  if (ranges.size() != numTargetDims || reverseRanges.size() != numTargetDims)
337  return std::nullopt;
338  // Now we can check for ambiguity of each target dimension's reassociation. If
339  // successful, we put the full indices into our result map for the target
340  // shape.
341  SmallVector<ReassociationIndices> reassociationMap(numTargetDims);
342  for (unsigned targetDimIdx = 0; targetDimIdx < numTargetDims;
343  ++targetDimIdx) {
344  ReassociationIndexRange &range = ranges[targetDimIdx];
345  ReassociationIndexRange &reverseRange = reverseRanges[targetDimIdx];
346  // Get non-overlapping indices between the ranges
347  ReassociationIndices nonMatchingIndices =
348  range.getNonOverlappingIndicesWith(reverseRange);
349  // Unit dimensions can be collapsed wherever - this is the only ambiguity
350  // that we allow.
351  for (int64_t sourceDimIdx : nonMatchingIndices) {
352  if (sourceShape[sourceDimIdx] != 1)
353  return std::nullopt;
354  }
355  reassociationMap[targetDimIdx] = range.getFullIndices();
356  }
357  return reassociationMap;
358 }
359 
360 std::optional<SmallVector<ReassociationIndices>>
362  ArrayRef<ReassociationIndices> producerReassociations,
363  ArrayRef<ReassociationIndices> consumerReassociations,
364  MLIRContext *context) {
365  SmallVector<ReassociationIndices> composedIndices;
366  // Make the producer the larger sized vector. If they are of same size, the
367  // resulting reshape is not a supported reshape op.
368  if (producerReassociations.size() == consumerReassociations.size())
369  return std::nullopt;
370  if (producerReassociations.size() < consumerReassociations.size())
371  std::swap(producerReassociations, consumerReassociations);
372 
373  // Handle the corner case of the result being a rank 0 shaped type. Return an
374  // empty reassociation.
375  if (consumerReassociations.empty())
376  return composedIndices;
377 
378  size_t consumerDims = std::accumulate(
379  consumerReassociations.begin(), consumerReassociations.end(), 0,
380  [](size_t all, ReassociationIndicesRef indices) {
381  return all + indices.size();
382  });
383  if (producerReassociations.size() != consumerDims)
384  return std::nullopt;
385 
386  for (ReassociationIndicesRef consumerIndices : consumerReassociations) {
387  ReassociationIndices reassociations;
388  for (int64_t consumerIndex : consumerIndices) {
389  llvm::append_range(reassociations, producerReassociations[consumerIndex]);
390  }
391  composedIndices.push_back(std::move(reassociations));
392  }
393  return composedIndices;
394 }
395 
398  MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices) {
399  SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps;
400  for (const auto &indices : reassociationIndices) {
401  SmallVector<AffineExpr, 2> reassociationMap;
402  reassociationMap.reserve(indices.size());
403  for (int64_t index : indices)
404  reassociationMap.push_back(mlir::getAffineDimExpr(index, context));
405  reassociationMaps.push_back(std::move(reassociationMap));
406  }
407  return reassociationMaps;
408 }
409 
410 template <typename AffineExprTy>
412  unsigned pos = 0;
413  for (const auto &exprs : exprArrays) {
414  for (auto expr : exprs) {
415  expr.walk([&pos](AffineExpr e) {
416  if (auto d = dyn_cast<AffineExprTy>(e))
417  pos = std::max(pos, d.getPosition());
418  });
419  }
420  }
421  return pos;
422 }
423 
425  Builder &b, ArrayRef<ReassociationIndices> reassociation) {
426  SmallVector<Attribute, 4> reassociationAttr =
427  llvm::to_vector<4>(llvm::map_range(
428  reassociation, [&](const ReassociationIndices &indices) -> Attribute {
429  return cast<Attribute>(b.getI64ArrayAttr(indices));
430  }));
431  return b.getArrayAttr(reassociationAttr);
432 }
433 
435  ArrayRef<ReassociationExprs> reassociationExprs) {
436  SmallVector<ReassociationIndices, 2> reassociationIndices;
437  for (const auto &exprs : reassociationExprs) {
438  ReassociationIndices indices;
439  indices.reserve(exprs.size());
440  for (const auto &expr : exprs)
441  indices.push_back(cast<AffineDimExpr>(expr).getPosition());
442  reassociationIndices.push_back(indices);
443  }
444  return reassociationIndices;
445 }
446 
449  unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
450  assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
451  "Expected symbol-less expressions");
453  maps.reserve(reassociation.size());
454  for (const auto &exprs : reassociation) {
455  assert(!exprs.empty());
456  maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext()));
457  }
458  return maps;
459 }
460 
462  int *invalidIndex) {
463  if (reassociation.empty())
464  return true;
465  unsigned nDims = reassociation[0].getNumDims();
466  unsigned nextExpectedDim = 0;
467  for (const auto &it : llvm::enumerate(reassociation)) {
468  auto m = it.value();
469  if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
470  if (invalidIndex)
471  *invalidIndex = it.index();
472  return false;
473  }
474  for (auto e : m.getResults()) {
475  auto d = dyn_cast<AffineDimExpr>(e);
476  if (!d || d.getPosition() != nextExpectedDim++) {
477  if (invalidIndex)
478  *invalidIndex = it.index();
479  return false;
480  }
481  }
482  }
483  if (nextExpectedDim != nDims) {
484  if (invalidIndex)
485  *invalidIndex = reassociation.size() - 1;
486  return false;
487  }
488  return true;
489 }
490 
492  function_ref<LogicalResult(const Twine &)> emitError,
493  ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
494  ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape) {
495  unsigned expandedDimStart = 0;
496  for (const auto &map : llvm::enumerate(reassociationMaps)) {
497  bool foundDynamicShape = false;
498  int64_t linearizedStaticShape = 1;
499 
500  for (const auto &dim : llvm::enumerate(
501  expandedShape.slice(expandedDimStart, map.value().size()))) {
502  if (ShapedType::isDynamic(dim.value()))
503  foundDynamicShape = true;
504  else
505  linearizedStaticShape *= dim.value();
506  }
507  if (foundDynamicShape) {
508  if (ShapedType::isStatic(collapsedShape[map.index()])) {
509  return emitError(
510  "expected dimension " + Twine(map.index()) +
511  " of collapsed type to be dynamic since one or more of the "
512  "corresponding dimensions in the expanded type is dynamic");
513  }
514  } else {
515  if (collapsedShape[map.index()] != linearizedStaticShape) {
516  return emitError("expected dimension " + Twine(map.index()) +
517  " of collapsed type to be static value of " +
518  Twine(linearizedStaticShape));
519  }
520  }
521  expandedDimStart += map.value().size();
522  }
523  return success();
524 }
525 
527  if (auto memrefType = dyn_cast<MemRefType>(type))
528  return !memrefType.getLayout().isIdentity();
529  return false;
530 }
531 
532 llvm::SmallBitVector
534  ArrayRef<Range> sliceParams) {
535  assert(sliceParams.size() == sliceInputShape.size() &&
536  "only supports non rank-reducing case");
537  llvm::SmallBitVector mask(sliceInputShape.size());
538  unsigned idx = 0;
539  for (const auto &[offset, size, stride] : sliceParams) {
540  std::optional<int64_t> offsetConst = getConstantIntValue(offset);
541  std::optional<int64_t> strideConst = getConstantIntValue(stride);
542  mask[idx] = !isEqualConstantIntOrValue(size, sliceInputShape[idx]) ||
543  (!strideConst || *strideConst != 1) ||
544  (!offsetConst || *offsetConst != 0);
545  idx++;
546  }
547  return mask;
548 }
549 
550 llvm::SmallBitVector mlir::getLinearizedDimensions(
551  ArrayRef<ReassociationIndices> reassociationIndices) {
552  llvm::SmallBitVector result(reassociationIndices.size());
553  for (const auto &it : llvm::enumerate(reassociationIndices))
554  result[it.index()] = it.value().size() > 1;
555  return result;
556 }
557 
558 SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
559  MLIRContext *ctx, ArrayRef<ValueRange> multiIndices) {
560  unsigned loopIdx = 0;
561  auto oneAttr = IntegerAttr::get(IndexType::get(ctx), 1);
562  auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0);
563  SmallVector<Range> offsetsSizesAndStrides;
564  offsetsSizesAndStrides.reserve(collapseShapeInputShape.size());
565  for (const auto &it : llvm::enumerate(reassociationIndices)) {
566  // Case 1: Linearized dimensions that have also been sliced. These
567  // are size of 1 because we are iterating over these dimensions. The
568  // offsets are exactly the de-linearized multi-indices.
569  if (slicedDimensions[it.index()] && linearizedDimensions[it.index()]) {
570  llvm::append_range(
571  offsetsSizesAndStrides,
572  llvm::map_range(multiIndices[loopIdx++], [&](Value v) -> Range {
573  return Range{getAsOpFoldResult(v), oneAttr, oneAttr};
574  }));
575  continue;
576  }
577 
578  // Case 2: One or possibly multiple combined input dimensions, but we
579  // have proven that these are not sliced. In this case we just take
580  // the full extent of each dimension in the reassociation list.
581  if (linearizedDimensions[it.index()]) {
582  llvm::append_range(offsetsSizesAndStrides,
583  llvm::map_range(it.value(), [&](int64_t idx) -> Range {
584  return {zeroAttr, collapseShapeInputShape[idx],
585  oneAttr};
586  }));
587  continue;
588  }
589 
590  // Case 3: A single index, but it may be sliced.
591  offsetsSizesAndStrides.push_back(sliceParams[it.index()]);
592  }
593  return offsetsSizesAndStrides;
594 }
595 
597 SliceFromCollapseHelper::getInsertSliceParams(MLIRContext *ctx,
598  ValueRange tileIndices) {
599  auto one = IntegerAttr::get(IndexType::get(ctx), 1);
600  auto zero = IntegerAttr::get(IndexType::get(ctx), 0);
601  SmallVector<Range> insertParams;
602  insertParams.reserve(linearizedDimensions.size());
603  unsigned loopIdx = 0;
604  for (unsigned i = 0; i < linearizedDimensions.size(); i++) {
605  if (linearizedDimensions[i] && slicedDimensions[i]) {
606  insertParams.push_back(Range{tileIndices[loopIdx++], one, one});
607  continue;
608  }
609  insertParams.push_back(Range{zero, sliceParams[i].size, one});
610  }
611  return insertParams;
612 }
613 
614 /// Returns the index of the only non-unit dimension among `indices` of `shape`,
615 /// if such a dimension exists and `indices` has more than one element.
616 /// Otherwise, return std::nullopt.
617 static std::optional<int64_t> getUniqueNonUnitDim(ArrayRef<int64_t> indices,
618  ArrayRef<int64_t> shape) {
619  // Return false if more than one of the dimensions in this group are not 1.
620  std::optional<int64_t> dimIndex;
621  if (indices.size() < 2)
622  return std::nullopt;
623  for (int64_t idx : indices) {
624  if (shape[idx] != 1) {
625  if (dimIndex != std::nullopt)
626  return std::nullopt;
627  dimIndex = idx;
628  }
629  }
630  return dimIndex;
631 }
632 
633 // For each segment in the reassociation indices, check whether we can
634 // simplify that segment with a rank-reducing extract slice. We can do this if
635 // all but (exactly) one of the corresponding source dims is 1.
637  RankedTensorType sourceType,
638  ArrayRef<ReassociationIndices> reassociationIndices) {
639  SmallVector<std::optional<int64_t>> trivialSegments;
640  for (const auto &indices : reassociationIndices)
641  trivialSegments.push_back(
642  getUniqueNonUnitDim(indices, sourceType.getShape()));
643  return trivialSegments;
644 }
645 
646 /// Returns true if any of the segments of the reassociation indices for a
647 /// collapsing reshape can be simplified using a rank-reducing slice.
648 static FailureOr<SmallVector<std::optional<int64_t>>>
650  RankedTensorType sourceType,
651  ArrayRef<ReassociationIndices> reassociationIndices) {
652  SmallVector<std::optional<int64_t>> trivialSegments =
653  getCollapseShapeTrivialSegments(sourceType, reassociationIndices);
654  if (!llvm::any_of(trivialSegments, [](const std::optional<int64_t> &idx) {
655  return idx.has_value();
656  }))
657  return failure();
658  return trivialSegments;
659 }
660 
661 FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
662 mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
663  RankedTensorType sourceType,
664  ArrayRef<ReassociationIndices> reassociationIndices) {
665  FailureOr<SmallVector<std::optional<int64_t>>> trivialSegments =
667  reassociationIndices);
668  if (failed(trivialSegments))
669  return failure();
670 
671  // Create the expected result shape of the rank-reducing slice.
672  SmallVector<int64_t> sliceShape;
673  for (const auto &[nonUnitDim, indices] :
674  llvm::zip(*trivialSegments, reassociationIndices)) {
675  if (nonUnitDim) {
676  sliceShape.push_back(sourceType.getDimSize(*nonUnitDim));
677  continue;
678  }
679  llvm::append_range(sliceShape, llvm::map_range(indices, [&](int64_t idx) {
680  return sourceType.getDimSize(idx);
681  }));
682  }
683  auto sliceType =
684  RankedTensorType::get(sliceShape, sourceType.getElementType());
685 
686  // If the rank-reducing slice simplified every segment, then we are done.
687  if (sliceShape.size() == reassociationIndices.size())
688  return CollapseShapeRankReducingSliceSimplificationInfo{sliceType,
689  std::nullopt};
690 
691  // Otherwise, we need to create a new collapse_shape op for the segments that
692  // weren't covered by the slice. By design, the new reassociation indices has
693  // the same number of groups as the old reassociation indices.
694  SmallVector<ReassociationIndices> newReassociationIndices;
695  SmallVector<int64_t, 2> reassociation;
696  int64_t groupIdx = 0;
697  for (int64_t dimIdx = 0; dimIdx < sliceType.getRank(); dimIdx++) {
698  reassociation.push_back(dimIdx);
699  if ((*trivialSegments)[groupIdx] ||
700  reassociation.size() == reassociationIndices[groupIdx].size()) {
701  newReassociationIndices.push_back(reassociation);
702  reassociation.clear();
703  groupIdx++;
704  }
705  }
706 
707  return CollapseShapeRankReducingSliceSimplificationInfo{
708  sliceType, newReassociationIndices};
709 }
710 
711 PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
712  ArrayRef<int64_t> innerDimPos) {
713  PackingMetadata res;
714  res.insertPositions.reserve(innerDimPos.size());
715  // The pack insert position is the position + the number of previously
716  // inserted positions + offset.
717  // The offset controls whether the packing dimension is the first or last.
718  //
719  // Example
720  // =======
721  // Consider packing from a hypothetical ABCD layout to ABCDba whose
722  // pack.inner_dims is [1, 0]. The first step consists in undoing the
723  // permutation and producing AaBbCD. This is achieved purely by computing the
724  // insert positions of `b` and `a` into `ABCD`, starting from [1, 0]. One
725  // possibility, is to produce insert positions [2, 0], this would result in an
726  // aAbBCD layout (i.e. offset 0). The other possibility, is to produce insert
727  // positions [3, 1], this would result in an AaBbCD layout (i.e. offset 1).
728  // The latter is what we expect from packing.
729  int64_t offset = 1;
730  for (int64_t pos : innerDimPos) {
731  int64_t numInsertedBefore = llvm::count_if(
732  innerDimPos, [&pos](int64_t pos2) { return pos > pos2; });
733  res.insertPositions.push_back(pos + numInsertedBefore + offset);
734  }
735 
736  DenseSet<int64_t> posSet(res.insertPositions.begin(),
737  res.insertPositions.end());
738  res.reassociations.reserve(packedRank);
739  for (int64_t i = 1; i <= packedRank; ++i) {
740  res.outerPositions.push_back(i - 1);
741  if (!posSet.contains(i)) {
742  res.reassociations.push_back(ReassociationIndices{i - 1});
743  continue;
744  }
745  res.reassociations.push_back(ReassociationIndices{i - 1, i});
746  ++i;
747  }
748  return res;
749 }
750 
751 OpFoldResult mlir::reshapeConstantSource(DenseElementsAttr source,
752  TensorType result,
753  std::optional<Attribute> cst) {
754  if (source && source.isSplat() && result.hasStaticShape() &&
755  (!cst.has_value() || source.getSplatValue<Attribute>() == cst.value()))
756  return source.resizeSplat(result);
757 
758  return {};
759 }
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
unsigned getMaxPosOfType(ArrayRef< ReassociationExprs > exprArrays)
static FailureOr< ReassociationIndexRange > findReassociationRangeForSize(ArrayRef< int64_t > sourceShape, int64_t sourceStartIdx, int64_t targetSize, bool matchGreedily=false)
Starting from sourceStartIdx, searches sourceShape for the first sequence of static dimensions such t...
static SmallVector< std::optional< int64_t > > getCollapseShapeTrivialSegments(RankedTensorType sourceType, ArrayRef< ReassociationIndices > reassociationIndices)
static FailureOr< ReassociationIndexRange > findReassociationRangeForDynamicDim(ArrayRef< int64_t > sourceShape, int64_t sourceStartIdx, bool matchGreedily=false)
Starting from sourceStartIdx, searches sourceShape for the first sequence that can be collapsed into ...
static std::optional< int64_t > getUniqueNonUnitDim(ArrayRef< int64_t > indices, ArrayRef< int64_t > shape)
Returns the index of the only non-unit dimension among indices of shape, if such a dimension exists a...
static FailureOr< SmallVector< ReassociationIndexRange > > findReassociationRangesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)
Attempts to find a valid collapsing reassociation of sourceShape into targetShape through a simple tr...
static FailureOr< SmallVector< std::optional< int64_t > > > canCollapseShapeBeSimplifiedByRankReducingSlice(RankedTensorType sourceType, ArrayRef< ReassociationIndices > reassociationIndices)
Returns true if any of the segments of the reassociation indices for a collapsing reshape can be simp...
Base type for affine expression.
Definition: AffineExpr.h:68
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:261
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:276
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
llvm::SmallBitVector getSlicedDimensions(ArrayRef< OpFoldResult > sliceInputShape, ArrayRef< Range > sliceParams)
The input parameters offsets, sizes, strides specify a rectangular non rank-reducing slice of the col...
bool hasNonIdentityLayout(Type type)
Returns true iff the type is a MemRefType and has a non-identity layout.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< ReassociationIndices, 2 > convertReassociationMapsToIndices(ArrayRef< ReassociationExprs > reassociationExprs)
Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rule: if a dimension in the collapsed type i...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)
Returns the reassociation maps to collapse sourceShape to targetShape if possible.
ArrayRef< int64_t > ReassociationIndicesRef
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
std::optional< SmallVector< ReassociationIndices > > composeReassociationIndices(ArrayRef< ReassociationIndices > producerReassociations, ArrayRef< ReassociationIndices > consumerReassociations, MLIRContext *context)
Compose reassociation maps that are used in pair of reshape ops where one is a producer and other is ...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getLinearizedDimensions(ArrayRef< ReassociationIndices > reassociationIndices)
Determine which dimensions are linearized by a tensor.collapse_shape op by inspecting its reassociati...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult size