MLIR  22.0.0git
ReshapeOpsUtils.cpp
Go to the documentation of this file.
1 //===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Builders.h"
14 #include "llvm/ADT/ArrayRef.h"
15 #include "llvm/ADT/SmallVector.h"
16 
17 #include <numeric>
18 #include <optional>
19 
20 using namespace mlir;
21 
22 std::optional<SmallVector<ReassociationIndices>>
24  ShapedType targetType) {
25  if (sourceType.getRank() > targetType.getRank())
26  return getReassociationIndicesForCollapse(sourceType.getShape(),
27  targetType.getShape());
28  if (sourceType.getRank() < targetType.getRank())
29  return getReassociationIndicesForCollapse(targetType.getShape(),
30  sourceType.getShape());
31  return std::nullopt;
32 }
33 
34 namespace {
35 /// A simple struct to represent ReassociationIndices as an inclusive interval.
36 /// It's designed to be feasibly minimal, so the call sites should manage the
37 /// validity of the range manually.
38 struct ReassociationIndexRange {
39  /// FIXME: Signed type is used for consistency with ReassociationIndices.
40  /// We should consider refactoring all reassociation utilities to use unsigned
41  /// types.
42  int64_t leftIdx = 0, rightIdx = 0;
43 
44  /// Util for manual checks of the range's validity
45  LogicalResult verify() const {
46  return leftIdx >= 0 && (leftIdx <= rightIdx) ? success() : failure();
47  }
48 
49  /// Checks range's containment within another range. Treats the edges
50  /// non-exclusively.
51  bool isInRange(const ReassociationIndexRange &outerRange) const {
52  return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx;
53  }
54 
55  unsigned size() const {
56  assert(succeeded(verify()));
57  return rightIdx - leftIdx + 1;
58  }
59  bool containsSingleIndex() const { return size() == 1; }
60 
61  /// Collects indices that do not overlap between this and another range.
63  getNonOverlappingIndicesWith(ReassociationIndexRange &rhs) const {
64  if (rightIdx < rhs.leftIdx) {
65  // The intervals do not overlap - concatenate the indices from both.
66  auto jointFullIndices = getFullIndices();
67  jointFullIndices.append(rhs.getFullIndices());
68  return jointFullIndices;
69  }
70  ReassociationIndices result;
71  // Handle the chunk left of the overlapping range.
72  int64_t leftStart = std::min(leftIdx, rhs.leftIdx);
73  int64_t leftEnd = std::max(leftIdx, rhs.leftIdx);
74  llvm::append_range(result, llvm::seq(leftStart, leftEnd));
75  // Handle the chunk right of the overlapping range. Symmetrically, we should
76  // skip the edge of the overlap AND include the rightmost index.
77  int64_t rightStart = std::min(rightIdx, rhs.rightIdx) + 1;
78  int64_t rightEnd = std::max(rightIdx, rhs.rightIdx);
79  if (rightStart < rightEnd)
80  llvm::append_range(result, llvm::seq_inclusive(rightStart, rightEnd));
81  return result;
82  }
83 
84  /// Converts the range into ReassociationIndices.
85  ReassociationIndices getFullIndices() const {
86  ReassociationIndices result;
87  for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) {
88  result.push_back(idx);
89  }
90  return result;
91  }
92 };
93 } // namespace
94 
95 /// Starting from `sourceStartIdx`, searches `sourceShape` for the first
96 /// sequence that can be collapsed into a dynamic dimension (at least one must
97 /// be present in the source).
98 /// By default, lazily returns once the first dynamic dimension has been found.
99 /// Setting `matchGreedily` as `true` will also mark all subsequent
100 /// source dimensions for collapsing into the target.
101 static FailureOr<ReassociationIndexRange>
103  int64_t sourceStartIdx,
104  bool matchGreedily = false) {
105  const unsigned numSourceDims = sourceShape.size();
106  ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
107  std::optional<ReassociationIndexRange> resultRange = std::nullopt;
108 
109  ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
110  for (; iterationRange.isInRange(sourceShapeAsRange);
111  iterationRange.rightIdx++) {
112  int64_t sourceSize = sourceShape[iterationRange.rightIdx];
113  if (sourceSize == ShapedType::kDynamic) {
114  resultRange = iterationRange;
115  break;
116  }
117  }
118  if (!resultRange)
119  return failure();
120  if (matchGreedily)
121  resultRange->rightIdx = sourceShapeAsRange.rightIdx;
122  return *resultRange;
123 }
124 
125 /// Starting from `sourceStartIdx`, searches `sourceShape` for the first
126 /// sequence of static dimensions such that their product matches `targetSize`.
127 /// By default, lazily returns once the product matches the target size. Setting
128 /// `matchGreedily` as `true` will append all neighboring unit dimensions
129 /// (dimensions of 1) to the match.
130 static FailureOr<ReassociationIndexRange>
132  int64_t sourceStartIdx, int64_t targetSize,
133  bool matchGreedily = false) {
134  const unsigned numSourceDims = sourceShape.size();
135  ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
136  std::optional<ReassociationIndexRange> resultRange = std::nullopt;
137 
138  ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
139  int64_t prodOfCollapsedDims = 1;
140  while (iterationRange.isInRange(sourceShapeAsRange)) {
141  int64_t sourceSize = sourceShape[iterationRange.rightIdx];
142  if (sourceSize == ShapedType::kDynamic) {
143  // Reassociation for a static dim cannot include a dynamic dim. Reset
144  // induction variables to essentially restart the loop from the next
145  // source dimension.
146  prodOfCollapsedDims = 1;
147  iterationRange = {iterationRange.rightIdx + 1,
148  iterationRange.rightIdx + 1};
149  continue;
150  }
151  prodOfCollapsedDims *= sourceSize;
152  // If the target size has been exceeded without matching, we need to shift
153  // the range start right. From the start of the range, roll back the
154  // multiplication until the target size exceeds the product again.
155  while (prodOfCollapsedDims > targetSize &&
156  !iterationRange.containsSingleIndex()) {
157  int64_t frontSourceSize = sourceShape[iterationRange.leftIdx];
158  prodOfCollapsedDims /= frontSourceSize;
159  // Shrink the range rightwards
160  iterationRange.leftIdx++;
161  }
162  // We could've reached the target size with the current dimension,
163  // also as a result of the above shift to right.
164  if (prodOfCollapsedDims == targetSize) {
165  resultRange = iterationRange;
166  break;
167  }
168  // Increment the iteration range
169  iterationRange.rightIdx++;
170  }
171  if (!resultRange)
172  return failure();
173  if (matchGreedily) {
174  // We now want to collect all unit dimensions directly after the target
175  // product match. Advance the iterator to avoid OOB when the product match
176  // happens at the last element.
177  iterationRange.rightIdx++;
178  while (iterationRange.isInRange(sourceShapeAsRange) &&
179  sourceShape[iterationRange.rightIdx] == 1) {
180  resultRange = iterationRange;
181  iterationRange.rightIdx++;
182  }
183  }
184  return *resultRange;
185 }
186 
187 /// Attempts to find a valid collapsing reassociation of `sourceShape` into
188 /// `targetShape` through a simple traversal. If successful, an array of source
189 /// index ranges is returned, correspondingly to each dimension in the target
190 /// shape. The resulting indices shall fully cover the `sourceShape` without
191 /// overlaps.
192 ///
193 /// The algorithm is essentially a lazy one, searching for non-greedy matches -
194 /// it will only yield a greedy match for the last target dimension.
195 /// FIXME: The algorithm can only backtrack when it needs to append an offset
196 /// for a static target dimension to the preceding dynamic one (this retains the
197 /// linear complexity). As feasible, consider adding further backtracking
198 /// routines to enable more reassociations, e.g.:
199 /// - ?x2x?x2 into ?x2
200 static FailureOr<SmallVector<ReassociationIndexRange>>
202  ArrayRef<int64_t> targetShape) {
203  unsigned numSourceDims = sourceShape.size(),
204  numTargetDims = targetShape.size();
205  assert(numSourceDims > numTargetDims);
206  ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
207 
209  reassocRanges.reserve(numTargetDims);
210  // We'll iterate in strides of 2 to enable pseudo-backtracking for simple
211  // cases, e.g.:
212  // - ?x2x3x5 into ?x15
213  std::optional<int64_t> prevTargetSize = std::nullopt;
214  for (unsigned targetDimIdx = 0, sourceDimIdx = 0;
215  targetDimIdx < numTargetDims; ++targetDimIdx) {
216  int64_t targetSize = targetShape[targetDimIdx];
217  // Simply check if there are any subsequent target dimensions left - if not,
218  // the match must be made greedily.
219  bool shouldMatchGreedily = targetDimIdx == numTargetDims - 1;
220  FailureOr<ReassociationIndexRange> sourceRange;
221  if (targetSize == ShapedType::kDynamic) {
223  sourceShape, sourceDimIdx, shouldMatchGreedily);
224  } else {
225  sourceRange = findReassociationRangeForSize(
226  sourceShape, sourceDimIdx, targetSize, shouldMatchGreedily);
227  }
228 
229  // Run sanity checks on the returned index range.
230  if (failed(sourceRange) || failed(sourceRange->verify()) ||
231  !sourceRange->isInRange(sourceShapeAsRange))
232  return failure();
233  if (sourceRange->leftIdx > sourceDimIdx) {
234  // If some source dimensions had to be skipped in order to find a match,
235  // they must be collapsed into the directly preceding dynamic dimension.
236  if (!prevTargetSize || prevTargetSize != ShapedType::kDynamic)
237  return failure();
238  reassocRanges.back().rightIdx = sourceRange->leftIdx - 1;
239  }
240 
241  // Store the gathered information as required for the next iteration.
242  prevTargetSize = targetSize;
243  sourceDimIdx = sourceRange->rightIdx + 1;
244  reassocRanges.push_back(*sourceRange);
245  }
246  // Fail if the source shape wasn't a full match for the target shape. We only
247  // need to check the last recorded index - any other gaps should have been
248  // mended by the main loop.
249  if (reassocRanges.back().rightIdx < sourceShapeAsRange.rightIdx)
250  return failure();
251  return reassocRanges;
252 }
253 
254 /// A variant of `findReassociationRangesForCollapse(...)` that can also scan
255 /// the shapes right-to-left.
256 static FailureOr<SmallVector<ReassociationIndexRange>>
258  ArrayRef<int64_t> targetShape,
259  bool iterateRightToLeft) {
260  if (!iterateRightToLeft)
261  return findReassociationRangesForCollapse(sourceShape, targetShape);
262  // NB: To iterate right-to-left, we currently reverse the shapes and then
263  // reverse the result back. The reversed shapes must not be temporary, as
264  // we're passing through an ArrayRef.
265  // FIXME: It would be preferable to avoid the expensive copies. At the moment,
266  // this approach is chosen for readability of the main implementation.
267  std::vector<int64_t> sourceToReverse = sourceShape.vec(),
268  targetToReverse = targetShape.vec();
269  std::reverse(sourceToReverse.begin(), sourceToReverse.end());
270  std::reverse(targetToReverse.begin(), targetToReverse.end());
271  auto invertedRanges =
272  findReassociationRangesForCollapse(sourceToReverse, targetToReverse);
273  if (failed(invertedRanges))
274  return failure();
275  SmallVector<ReassociationIndexRange> &rangesToInvert = *invertedRanges;
276  unsigned numSourceDims = sourceShape.size();
277  // We have received the ranges for inverted shapes. Now we have to invert
278  // the ranges back to correspond with the original source shape.
279  for (auto &range : rangesToInvert) {
280  int64_t invLeftIdx = range.leftIdx, invRightIdx = range.rightIdx;
281  range.leftIdx = numSourceDims - 1 - invRightIdx;
282  range.rightIdx = numSourceDims - 1 - invLeftIdx;
283  }
284  // Also invert the ordering of the ranges to correspond with the original
285  // target shape.
286  std::reverse(rangesToInvert.begin(), rangesToInvert.end());
287  return rangesToInvert;
288 }
289 
290 std::optional<SmallVector<ReassociationIndices>>
292  ArrayRef<int64_t> targetShape) {
293  unsigned numSourceDims = sourceShape.size(),
294  numTargetDims = targetShape.size();
295  // We're supposed to search for a collapsing reassociation. If the sizes
296  // match, there's no actual collapsing taking place - it's either a no-op or a
297  // `tensor.reshape`-style reassociation (that would be beyond the scope of
298  // this utility).
299  if (numSourceDims <= numTargetDims)
300  return std::nullopt;
301  // Early handling for scalar target types. We should report an invalid
302  // reassociation for non-unit static dimensions - no chance to collapse these
303  // into a scalar.
304  if (numTargetDims == 0) {
305  for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;
306  ++sourceDimIdx) {
307  int64_t sourceSize = sourceShape[sourceDimIdx];
308  if (sourceSize != 1 && sourceSize != ShapedType::kDynamic)
309  return std::nullopt;
310  }
312  }
313 
314  // Collect source ranges by iterating over the target shape left-to-right.
315  FailureOr<SmallVector<ReassociationIndexRange>> maybeForwardRanges =
316  findReassociationRangesForCollapse(sourceShape, targetShape);
317  if (failed(maybeForwardRanges))
318  return std::nullopt;
319  auto &ranges = *maybeForwardRanges;
320  // Now do the same in reverse. We need to get another valid reassociation
321  // through some other strategy, and then compare the results in order to
322  // disambiguate mixed subshapes, such as:
323  // ?x?x? into ?x?, ?x2x? into ?x?, ?x2x3x6x? into ?x6x?
324  // This leads us to lose some of the reassociation opportunities that can only
325  // be found by iterating in a certain direction, e.g. 2x2x? into 2x? - without
326  // backtracking, the algorithm will fail right-to-left. However, this is the
327  // best way to preserve correctness.
328  FailureOr<SmallVector<ReassociationIndexRange>> maybeReverseRanges =
329  findReassociationRangesForCollapse(sourceShape, targetShape,
330  /*iterateRightToLeft=*/true);
331  if (failed(maybeReverseRanges))
332  return std::nullopt;
333  auto &reverseRanges = *maybeReverseRanges;
334 
335  if (ranges.size() != numTargetDims || reverseRanges.size() != numTargetDims)
336  return std::nullopt;
337  // Now we can check for ambiguity of each target dimension's reassociation. If
338  // successful, we put the full indices into our result map for the target
339  // shape.
340  SmallVector<ReassociationIndices> reassociationMap(numTargetDims);
341  for (unsigned targetDimIdx = 0; targetDimIdx < numTargetDims;
342  ++targetDimIdx) {
343  ReassociationIndexRange &range = ranges[targetDimIdx];
344  ReassociationIndexRange &reverseRange = reverseRanges[targetDimIdx];
345  // Get non-overlapping indices between the ranges
346  ReassociationIndices nonMatchingIndices =
347  range.getNonOverlappingIndicesWith(reverseRange);
348  // Unit dimensions can be collapsed wherever - this is the only ambiguity
349  // that we allow.
350  for (int64_t sourceDimIdx : nonMatchingIndices) {
351  if (sourceShape[sourceDimIdx] != 1)
352  return std::nullopt;
353  }
354  reassociationMap[targetDimIdx] = range.getFullIndices();
355  }
356  return reassociationMap;
357 }
358 
359 std::optional<SmallVector<ReassociationIndices>>
361  ArrayRef<ReassociationIndices> producerReassociations,
362  ArrayRef<ReassociationIndices> consumerReassociations,
363  MLIRContext *context) {
364  SmallVector<ReassociationIndices> composedIndices;
365  // Make the producer the larger sized vector. If they are of same size, the
366  // resulting reshape is not a supported reshape op.
367  if (producerReassociations.size() == consumerReassociations.size())
368  return std::nullopt;
369  if (producerReassociations.size() < consumerReassociations.size())
370  std::swap(producerReassociations, consumerReassociations);
371 
372  // Handle the corner case of the result being a rank 0 shaped type. Return an
373  // empty reassociation.
374  if (consumerReassociations.empty())
375  return composedIndices;
376 
377  size_t consumerDims = std::accumulate(
378  consumerReassociations.begin(), consumerReassociations.end(), 0,
379  [](size_t all, ReassociationIndicesRef indices) {
380  return all + indices.size();
381  });
382  if (producerReassociations.size() != consumerDims)
383  return std::nullopt;
384 
385  for (ReassociationIndicesRef consumerIndices : consumerReassociations) {
386  ReassociationIndices reassociations;
387  for (int64_t consumerIndex : consumerIndices) {
388  llvm::append_range(reassociations, producerReassociations[consumerIndex]);
389  }
390  composedIndices.push_back(std::move(reassociations));
391  }
392  return composedIndices;
393 }
394 
397  MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices) {
398  SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps;
399  for (const auto &indices : reassociationIndices) {
400  SmallVector<AffineExpr, 2> reassociationMap;
401  reassociationMap.reserve(indices.size());
402  for (int64_t index : indices)
403  reassociationMap.push_back(mlir::getAffineDimExpr(index, context));
404  reassociationMaps.push_back(std::move(reassociationMap));
405  }
406  return reassociationMaps;
407 }
408 
409 template <typename AffineExprTy>
411  unsigned pos = 0;
412  for (const auto &exprs : exprArrays) {
413  for (auto expr : exprs) {
414  expr.walk([&pos](AffineExpr e) {
415  if (auto d = dyn_cast<AffineExprTy>(e))
416  pos = std::max(pos, d.getPosition());
417  });
418  }
419  }
420  return pos;
421 }
422 
424  Builder &b, ArrayRef<ReassociationIndices> reassociation) {
425  SmallVector<Attribute, 4> reassociationAttr =
426  llvm::to_vector<4>(llvm::map_range(
427  reassociation, [&](const ReassociationIndices &indices) -> Attribute {
428  return cast<Attribute>(b.getI64ArrayAttr(indices));
429  }));
430  return b.getArrayAttr(reassociationAttr);
431 }
432 
434  ArrayRef<ReassociationExprs> reassociationExprs) {
435  SmallVector<ReassociationIndices, 2> reassociationIndices;
436  for (const auto &exprs : reassociationExprs) {
437  ReassociationIndices indices;
438  indices.reserve(exprs.size());
439  for (const auto &expr : exprs)
440  indices.push_back(cast<AffineDimExpr>(expr).getPosition());
441  reassociationIndices.push_back(indices);
442  }
443  return reassociationIndices;
444 }
445 
448  unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
449  assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
450  "Expected symbol-less expressions");
452  maps.reserve(reassociation.size());
453  for (const auto &exprs : reassociation) {
454  assert(!exprs.empty());
455  maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext()));
456  }
457  return maps;
458 }
459 
461  int *invalidIndex) {
462  if (reassociation.empty())
463  return true;
464  unsigned nDims = reassociation[0].getNumDims();
465  unsigned nextExpectedDim = 0;
466  for (const auto &it : llvm::enumerate(reassociation)) {
467  auto m = it.value();
468  if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
469  if (invalidIndex)
470  *invalidIndex = it.index();
471  return false;
472  }
473  for (auto e : m.getResults()) {
474  auto d = dyn_cast<AffineDimExpr>(e);
475  if (!d || d.getPosition() != nextExpectedDim++) {
476  if (invalidIndex)
477  *invalidIndex = it.index();
478  return false;
479  }
480  }
481  }
482  if (nextExpectedDim != nDims) {
483  if (invalidIndex)
484  *invalidIndex = reassociation.size() - 1;
485  return false;
486  }
487  return true;
488 }
489 
491  function_ref<LogicalResult(const Twine &)> emitError,
492  ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
493  ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape) {
494  unsigned expandedDimStart = 0;
495  for (const auto &map : llvm::enumerate(reassociationMaps)) {
496  bool foundDynamicShape = false;
497  int64_t linearizedStaticShape = 1;
498 
499  for (const auto &dim : llvm::enumerate(
500  expandedShape.slice(expandedDimStart, map.value().size()))) {
501  if (ShapedType::isDynamic(dim.value()))
502  foundDynamicShape = true;
503  else
504  linearizedStaticShape *= dim.value();
505  }
506  if (foundDynamicShape) {
507  if (ShapedType::isStatic(collapsedShape[map.index()])) {
508  return emitError(
509  "expected dimension " + Twine(map.index()) +
510  " of collapsed type to be dynamic since one or more of the "
511  "corresponding dimensions in the expanded type is dynamic");
512  }
513  } else {
514  if (collapsedShape[map.index()] != linearizedStaticShape) {
515  return emitError("expected dimension " + Twine(map.index()) +
516  " of collapsed type to be static value of " +
517  Twine(linearizedStaticShape));
518  }
519  }
520  expandedDimStart += map.value().size();
521  }
522  return success();
523 }
524 
526  if (auto memrefType = dyn_cast<MemRefType>(type))
527  return !memrefType.getLayout().isIdentity();
528  return false;
529 }
530 
531 llvm::SmallBitVector
533  ArrayRef<Range> sliceParams) {
534  assert(sliceParams.size() == sliceInputShape.size() &&
535  "only supports non rank-reducing case");
536  llvm::SmallBitVector mask(sliceInputShape.size());
537  unsigned idx = 0;
538  for (const auto &[offset, size, stride] : sliceParams) {
539  std::optional<int64_t> offsetConst = getConstantIntValue(offset);
540  std::optional<int64_t> strideConst = getConstantIntValue(stride);
541  mask[idx] = !isEqualConstantIntOrValue(size, sliceInputShape[idx]) ||
542  (!strideConst || *strideConst != 1) ||
543  (!offsetConst || *offsetConst != 0);
544  idx++;
545  }
546  return mask;
547 }
548 
549 llvm::SmallBitVector mlir::getLinearizedDimensions(
550  ArrayRef<ReassociationIndices> reassociationIndices) {
551  llvm::SmallBitVector result(reassociationIndices.size());
552  for (const auto &it : llvm::enumerate(reassociationIndices))
553  result[it.index()] = it.value().size() > 1;
554  return result;
555 }
556 
557 SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
558  MLIRContext *ctx, ArrayRef<ValueRange> multiIndices) {
559  unsigned loopIdx = 0;
560  auto oneAttr = IntegerAttr::get(IndexType::get(ctx), 1);
561  auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0);
562  SmallVector<Range> offsetsSizesAndStrides;
563  offsetsSizesAndStrides.reserve(collapseShapeInputShape.size());
564  for (const auto &it : llvm::enumerate(reassociationIndices)) {
565  // Case 1: Linearized dimensions that have also been sliced. These
566  // are size of 1 because we are iterating over these dimensions. The
567  // offsets are exactly the de-linearized multi-indices.
568  if (slicedDimensions[it.index()] && linearizedDimensions[it.index()]) {
569  llvm::append_range(
570  offsetsSizesAndStrides,
571  llvm::map_range(multiIndices[loopIdx++], [&](Value v) -> Range {
572  return Range{getAsOpFoldResult(v), oneAttr, oneAttr};
573  }));
574  continue;
575  }
576 
577  // Case 2: One or possibly multiple combined input dimensions, but we
578  // have proven that these are not sliced. In this case we just take
579  // the full extent of each dimension in the reassociation list.
580  if (linearizedDimensions[it.index()]) {
581  llvm::append_range(offsetsSizesAndStrides,
582  llvm::map_range(it.value(), [&](int64_t idx) -> Range {
583  return {zeroAttr, collapseShapeInputShape[idx],
584  oneAttr};
585  }));
586  continue;
587  }
588 
589  // Case 3: A single index, but it may be sliced.
590  offsetsSizesAndStrides.push_back(sliceParams[it.index()]);
591  }
592  return offsetsSizesAndStrides;
593 }
594 
596 SliceFromCollapseHelper::getInsertSliceParams(MLIRContext *ctx,
597  ValueRange tileIndices) {
598  auto one = IntegerAttr::get(IndexType::get(ctx), 1);
599  auto zero = IntegerAttr::get(IndexType::get(ctx), 0);
600  SmallVector<Range> insertParams;
601  insertParams.reserve(linearizedDimensions.size());
602  unsigned loopIdx = 0;
603  for (unsigned i = 0; i < linearizedDimensions.size(); i++) {
604  if (linearizedDimensions[i] && slicedDimensions[i]) {
605  insertParams.push_back(Range{tileIndices[loopIdx++], one, one});
606  continue;
607  }
608  insertParams.push_back(Range{zero, sliceParams[i].size, one});
609  }
610  return insertParams;
611 }
612 
613 /// Returns the index of the only non-unit dimension among `indices` of `shape`,
614 /// if such a dimension exists and `indices` has more than one element.
615 /// Otherwise, return std::nullopt.
616 static std::optional<int64_t> getUniqueNonUnitDim(ArrayRef<int64_t> indices,
617  ArrayRef<int64_t> shape) {
618  // Return false if more than one of the dimensions in this group are not 1.
619  std::optional<int64_t> dimIndex;
620  if (indices.size() < 2)
621  return std::nullopt;
622  for (int64_t idx : indices) {
623  if (shape[idx] != 1) {
624  if (dimIndex != std::nullopt)
625  return std::nullopt;
626  dimIndex = idx;
627  }
628  }
629  return dimIndex;
630 }
631 
632 // For each segment in the reassociation indices, check whether we can
633 // simplify that segment with a rank-reducing extract slice. We can do this if
634 // all but (exactly) one of the corresponding source dims is 1.
636  RankedTensorType sourceType,
637  ArrayRef<ReassociationIndices> reassociationIndices) {
638  SmallVector<std::optional<int64_t>> trivialSegments;
639  for (const auto &indices : reassociationIndices)
640  trivialSegments.push_back(
641  getUniqueNonUnitDim(indices, sourceType.getShape()));
642  return trivialSegments;
643 }
644 
645 /// Returns true if any of the segments of the reassociation indices for a
646 /// collapsing reshape can be simplified using a rank-reducing slice.
647 static FailureOr<SmallVector<std::optional<int64_t>>>
649  RankedTensorType sourceType,
650  ArrayRef<ReassociationIndices> reassociationIndices) {
651  SmallVector<std::optional<int64_t>> trivialSegments =
652  getCollapseShapeTrivialSegments(sourceType, reassociationIndices);
653  if (!llvm::any_of(trivialSegments, [](const std::optional<int64_t> &idx) {
654  return idx.has_value();
655  }))
656  return failure();
657  return trivialSegments;
658 }
659 
660 FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
661 mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
662  RankedTensorType sourceType,
663  ArrayRef<ReassociationIndices> reassociationIndices) {
664  FailureOr<SmallVector<std::optional<int64_t>>> trivialSegments =
666  reassociationIndices);
667  if (failed(trivialSegments))
668  return failure();
669 
670  // Create the expected result shape of the rank-reducing slice.
671  SmallVector<int64_t> sliceShape;
672  for (const auto &[nonUnitDim, indices] :
673  llvm::zip(*trivialSegments, reassociationIndices)) {
674  if (nonUnitDim) {
675  sliceShape.push_back(sourceType.getDimSize(*nonUnitDim));
676  continue;
677  }
678  llvm::append_range(sliceShape, llvm::map_range(indices, [&](int64_t idx) {
679  return sourceType.getDimSize(idx);
680  }));
681  }
682  auto sliceType =
683  RankedTensorType::get(sliceShape, sourceType.getElementType());
684 
685  // If the rank-reducing slice simplified every segment, then we are done.
686  if (sliceShape.size() == reassociationIndices.size())
687  return CollapseShapeRankReducingSliceSimplificationInfo{sliceType,
688  std::nullopt};
689 
690  // Otherwise, we need to create a new collapse_shape op for the segments that
691  // weren't covered by the slice. By design, the new reassociation indices has
692  // the same number of groups as the old reassociation indices.
693  SmallVector<ReassociationIndices> newReassociationIndices;
694  SmallVector<int64_t, 2> reassociation;
695  int64_t groupIdx = 0;
696  for (int64_t dimIdx = 0; dimIdx < sliceType.getRank(); dimIdx++) {
697  reassociation.push_back(dimIdx);
698  if ((*trivialSegments)[groupIdx] ||
699  reassociation.size() == reassociationIndices[groupIdx].size()) {
700  newReassociationIndices.push_back(reassociation);
701  reassociation.clear();
702  groupIdx++;
703  }
704  }
705 
706  return CollapseShapeRankReducingSliceSimplificationInfo{
707  sliceType, newReassociationIndices};
708 }
709 
710 PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
711  ArrayRef<int64_t> innerDimPos) {
712  PackingMetadata res;
713  res.insertPositions.reserve(innerDimPos.size());
714  // The pack insert position is the position + the number of previously
715  // inserted positions + offset.
716  // The offset controls whether the packing dimension is the first or last.
717  //
718  // Example
719  // =======
720  // Consider packing from a hypothetical ABCD layout to ABCDba whose
721  // pack.inner_dims is [1, 0]. The first step consists in undoing the
722  // permutation and producing AaBbCD. This is achieved purely by computing the
723  // insert positions of `b` and `a` into `ABCD`, starting from [1, 0]. One
724  // possibility, is to produce insert positions [2, 0], this would result in an
725  // aAbBCD layout (i.e. offset 0). The other possibility, is to produce insert
726  // positions [3, 1], this would result in an AaBbCD layout (i.e. offset 1).
727  // The latter is what we expect from packing.
728  int64_t offset = 1;
729  for (int64_t pos : innerDimPos) {
730  int64_t numInsertedBefore = llvm::count_if(
731  innerDimPos, [&pos](int64_t pos2) { return pos > pos2; });
732  res.insertPositions.push_back(pos + numInsertedBefore + offset);
733  }
734 
735  DenseSet<int64_t> posSet(res.insertPositions.begin(),
736  res.insertPositions.end());
737  res.reassociations.reserve(packedRank);
738  for (int64_t i = 1; i <= packedRank; ++i) {
739  res.outerPositions.push_back(i - 1);
740  if (!posSet.contains(i)) {
741  res.reassociations.push_back(ReassociationIndices{i - 1});
742  continue;
743  }
744  res.reassociations.push_back(ReassociationIndices{i - 1, i});
745  ++i;
746  }
747  return res;
748 }
749 
750 OpFoldResult mlir::reshapeConstantSource(DenseElementsAttr source,
751  TensorType result,
752  std::optional<Attribute> cst) {
753  if (source && source.isSplat() && result.hasStaticShape() &&
754  (!cst.has_value() || source.getSplatValue<Attribute>() == cst.value()))
755  return source.resizeSplat(result);
756 
757  return {};
758 }
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
unsigned getMaxPosOfType(ArrayRef< ReassociationExprs > exprArrays)
static FailureOr< ReassociationIndexRange > findReassociationRangeForSize(ArrayRef< int64_t > sourceShape, int64_t sourceStartIdx, int64_t targetSize, bool matchGreedily=false)
Starting from sourceStartIdx, searches sourceShape for the first sequence of static dimensions such t...
static SmallVector< std::optional< int64_t > > getCollapseShapeTrivialSegments(RankedTensorType sourceType, ArrayRef< ReassociationIndices > reassociationIndices)
static FailureOr< ReassociationIndexRange > findReassociationRangeForDynamicDim(ArrayRef< int64_t > sourceShape, int64_t sourceStartIdx, bool matchGreedily=false)
Starting from sourceStartIdx, searches sourceShape for the first sequence that can be collapsed into ...
static std::optional< int64_t > getUniqueNonUnitDim(ArrayRef< int64_t > indices, ArrayRef< int64_t > shape)
Returns the index of the only non-unit dimension among indices of shape, if such a dimension exists a...
static FailureOr< SmallVector< ReassociationIndexRange > > findReassociationRangesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)
Attempts to find a valid collapsing reassociation of sourceShape into targetShape through a simple tr...
static FailureOr< SmallVector< std::optional< int64_t > > > canCollapseShapeBeSimplifiedByRankReducingSlice(RankedTensorType sourceType, ArrayRef< ReassociationIndices > reassociationIndices)
Returns true if any of the segments of the reassociation indices for a collapsing reshape can be simp...
Base type for affine expression.
Definition: AffineExpr.h:68
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:261
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:276
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
llvm::SmallBitVector getSlicedDimensions(ArrayRef< OpFoldResult > sliceInputShape, ArrayRef< Range > sliceParams)
The input parameters offsets, sizes, strides specify a rectangular non rank-reducing slice of the col...
bool hasNonIdentityLayout(Type type)
Returns true iff the type is a MemRefType and has a non-identity layout.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< ReassociationIndices, 2 > convertReassociationMapsToIndices(ArrayRef< ReassociationExprs > reassociationExprs)
Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rule: if a dimension in the collapsed type i...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)
Returns the reassociation maps to collapse sourceShape to targetShape if possible.
ArrayRef< int64_t > ReassociationIndicesRef
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
std::optional< SmallVector< ReassociationIndices > > composeReassociationIndices(ArrayRef< ReassociationIndices > producerReassociations, ArrayRef< ReassociationIndices > consumerReassociations, MLIRContext *context)
Compose reassociation maps that are used in pair of reshape ops where one is a producer and other is ...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getLinearizedDimensions(ArrayRef< ReassociationIndices > reassociationIndices)
Determine which dimensions are linearized by a tensor.collapse_shape op by inspecting its reassociati...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult size