MLIR 23.0.0git
ReshapeOpsUtils.cpp
Go to the documentation of this file.
1//===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
11#include "mlir/IR/AffineMap.h"
12#include "mlir/IR/Builders.h"
14#include "llvm/ADT/ArrayRef.h"
15#include "llvm/ADT/SmallVector.h"
16#include "llvm/ADT/SmallVectorExtras.h"
17
18#include <numeric>
19#include <optional>
20
21using namespace mlir;
22
23std::optional<SmallVector<ReassociationIndices>>
25 ShapedType targetType) {
26 if (sourceType.getRank() > targetType.getRank())
27 return getReassociationIndicesForCollapse(sourceType.getShape(),
28 targetType.getShape());
29 if (sourceType.getRank() < targetType.getRank())
30 return getReassociationIndicesForCollapse(targetType.getShape(),
31 sourceType.getShape());
32 return std::nullopt;
33}
34
35namespace {
36/// A simple struct to represent ReassociationIndices as an inclusive interval.
37/// It's designed to be feasibly minimal, so the call sites should manage the
38/// validity of the range manually.
39struct ReassociationIndexRange {
40 /// FIXME: Signed type is used for consistency with ReassociationIndices.
41 /// We should consider refactoring all reassociation utilities to use unsigned
42 /// types.
43 int64_t leftIdx = 0, rightIdx = 0;
44
45 /// Util for manual checks of the range's validity
46 LogicalResult verify() const {
47 return leftIdx >= 0 && (leftIdx <= rightIdx) ? success() : failure();
48 }
49
50 /// Checks range's containment within another range. Treats the edges
51 /// non-exclusively.
52 bool isInRange(const ReassociationIndexRange &outerRange) const {
53 return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx;
54 }
55
56 unsigned size() const {
57 assert(succeeded(verify()));
58 return rightIdx - leftIdx + 1;
59 }
60 bool containsSingleIndex() const { return size() == 1; }
61
62 /// Collects indices that do not overlap between this and another range.
64 getNonOverlappingIndicesWith(ReassociationIndexRange &rhs) const {
65 if (rightIdx < rhs.leftIdx) {
66 // The intervals do not overlap - concatenate the indices from both.
67 auto jointFullIndices = getFullIndices();
68 jointFullIndices.append(rhs.getFullIndices());
69 return jointFullIndices;
70 }
72 // Handle the chunk left of the overlapping range.
73 int64_t leftStart = std::min(leftIdx, rhs.leftIdx);
74 int64_t leftEnd = std::max(leftIdx, rhs.leftIdx);
75 llvm::append_range(result, llvm::seq(leftStart, leftEnd));
76 // Handle the chunk right of the overlapping range. Symmetrically, we should
77 // skip the edge of the overlap AND include the rightmost index.
78 int64_t rightStart = std::min(rightIdx, rhs.rightIdx) + 1;
79 int64_t rightEnd = std::max(rightIdx, rhs.rightIdx);
80 if (rightStart < rightEnd)
81 llvm::append_range(result, llvm::seq_inclusive(rightStart, rightEnd));
82 return result;
83 }
84
85 /// Converts the range into ReassociationIndices.
86 ReassociationIndices getFullIndices() const {
88 for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) {
89 result.push_back(idx);
90 }
91 return result;
92 }
93};
94} // namespace
95
96/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
97/// sequence that can be collapsed into a dynamic dimension (at least one must
98/// be present in the source).
99/// By default, lazily returns once the first dynamic dimension has been found.
100/// Setting `matchGreedily` as `true` will also mark all subsequent
101/// source dimensions for collapsing into the target.
102static FailureOr<ReassociationIndexRange>
104 int64_t sourceStartIdx,
105 bool matchGreedily = false) {
106 const unsigned numSourceDims = sourceShape.size();
107 ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
108 std::optional<ReassociationIndexRange> resultRange = std::nullopt;
109
110 ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
111 for (; iterationRange.isInRange(sourceShapeAsRange);
112 iterationRange.rightIdx++) {
113 int64_t sourceSize = sourceShape[iterationRange.rightIdx];
114 if (sourceSize == ShapedType::kDynamic) {
115 resultRange = iterationRange;
116 break;
117 }
118 }
119 if (!resultRange)
120 return failure();
121 if (matchGreedily)
122 resultRange->rightIdx = sourceShapeAsRange.rightIdx;
123 return *resultRange;
124}
125
126/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
127/// sequence of static dimensions such that their product matches `targetSize`.
128/// By default, lazily returns once the product matches the target size. Setting
129/// `matchGreedily` as `true` will append all neighboring unit dimensions
130/// (dimensions of 1) to the match.
131static FailureOr<ReassociationIndexRange>
133 int64_t sourceStartIdx, int64_t targetSize,
134 bool matchGreedily = false) {
135 const unsigned numSourceDims = sourceShape.size();
136 ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
137 std::optional<ReassociationIndexRange> resultRange = std::nullopt;
138
139 ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
140 int64_t prodOfCollapsedDims = 1;
141 while (iterationRange.isInRange(sourceShapeAsRange)) {
142 int64_t sourceSize = sourceShape[iterationRange.rightIdx];
143 if (sourceSize == ShapedType::kDynamic) {
144 // Reassociation for a static dim cannot include a dynamic dim. Reset
145 // induction variables to essentially restart the loop from the next
146 // source dimension.
147 prodOfCollapsedDims = 1;
148 iterationRange = {iterationRange.rightIdx + 1,
149 iterationRange.rightIdx + 1};
150 continue;
151 }
152 prodOfCollapsedDims *= sourceSize;
153 // If the target size has been exceeded without matching, we need to shift
154 // the range start right. From the start of the range, roll back the
155 // multiplication until the target size exceeds the product again.
156 while (prodOfCollapsedDims > targetSize &&
157 !iterationRange.containsSingleIndex()) {
158 int64_t frontSourceSize = sourceShape[iterationRange.leftIdx];
159 prodOfCollapsedDims /= frontSourceSize;
160 // Shrink the range rightwards
161 iterationRange.leftIdx++;
162 }
163 // We could've reached the target size with the current dimension,
164 // also as a result of the above shift to right.
165 if (prodOfCollapsedDims == targetSize) {
166 resultRange = iterationRange;
167 break;
168 }
169 // Increment the iteration range
170 iterationRange.rightIdx++;
171 }
172 if (!resultRange)
173 return failure();
174 if (matchGreedily) {
175 // We now want to collect all unit dimensions directly after the target
176 // product match. Advance the iterator to avoid OOB when the product match
177 // happens at the last element.
178 iterationRange.rightIdx++;
179 while (iterationRange.isInRange(sourceShapeAsRange) &&
180 sourceShape[iterationRange.rightIdx] == 1) {
181 resultRange = iterationRange;
182 iterationRange.rightIdx++;
183 }
184 }
185 return *resultRange;
186}
187
188/// Attempts to find a valid collapsing reassociation of `sourceShape` into
189/// `targetShape` through a simple traversal. If successful, an array of source
190/// index ranges is returned, correspondingly to each dimension in the target
191/// shape. The resulting indices shall fully cover the `sourceShape` without
192/// overlaps.
193///
194/// The algorithm is essentially a lazy one, searching for non-greedy matches -
195/// it will only yield a greedy match for the last target dimension.
196/// FIXME: The algorithm can only backtrack when it needs to append an offset
197/// for a static target dimension to the preceding dynamic one (this retains the
198/// linear complexity). As feasible, consider adding further backtracking
199/// routines to enable more reassociations, e.g.:
200/// - ?x2x?x2 into ?x2
201static FailureOr<SmallVector<ReassociationIndexRange>>
203 ArrayRef<int64_t> targetShape) {
204 unsigned numSourceDims = sourceShape.size(),
205 numTargetDims = targetShape.size();
206 assert(numSourceDims > numTargetDims);
207 ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
208
210 reassocRanges.reserve(numTargetDims);
211 // We'll iterate in strides of 2 to enable pseudo-backtracking for simple
212 // cases, e.g.:
213 // - ?x2x3x5 into ?x15
214 std::optional<int64_t> prevTargetSize = std::nullopt;
215 for (unsigned targetDimIdx = 0, sourceDimIdx = 0;
216 targetDimIdx < numTargetDims; ++targetDimIdx) {
217 int64_t targetSize = targetShape[targetDimIdx];
218 // Simply check if there are any subsequent target dimensions left - if not,
219 // the match must be made greedily.
220 bool shouldMatchGreedily = targetDimIdx == numTargetDims - 1;
221 FailureOr<ReassociationIndexRange> sourceRange;
222 if (targetSize == ShapedType::kDynamic) {
224 sourceShape, sourceDimIdx, shouldMatchGreedily);
225 } else {
226 sourceRange = findReassociationRangeForSize(
227 sourceShape, sourceDimIdx, targetSize, shouldMatchGreedily);
228 }
229
230 // Run sanity checks on the returned index range.
231 if (failed(sourceRange) || failed(sourceRange->verify()) ||
232 !sourceRange->isInRange(sourceShapeAsRange))
233 return failure();
234 if (sourceRange->leftIdx > sourceDimIdx) {
235 // If some source dimensions had to be skipped in order to find a match,
236 // they must be collapsed into the directly preceding dynamic dimension.
237 if (!prevTargetSize || prevTargetSize != ShapedType::kDynamic)
238 return failure();
239 reassocRanges.back().rightIdx = sourceRange->leftIdx - 1;
240 }
241
242 // Store the gathered information as required for the next iteration.
243 prevTargetSize = targetSize;
244 sourceDimIdx = sourceRange->rightIdx + 1;
245 reassocRanges.push_back(*sourceRange);
246 }
247 // Fail if the source shape wasn't a full match for the target shape. We only
248 // need to check the last recorded index - any other gaps should have been
249 // mended by the main loop.
250 if (reassocRanges.back().rightIdx < sourceShapeAsRange.rightIdx)
251 return failure();
252 return reassocRanges;
253}
254
255/// A variant of `findReassociationRangesForCollapse(...)` that can also scan
256/// the shapes right-to-left.
257static FailureOr<SmallVector<ReassociationIndexRange>>
259 ArrayRef<int64_t> targetShape,
260 bool iterateRightToLeft) {
261 if (!iterateRightToLeft)
262 return findReassociationRangesForCollapse(sourceShape, targetShape);
263 // NB: To iterate right-to-left, we currently reverse the shapes and then
264 // reverse the result back. The reversed shapes must not be temporary, as
265 // we're passing through an ArrayRef.
266 // FIXME: It would be preferable to avoid the expensive copies. At the moment,
267 // this approach is chosen for readability of the main implementation.
268 std::vector<int64_t> sourceToReverse = sourceShape.vec(),
269 targetToReverse = targetShape.vec();
270 std::reverse(sourceToReverse.begin(), sourceToReverse.end());
271 std::reverse(targetToReverse.begin(), targetToReverse.end());
272 auto invertedRanges =
273 findReassociationRangesForCollapse(sourceToReverse, targetToReverse);
274 if (failed(invertedRanges))
275 return failure();
276 SmallVector<ReassociationIndexRange> &rangesToInvert = *invertedRanges;
277 unsigned numSourceDims = sourceShape.size();
278 // We have received the ranges for inverted shapes. Now we have to invert
279 // the ranges back to correspond with the original source shape.
280 for (auto &range : rangesToInvert) {
281 int64_t invLeftIdx = range.leftIdx, invRightIdx = range.rightIdx;
282 range.leftIdx = numSourceDims - 1 - invRightIdx;
283 range.rightIdx = numSourceDims - 1 - invLeftIdx;
284 }
285 // Also invert the ordering of the ranges to correspond with the original
286 // target shape.
287 std::reverse(rangesToInvert.begin(), rangesToInvert.end());
288 return rangesToInvert;
289}
290
291std::optional<SmallVector<ReassociationIndices>>
293 ArrayRef<int64_t> targetShape) {
294 unsigned numSourceDims = sourceShape.size(),
295 numTargetDims = targetShape.size();
296 // We're supposed to search for a collapsing reassociation. If the sizes
297 // match, there's no actual collapsing taking place - it's either a no-op or a
298 // `tensor.reshape`-style reassociation (that would be beyond the scope of
299 // this utility).
300 if (numSourceDims <= numTargetDims)
301 return std::nullopt;
302 // Early handling for scalar target types. We should report an invalid
303 // reassociation for non-unit static dimensions - no chance to collapse these
304 // into a scalar.
305 if (numTargetDims == 0) {
306 for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;
307 ++sourceDimIdx) {
308 int64_t sourceSize = sourceShape[sourceDimIdx];
309 if (sourceSize != 1 && sourceSize != ShapedType::kDynamic)
310 return std::nullopt;
311 }
313 }
314
315 // Collect source ranges by iterating over the target shape left-to-right.
316 FailureOr<SmallVector<ReassociationIndexRange>> maybeForwardRanges =
317 findReassociationRangesForCollapse(sourceShape, targetShape);
318 if (failed(maybeForwardRanges))
319 return std::nullopt;
320 auto &ranges = *maybeForwardRanges;
321 // Now do the same in reverse. We need to get another valid reassociation
322 // through some other strategy, and then compare the results in order to
323 // disambiguate mixed subshapes, such as:
324 // ?x?x? into ?x?, ?x2x? into ?x?, ?x2x3x6x? into ?x6x?
325 // This leads us to lose some of the reassociation opportunities that can only
326 // be found by iterating in a certain direction, e.g. 2x2x? into 2x? - without
327 // backtracking, the algorithm will fail right-to-left. However, this is the
328 // best way to preserve correctness.
329 FailureOr<SmallVector<ReassociationIndexRange>> maybeReverseRanges =
330 findReassociationRangesForCollapse(sourceShape, targetShape,
331 /*iterateRightToLeft=*/true);
332 if (failed(maybeReverseRanges))
333 return std::nullopt;
334 auto &reverseRanges = *maybeReverseRanges;
335
336 if (ranges.size() != numTargetDims || reverseRanges.size() != numTargetDims)
337 return std::nullopt;
338 // Now we can check for ambiguity of each target dimension's reassociation. If
339 // successful, we put the full indices into our result map for the target
340 // shape.
341 SmallVector<ReassociationIndices> reassociationMap(numTargetDims);
342 for (unsigned targetDimIdx = 0; targetDimIdx < numTargetDims;
343 ++targetDimIdx) {
344 ReassociationIndexRange &range = ranges[targetDimIdx];
345 ReassociationIndexRange &reverseRange = reverseRanges[targetDimIdx];
346 // Get non-overlapping indices between the ranges
347 ReassociationIndices nonMatchingIndices =
348 range.getNonOverlappingIndicesWith(reverseRange);
349 // Unit dimensions can be collapsed wherever - this is the only ambiguity
350 // that we allow.
351 for (int64_t sourceDimIdx : nonMatchingIndices) {
352 if (sourceShape[sourceDimIdx] != 1)
353 return std::nullopt;
354 }
355 reassociationMap[targetDimIdx] = range.getFullIndices();
356 }
357 return reassociationMap;
358}
359
360std::optional<SmallVector<ReassociationIndices>>
362 ArrayRef<ReassociationIndices> producerReassociations,
363 ArrayRef<ReassociationIndices> consumerReassociations,
364 MLIRContext *context) {
365 SmallVector<ReassociationIndices> composedIndices;
366 // Make the producer the larger sized vector. If they are of same size, the
367 // resulting reshape is not a supported reshape op.
368 if (producerReassociations.size() == consumerReassociations.size())
369 return std::nullopt;
370 if (producerReassociations.size() < consumerReassociations.size())
371 std::swap(producerReassociations, consumerReassociations);
372
373 // Handle the corner case of the result being a rank 0 shaped type. Return an
374 // empty reassociation.
375 if (consumerReassociations.empty())
376 return composedIndices;
377
378 size_t consumerDims =
379 llvm::accumulate(consumerReassociations, size_t(0),
380 [](size_t all, ReassociationIndicesRef indices) {
381 return all + indices.size();
382 });
383 if (producerReassociations.size() != consumerDims)
384 return std::nullopt;
385
386 for (ReassociationIndicesRef consumerIndices : consumerReassociations) {
387 ReassociationIndices reassociations;
388 for (int64_t consumerIndex : consumerIndices) {
389 llvm::append_range(reassociations, producerReassociations[consumerIndex]);
390 }
391 composedIndices.push_back(std::move(reassociations));
392 }
393 return composedIndices;
394}
395
398 MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices) {
399 SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps;
400 for (const auto &indices : reassociationIndices) {
401 SmallVector<AffineExpr, 2> reassociationMap;
402 reassociationMap.reserve(indices.size());
403 for (int64_t index : indices)
404 reassociationMap.push_back(mlir::getAffineDimExpr(index, context));
405 reassociationMaps.push_back(std::move(reassociationMap));
406 }
407 return reassociationMaps;
408}
409
410template <typename AffineExprTy>
412 unsigned pos = 0;
413 for (const auto &exprs : exprArrays) {
414 for (auto expr : exprs) {
415 expr.walk([&pos](AffineExpr e) {
416 if (auto d = dyn_cast<AffineExprTy>(e))
417 pos = std::max(pos, d.getPosition());
418 });
419 }
420 }
421 return pos;
422}
423
425 Builder &b, ArrayRef<ReassociationIndices> reassociation) {
426 SmallVector<Attribute, 4> reassociationAttr = llvm::map_to_vector<4>(
427 reassociation, [&](const ReassociationIndices &indices) -> Attribute {
428 return cast<Attribute>(b.getI64ArrayAttr(indices));
429 });
430 return b.getArrayAttr(reassociationAttr);
431}
432
434 ArrayRef<ReassociationExprs> reassociationExprs) {
435 SmallVector<ReassociationIndices, 2> reassociationIndices;
436 for (const auto &exprs : reassociationExprs) {
438 indices.reserve(exprs.size());
439 for (const auto &expr : exprs)
440 indices.push_back(cast<AffineDimExpr>(expr).getPosition());
441 reassociationIndices.push_back(indices);
442 }
443 return reassociationIndices;
444}
445
448 unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
449 assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
450 "Expected symbol-less expressions");
452 maps.reserve(reassociation.size());
453 for (const auto &exprs : reassociation) {
454 assert(!exprs.empty());
455 maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext()));
456 }
457 return maps;
458}
459
461 int *invalidIndex) {
462 if (reassociation.empty())
463 return true;
464 unsigned nDims = reassociation[0].getNumDims();
465 unsigned nextExpectedDim = 0;
466 for (const auto &it : llvm::enumerate(reassociation)) {
467 auto m = it.value();
468 if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
469 if (invalidIndex)
470 *invalidIndex = it.index();
471 return false;
472 }
473 for (auto e : m.getResults()) {
474 auto d = dyn_cast<AffineDimExpr>(e);
475 if (!d || d.getPosition() != nextExpectedDim++) {
476 if (invalidIndex)
477 *invalidIndex = it.index();
478 return false;
479 }
480 }
481 }
482 if (nextExpectedDim != nDims) {
483 if (invalidIndex)
484 *invalidIndex = reassociation.size() - 1;
485 return false;
486 }
487 return true;
488}
489
490LogicalResult mlir::reshapeLikeShapesAreCompatible(
491 function_ref<LogicalResult(const Twine &)> emitError,
492 ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
493 ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape) {
494 unsigned expandedDimStart = 0;
495 for (const auto &map : llvm::enumerate(reassociationMaps)) {
496 bool foundDynamicShape = false;
497 int64_t linearizedStaticShape = 1;
498
499 for (const auto &dim : llvm::enumerate(
500 expandedShape.slice(expandedDimStart, map.value().size()))) {
501 if (ShapedType::isDynamic(dim.value()))
502 foundDynamicShape = true;
503 else
504 linearizedStaticShape *= dim.value();
505 }
506 if (foundDynamicShape) {
507 if (ShapedType::isStatic(collapsedShape[map.index()])) {
508 return emitError(
509 "expected dimension " + Twine(map.index()) +
510 " of collapsed type to be dynamic since one or more of the "
511 "corresponding dimensions in the expanded type is dynamic");
512 }
513 } else {
514 if (collapsedShape[map.index()] != linearizedStaticShape) {
515 return emitError("expected dimension " + Twine(map.index()) +
516 " of collapsed type to be static value of " +
517 Twine(linearizedStaticShape));
518 }
519 }
520 expandedDimStart += map.value().size();
521 }
522 return success();
523}
524
525bool mlir::hasNonIdentityLayout(Type type) {
526 if (auto memrefType = dyn_cast<MemRefType>(type))
527 return !memrefType.getLayout().isIdentity();
528 return false;
529}
530
531llvm::SmallBitVector
533 ArrayRef<Range> sliceParams) {
534 assert(sliceParams.size() == sliceInputShape.size() &&
535 "only supports non rank-reducing case");
536 llvm::SmallBitVector mask(sliceInputShape.size());
537 unsigned idx = 0;
538 for (const auto &[offset, size, stride] : sliceParams) {
539 std::optional<int64_t> offsetConst = getConstantIntValue(offset);
540 std::optional<int64_t> strideConst = getConstantIntValue(stride);
541 mask[idx] = !isEqualConstantIntOrValue(size, sliceInputShape[idx]) ||
542 (!strideConst || *strideConst != 1) ||
543 (!offsetConst || *offsetConst != 0);
544 idx++;
545 }
546 return mask;
547}
548
549llvm::SmallBitVector mlir::getLinearizedDimensions(
550 ArrayRef<ReassociationIndices> reassociationIndices) {
551 llvm::SmallBitVector result(reassociationIndices.size());
552 for (const auto &it : llvm::enumerate(reassociationIndices))
553 result[it.index()] = it.value().size() > 1;
554 return result;
555}
556
557SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
558 MLIRContext *ctx, ArrayRef<ValueRange> multiIndices) {
559 unsigned loopIdx = 0;
560 auto oneAttr = IntegerAttr::get(IndexType::get(ctx), 1);
561 auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0);
562 SmallVector<Range> offsetsSizesAndStrides;
563 offsetsSizesAndStrides.reserve(collapseShapeInputShape.size());
564 for (const auto &it : llvm::enumerate(reassociationIndices)) {
565 // Case 1: Linearized dimensions that have also been sliced. These
566 // are size of 1 because we are iterating over these dimensions. The
567 // offsets are exactly the de-linearized multi-indices.
568 if (slicedDimensions[it.index()] && linearizedDimensions[it.index()]) {
569 llvm::append_range(
570 offsetsSizesAndStrides,
571 llvm::map_range(multiIndices[loopIdx++], [&](Value v) -> Range {
572 return Range{getAsOpFoldResult(v), oneAttr, oneAttr};
573 }));
574 continue;
575 }
576
577 // Case 2: One or possibly multiple combined input dimensions, but we
578 // have proven that these are not sliced. In this case we just take
579 // the full extent of each dimension in the reassociation list.
580 if (linearizedDimensions[it.index()]) {
581 llvm::append_range(offsetsSizesAndStrides,
582 llvm::map_range(it.value(), [&](int64_t idx) -> Range {
583 return {zeroAttr, collapseShapeInputShape[idx],
584 oneAttr};
585 }));
586 continue;
587 }
588
589 // Case 3: A single index, but it may be sliced.
590 offsetsSizesAndStrides.push_back(sliceParams[it.index()]);
591 }
592 return offsetsSizesAndStrides;
593}
594
596SliceFromCollapseHelper::getInsertSliceParams(MLIRContext *ctx,
597 ValueRange tileIndices) {
598 auto one = IntegerAttr::get(IndexType::get(ctx), 1);
599 auto zero = IntegerAttr::get(IndexType::get(ctx), 0);
600 SmallVector<Range> insertParams;
601 insertParams.reserve(linearizedDimensions.size());
602 unsigned loopIdx = 0;
603 for (unsigned i = 0; i < linearizedDimensions.size(); i++) {
604 if (linearizedDimensions[i] && slicedDimensions[i]) {
605 insertParams.push_back(Range{tileIndices[loopIdx++], one, one});
606 continue;
607 }
608 insertParams.push_back(Range{zero, sliceParams[i].size, one});
609 }
610 return insertParams;
611}
612
613/// Returns the index of the only non-unit dimension among `indices` of `shape`,
614/// if such a dimension exists and `indices` has more than one element.
615/// Otherwise, return std::nullopt.
616static std::optional<int64_t> getUniqueNonUnitDim(ArrayRef<int64_t> indices,
618 // Return false if more than one of the dimensions in this group are not 1.
619 std::optional<int64_t> dimIndex;
620 if (indices.size() < 2)
621 return std::nullopt;
622 for (int64_t idx : indices) {
623 if (shape[idx] != 1) {
624 if (dimIndex != std::nullopt)
625 return std::nullopt;
626 dimIndex = idx;
627 }
628 }
629 return dimIndex;
630}
631
632// For each segment in the reassociation indices, check whether we can
633// simplify that segment with a rank-reducing extract slice. We can do this if
634// all but (exactly) one of the corresponding source dims is 1.
636 RankedTensorType sourceType,
637 ArrayRef<ReassociationIndices> reassociationIndices) {
638 SmallVector<std::optional<int64_t>> trivialSegments;
639 for (const auto &indices : reassociationIndices)
640 trivialSegments.push_back(
641 getUniqueNonUnitDim(indices, sourceType.getShape()));
642 return trivialSegments;
643}
644
645/// Returns true if any of the segments of the reassociation indices for a
646/// collapsing reshape can be simplified using a rank-reducing slice.
647static FailureOr<SmallVector<std::optional<int64_t>>>
649 RankedTensorType sourceType,
650 ArrayRef<ReassociationIndices> reassociationIndices) {
651 SmallVector<std::optional<int64_t>> trivialSegments =
652 getCollapseShapeTrivialSegments(sourceType, reassociationIndices);
653 if (!llvm::any_of(trivialSegments, [](const std::optional<int64_t> &idx) {
654 return idx.has_value();
655 }))
656 return failure();
657 return trivialSegments;
658}
659
660FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
661mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
662 RankedTensorType sourceType,
663 ArrayRef<ReassociationIndices> reassociationIndices) {
664 FailureOr<SmallVector<std::optional<int64_t>>> trivialSegments =
666 reassociationIndices);
667 if (failed(trivialSegments))
668 return failure();
669
670 // Create the expected result shape of the rank-reducing slice.
671 SmallVector<int64_t> sliceShape;
672 for (const auto &[nonUnitDim, indices] :
673 llvm::zip(*trivialSegments, reassociationIndices)) {
674 if (nonUnitDim) {
675 sliceShape.push_back(sourceType.getDimSize(*nonUnitDim));
676 continue;
677 }
678 llvm::append_range(sliceShape, llvm::map_range(indices, [&](int64_t idx) {
679 return sourceType.getDimSize(idx);
680 }));
681 }
682 auto sliceType =
683 RankedTensorType::get(sliceShape, sourceType.getElementType());
684
685 // If the rank-reducing slice simplified every segment, then we are done.
686 if (sliceShape.size() == reassociationIndices.size())
687 return CollapseShapeRankReducingSliceSimplificationInfo{sliceType,
688 std::nullopt};
689
690 // Otherwise, we need to create a new collapse_shape op for the segments that
691 // weren't covered by the slice. By design, the new reassociation indices has
692 // the same number of groups as the old reassociation indices.
693 SmallVector<ReassociationIndices> newReassociationIndices;
694 SmallVector<int64_t, 2> reassociation;
695 int64_t groupIdx = 0;
696 for (int64_t dimIdx = 0; dimIdx < sliceType.getRank(); dimIdx++) {
697 reassociation.push_back(dimIdx);
698 if ((*trivialSegments)[groupIdx] ||
699 reassociation.size() == reassociationIndices[groupIdx].size()) {
700 newReassociationIndices.push_back(reassociation);
701 reassociation.clear();
702 groupIdx++;
703 }
704 }
705
706 return CollapseShapeRankReducingSliceSimplificationInfo{
707 sliceType, newReassociationIndices};
708}
709
710PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
711 ArrayRef<int64_t> innerDimPos) {
712 PackingMetadata res;
713 res.insertPositions.reserve(innerDimPos.size());
714 // The pack insert position is the position + the number of previously
715 // inserted positions + offset.
716 // The offset controls whether the packing dimension is the first or last.
717 //
718 // Example
719 // =======
720 // Consider packing from a hypothetical ABCD layout to ABCDba whose
721 // pack.inner_dims is [1, 0]. The first step consists in undoing the
722 // permutation and producing AaBbCD. This is achieved purely by computing the
723 // insert positions of `b` and `a` into `ABCD`, starting from [1, 0]. One
724 // possibility, is to produce insert positions [2, 0], this would result in an
725 // aAbBCD layout (i.e. offset 0). The other possibility, is to produce insert
726 // positions [3, 1], this would result in an AaBbCD layout (i.e. offset 1).
727 // The latter is what we expect from packing.
728 int64_t offset = 1;
729 for (int64_t pos : innerDimPos) {
730 int64_t numInsertedBefore = llvm::count_if(
731 innerDimPos, [&pos](int64_t pos2) { return pos > pos2; });
732 res.insertPositions.push_back(pos + numInsertedBefore + offset);
733 }
734
735 DenseSet<int64_t> posSet(res.insertPositions.begin(),
736 res.insertPositions.end());
737 res.reassociations.reserve(packedRank);
738 for (int64_t i = 1; i <= packedRank; ++i) {
739 res.outerPositions.push_back(i - 1);
740 if (!posSet.contains(i)) {
741 res.reassociations.push_back(ReassociationIndices{i - 1});
742 continue;
743 }
744 res.reassociations.push_back(ReassociationIndices{i - 1, i});
745 ++i;
746 }
747 return res;
748}
749
750OpFoldResult mlir::reshapeConstantSource(DenseElementsAttr source,
752 std::optional<Attribute> cst) {
753 if (source && source.isSplat() && result.hasStaticShape() &&
754 (!cst.has_value() || source.getSplatValue<Attribute>() == cst.value()))
755 return source.resizeSplat(result);
756
757 return {};
758}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static FailureOr< ReassociationIndexRange > findReassociationRangeForDynamicDim(ArrayRef< int64_t > sourceShape, int64_t sourceStartIdx, bool matchGreedily=false)
Starting from sourceStartIdx, searches sourceShape for the first sequence that can be collapsed into ...
static SmallVector< std::optional< int64_t > > getCollapseShapeTrivialSegments(RankedTensorType sourceType, ArrayRef< ReassociationIndices > reassociationIndices)
static std::optional< int64_t > getUniqueNonUnitDim(ArrayRef< int64_t > indices, ArrayRef< int64_t > shape)
Returns the index of the only non-unit dimension among indices of shape, if such a dimension exists a...
static unsigned getMaxPosOfType(ArrayRef< ReassociationExprs > exprArrays)
static FailureOr< SmallVector< std::optional< int64_t > > > canCollapseShapeBeSimplifiedByRankReducingSlice(RankedTensorType sourceType, ArrayRef< ReassociationIndices > reassociationIndices)
Returns true if any of the segments of the reassociation indices for a collapsing reshape can be simp...
static FailureOr< ReassociationIndexRange > findReassociationRangeForSize(ArrayRef< int64_t > sourceShape, int64_t sourceStartIdx, int64_t targetSize, bool matchGreedily=false)
Starting from sourceStartIdx, searches sourceShape for the first sequence of static dimensions such t...
static FailureOr< SmallVector< ReassociationIndexRange > > findReassociationRangesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)
Attempts to find a valid collapsing reassociation of sourceShape into targetShape through a simple tr...
Base type for affine expression.
Definition AffineExpr.h:68
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
An attribute that represents a reference to a dense vector or tensor object.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents a single result from folding an operation.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
llvm::SmallBitVector getSlicedDimensions(ArrayRef< OpFoldResult > sliceInputShape, ArrayRef< Range > sliceParams)
The input parameters offsets, sizes, strides specify a rectangular non rank-reducing slice of the col...
ArrayRef< int64_t > ReassociationIndicesRef
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:120
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< ReassociationIndices, 2 > convertReassociationMapsToIndices(ArrayRef< ReassociationExprs > reassociationExprs)
Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)
Returns the reassociation maps to collapse sourceShape to targetShape if possible.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
std::optional< SmallVector< ReassociationIndices > > composeReassociationIndices(ArrayRef< ReassociationIndices > producerReassociations, ArrayRef< ReassociationIndices > consumerReassociations, MLIRContext *context)
Compose reassociation maps that are used in pair of reshape ops where one is a producer and other is ...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getLinearizedDimensions(ArrayRef< ReassociationIndices > reassociationIndices)
Determine which dimensions are linearized by a tensor.collapse_shape op by inspecting its reassociati...
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition Verifier.cpp:423
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...