14#include "llvm/ADT/ArrayRef.h"
15#include "llvm/ADT/SmallVector.h"
16#include "llvm/ADT/SmallVectorExtras.h"
23std::optional<SmallVector<ReassociationIndices>>
25 ShapedType targetType) {
26 if (sourceType.getRank() > targetType.getRank())
28 targetType.getShape());
29 if (sourceType.getRank() < targetType.getRank())
31 sourceType.getShape());
39struct ReassociationIndexRange {
43 int64_t leftIdx = 0, rightIdx = 0;
46 LogicalResult
verify()
const {
47 return leftIdx >= 0 && (leftIdx <= rightIdx) ?
success() : failure();
52 bool isInRange(
const ReassociationIndexRange &outerRange)
const {
53 return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx;
56 unsigned size()
const {
57 assert(succeeded(
verify()));
58 return rightIdx - leftIdx + 1;
60 bool containsSingleIndex()
const {
return size() == 1; }
64 getNonOverlappingIndicesWith(ReassociationIndexRange &
rhs)
const {
65 if (rightIdx <
rhs.leftIdx) {
67 auto jointFullIndices = getFullIndices();
68 jointFullIndices.append(
rhs.getFullIndices());
69 return jointFullIndices;
73 int64_t leftStart = std::min(leftIdx,
rhs.leftIdx);
74 int64_t leftEnd = std::max(leftIdx,
rhs.leftIdx);
75 llvm::append_range(
result, llvm::seq(leftStart, leftEnd));
78 int64_t rightStart = std::min(rightIdx,
rhs.rightIdx) + 1;
79 int64_t rightEnd = std::max(rightIdx,
rhs.rightIdx);
80 if (rightStart < rightEnd)
81 llvm::append_range(
result, llvm::seq_inclusive(rightStart, rightEnd));
88 for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) {
102static FailureOr<ReassociationIndexRange>
105 bool matchGreedily =
false) {
106 const unsigned numSourceDims = sourceShape.size();
107 ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
108 std::optional<ReassociationIndexRange> resultRange = std::nullopt;
110 ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
111 for (; iterationRange.isInRange(sourceShapeAsRange);
112 iterationRange.rightIdx++) {
113 int64_t sourceSize = sourceShape[iterationRange.rightIdx];
114 if (sourceSize == ShapedType::kDynamic) {
115 resultRange = iterationRange;
122 resultRange->rightIdx = sourceShapeAsRange.rightIdx;
131static FailureOr<ReassociationIndexRange>
134 bool matchGreedily =
false) {
135 const unsigned numSourceDims = sourceShape.size();
136 ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
137 std::optional<ReassociationIndexRange> resultRange = std::nullopt;
139 ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
140 int64_t prodOfCollapsedDims = 1;
141 while (iterationRange.isInRange(sourceShapeAsRange)) {
142 int64_t sourceSize = sourceShape[iterationRange.rightIdx];
143 if (sourceSize == ShapedType::kDynamic) {
147 prodOfCollapsedDims = 1;
148 iterationRange = {iterationRange.rightIdx + 1,
149 iterationRange.rightIdx + 1};
152 prodOfCollapsedDims *= sourceSize;
156 while (prodOfCollapsedDims > targetSize &&
157 !iterationRange.containsSingleIndex()) {
158 int64_t frontSourceSize = sourceShape[iterationRange.leftIdx];
159 prodOfCollapsedDims /= frontSourceSize;
161 iterationRange.leftIdx++;
165 if (prodOfCollapsedDims == targetSize) {
166 resultRange = iterationRange;
170 iterationRange.rightIdx++;
178 iterationRange.rightIdx++;
179 while (iterationRange.isInRange(sourceShapeAsRange) &&
180 sourceShape[iterationRange.rightIdx] == 1) {
181 resultRange = iterationRange;
182 iterationRange.rightIdx++;
201static FailureOr<SmallVector<ReassociationIndexRange>>
204 unsigned numSourceDims = sourceShape.size(),
205 numTargetDims = targetShape.size();
206 assert(numSourceDims > numTargetDims);
207 ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
210 reassocRanges.reserve(numTargetDims);
214 std::optional<int64_t> prevTargetSize = std::nullopt;
215 for (
unsigned targetDimIdx = 0, sourceDimIdx = 0;
216 targetDimIdx < numTargetDims; ++targetDimIdx) {
217 int64_t targetSize = targetShape[targetDimIdx];
220 bool shouldMatchGreedily = targetDimIdx == numTargetDims - 1;
221 FailureOr<ReassociationIndexRange> sourceRange;
222 if (targetSize == ShapedType::kDynamic) {
224 sourceShape, sourceDimIdx, shouldMatchGreedily);
227 sourceShape, sourceDimIdx, targetSize, shouldMatchGreedily);
231 if (failed(sourceRange) || failed(sourceRange->verify()) ||
232 !sourceRange->isInRange(sourceShapeAsRange))
234 if (sourceRange->leftIdx > sourceDimIdx) {
237 if (!prevTargetSize || prevTargetSize != ShapedType::kDynamic)
239 reassocRanges.back().rightIdx = sourceRange->leftIdx - 1;
243 prevTargetSize = targetSize;
244 sourceDimIdx = sourceRange->rightIdx + 1;
245 reassocRanges.push_back(*sourceRange);
250 if (reassocRanges.back().rightIdx < sourceShapeAsRange.rightIdx)
252 return reassocRanges;
257static FailureOr<SmallVector<ReassociationIndexRange>>
260 bool iterateRightToLeft) {
261 if (!iterateRightToLeft)
268 std::vector<int64_t> sourceToReverse = sourceShape.vec(),
269 targetToReverse = targetShape.vec();
270 std::reverse(sourceToReverse.begin(), sourceToReverse.end());
271 std::reverse(targetToReverse.begin(), targetToReverse.end());
272 auto invertedRanges =
274 if (failed(invertedRanges))
277 unsigned numSourceDims = sourceShape.size();
280 for (
auto &range : rangesToInvert) {
281 int64_t invLeftIdx = range.leftIdx, invRightIdx = range.rightIdx;
282 range.leftIdx = numSourceDims - 1 - invRightIdx;
283 range.rightIdx = numSourceDims - 1 - invLeftIdx;
287 std::reverse(rangesToInvert.begin(), rangesToInvert.end());
288 return rangesToInvert;
291std::optional<SmallVector<ReassociationIndices>>
294 unsigned numSourceDims = sourceShape.size(),
295 numTargetDims = targetShape.size();
300 if (numSourceDims <= numTargetDims)
305 if (numTargetDims == 0) {
306 for (
unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;
308 int64_t sourceSize = sourceShape[sourceDimIdx];
309 if (sourceSize != 1 && sourceSize != ShapedType::kDynamic)
316 FailureOr<SmallVector<ReassociationIndexRange>> maybeForwardRanges =
318 if (failed(maybeForwardRanges))
320 auto &ranges = *maybeForwardRanges;
329 FailureOr<SmallVector<ReassociationIndexRange>> maybeReverseRanges =
332 if (failed(maybeReverseRanges))
334 auto &reverseRanges = *maybeReverseRanges;
336 if (ranges.size() != numTargetDims || reverseRanges.size() != numTargetDims)
342 for (
unsigned targetDimIdx = 0; targetDimIdx < numTargetDims;
344 ReassociationIndexRange &range = ranges[targetDimIdx];
345 ReassociationIndexRange &reverseRange = reverseRanges[targetDimIdx];
348 range.getNonOverlappingIndicesWith(reverseRange);
351 for (
int64_t sourceDimIdx : nonMatchingIndices) {
352 if (sourceShape[sourceDimIdx] != 1)
355 reassociationMap[targetDimIdx] = range.getFullIndices();
357 return reassociationMap;
360std::optional<SmallVector<ReassociationIndices>>
368 if (producerReassociations.size() == consumerReassociations.size())
370 if (producerReassociations.size() < consumerReassociations.size())
371 std::swap(producerReassociations, consumerReassociations);
375 if (consumerReassociations.empty())
376 return composedIndices;
378 size_t consumerDims =
379 llvm::accumulate(consumerReassociations,
size_t(0),
383 if (producerReassociations.size() != consumerDims)
388 for (
int64_t consumerIndex : consumerIndices) {
389 llvm::append_range(reassociations, producerReassociations[consumerIndex]);
391 composedIndices.push_back(std::move(reassociations));
393 return composedIndices;
400 for (
const auto &
indices : reassociationIndices) {
402 reassociationMap.reserve(
indices.size());
405 reassociationMaps.push_back(std::move(reassociationMap));
407 return reassociationMaps;
410template <
typename AffineExprTy>
413 for (
const auto &exprs : exprArrays) {
414 for (
auto expr : exprs) {
416 if (
auto d = dyn_cast<AffineExprTy>(e))
417 pos = std::max(pos, d.getPosition());
428 return cast<Attribute>(
b.getI64ArrayAttr(
indices));
430 return b.getArrayAttr(reassociationAttr);
436 for (
const auto &exprs : reassociationExprs) {
439 for (
const auto &expr : exprs)
440 indices.push_back(cast<AffineDimExpr>(expr).getPosition());
441 reassociationIndices.push_back(
indices);
443 return reassociationIndices;
450 "Expected symbol-less expressions");
452 maps.reserve(reassociation.size());
453 for (
const auto &exprs : reassociation) {
454 assert(!exprs.empty());
462 if (reassociation.empty())
464 unsigned nDims = reassociation[0].getNumDims();
465 unsigned nextExpectedDim = 0;
466 for (
const auto &it : llvm::enumerate(reassociation)) {
468 if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
470 *invalidIndex = it.index();
473 for (
auto e : m.getResults()) {
474 auto d = dyn_cast<AffineDimExpr>(e);
475 if (!d || d.getPosition() != nextExpectedDim++) {
477 *invalidIndex = it.index();
482 if (nextExpectedDim != nDims) {
484 *invalidIndex = reassociation.size() - 1;
490LogicalResult mlir::reshapeLikeShapesAreCompatible(
494 unsigned expandedDimStart = 0;
495 for (
const auto &map : llvm::enumerate(reassociationMaps)) {
496 bool foundDynamicShape =
false;
497 int64_t linearizedStaticShape = 1;
499 for (
const auto &dim : llvm::enumerate(
500 expandedShape.slice(expandedDimStart, map.value().size()))) {
501 if (ShapedType::isDynamic(dim.value()))
502 foundDynamicShape =
true;
504 linearizedStaticShape *= dim.value();
506 if (foundDynamicShape) {
507 if (ShapedType::isStatic(collapsedShape[map.index()])) {
509 "expected dimension " + Twine(map.index()) +
510 " of collapsed type to be dynamic since one or more of the "
511 "corresponding dimensions in the expanded type is dynamic");
514 if (collapsedShape[map.index()] != linearizedStaticShape) {
515 return emitError(
"expected dimension " + Twine(map.index()) +
516 " of collapsed type to be static value of " +
517 Twine(linearizedStaticShape));
520 expandedDimStart += map.value().size();
525bool mlir::hasNonIdentityLayout(
Type type) {
526 if (
auto memrefType = dyn_cast<MemRefType>(type))
527 return !memrefType.getLayout().isIdentity();
534 assert(sliceParams.size() == sliceInputShape.size() &&
535 "only supports non rank-reducing case");
536 llvm::SmallBitVector mask(sliceInputShape.size());
538 for (
const auto &[offset, size, stride] : sliceParams) {
542 (!strideConst || *strideConst != 1) ||
543 (!offsetConst || *offsetConst != 0);
551 llvm::SmallBitVector
result(reassociationIndices.size());
552 for (
const auto &it : llvm::enumerate(reassociationIndices))
553 result[it.index()] = it.value().size() > 1;
559 unsigned loopIdx = 0;
560 auto oneAttr = IntegerAttr::get(IndexType::get(ctx), 1);
561 auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0);
563 offsetsSizesAndStrides.reserve(collapseShapeInputShape.size());
564 for (
const auto &it : llvm::enumerate(reassociationIndices)) {
568 if (slicedDimensions[it.index()] && linearizedDimensions[it.index()]) {
570 offsetsSizesAndStrides,
571 llvm::map_range(multiIndices[loopIdx++], [&](
Value v) ->
Range {
580 if (linearizedDimensions[it.index()]) {
581 llvm::append_range(offsetsSizesAndStrides,
583 return {zeroAttr, collapseShapeInputShape[idx],
590 offsetsSizesAndStrides.push_back(sliceParams[it.index()]);
592 return offsetsSizesAndStrides;
596SliceFromCollapseHelper::getInsertSliceParams(
MLIRContext *ctx,
598 auto one = IntegerAttr::get(IndexType::get(ctx), 1);
599 auto zero = IntegerAttr::get(IndexType::get(ctx), 0);
601 insertParams.reserve(linearizedDimensions.size());
602 unsigned loopIdx = 0;
603 for (
unsigned i = 0; i < linearizedDimensions.size(); i++) {
604 if (linearizedDimensions[i] && slicedDimensions[i]) {
605 insertParams.push_back(
Range{tileIndices[loopIdx++], one, one});
608 insertParams.push_back(
Range{zero, sliceParams[i].size, one});
619 std::optional<int64_t> dimIndex;
623 if (
shape[idx] != 1) {
624 if (dimIndex != std::nullopt)
636 RankedTensorType sourceType,
639 for (
const auto &
indices : reassociationIndices)
640 trivialSegments.push_back(
642 return trivialSegments;
647static FailureOr<SmallVector<std::optional<int64_t>>>
649 RankedTensorType sourceType,
653 if (!llvm::any_of(trivialSegments, [](
const std::optional<int64_t> &idx) {
654 return idx.has_value();
657 return trivialSegments;
660FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
661mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
662 RankedTensorType sourceType,
664 FailureOr<SmallVector<std::optional<int64_t>>> trivialSegments =
666 reassociationIndices);
667 if (
failed(trivialSegments))
672 for (
const auto &[nonUnitDim,
indices] :
673 llvm::zip(*trivialSegments, reassociationIndices)) {
675 sliceShape.push_back(sourceType.getDimSize(*nonUnitDim));
678 llvm::append_range(sliceShape, llvm::map_range(
indices, [&](
int64_t idx) {
679 return sourceType.getDimSize(idx);
683 RankedTensorType::get(sliceShape, sourceType.getElementType());
686 if (sliceShape.size() == reassociationIndices.size())
687 return CollapseShapeRankReducingSliceSimplificationInfo{sliceType,
696 for (
int64_t dimIdx = 0; dimIdx < sliceType.getRank(); dimIdx++) {
697 reassociation.push_back(dimIdx);
698 if ((*trivialSegments)[groupIdx] ||
699 reassociation.size() == reassociationIndices[groupIdx].size()) {
700 newReassociationIndices.push_back(reassociation);
701 reassociation.clear();
706 return CollapseShapeRankReducingSliceSimplificationInfo{
707 sliceType, newReassociationIndices};
710PackingMetadata mlir::computePackingMetadata(
int64_t packedRank,
713 res.insertPositions.reserve(innerDimPos.size());
729 for (
int64_t pos : innerDimPos) {
730 int64_t numInsertedBefore = llvm::count_if(
731 innerDimPos, [&pos](
int64_t pos2) {
return pos > pos2; });
732 res.insertPositions.push_back(pos + numInsertedBefore + offset);
736 res.insertPositions.end());
737 res.reassociations.reserve(packedRank);
738 for (
int64_t i = 1; i <= packedRank; ++i) {
739 res.outerPositions.push_back(i - 1);
740 if (!posSet.contains(i)) {
752 std::optional<Attribute> cst) {
static FailureOr< ReassociationIndexRange > findReassociationRangeForDynamicDim(ArrayRef< int64_t > sourceShape, int64_t sourceStartIdx, bool matchGreedily=false)
Starting from sourceStartIdx, searches sourceShape for the first sequence that can be collapsed into ...
static SmallVector< std::optional< int64_t > > getCollapseShapeTrivialSegments(RankedTensorType sourceType, ArrayRef< ReassociationIndices > reassociationIndices)
static std::optional< int64_t > getUniqueNonUnitDim(ArrayRef< int64_t > indices, ArrayRef< int64_t > shape)
Returns the index of the only non-unit dimension among indices of shape, if such a dimension exists a...
static unsigned getMaxPosOfType(ArrayRef< ReassociationExprs > exprArrays)
static FailureOr< SmallVector< std::optional< int64_t > > > canCollapseShapeBeSimplifiedByRankReducingSlice(RankedTensorType sourceType, ArrayRef< ReassociationIndices > reassociationIndices)
Returns true if any of the segments of the reassociation indices for a collapsing reshape can be simp...
static FailureOr< ReassociationIndexRange > findReassociationRangeForSize(ArrayRef< int64_t > sourceShape, int64_t sourceStartIdx, int64_t targetSize, bool matchGreedily=false)
Starting from sourceStartIdx, searches sourceShape for the first sequence of static dimensions such t...
static FailureOr< SmallVector< ReassociationIndexRange > > findReassociationRangesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)
Attempts to find a valid collapsing reassociation of sourceShape into targetShape through a simple tr...
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
An attribute that represents a reference to a dense vector or tensor object.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
llvm::SmallBitVector getSlicedDimensions(ArrayRef< OpFoldResult > sliceInputShape, ArrayRef< Range > sliceParams)
The input parameters offsets, sizes, strides specify a rectangular non rank-reducing slice of the col...
ArrayRef< int64_t > ReassociationIndicesRef
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< ReassociationIndices, 2 > convertReassociationMapsToIndices(ArrayRef< ReassociationExprs > reassociationExprs)
Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)
Returns the reassociation maps to collapse sourceShape to targetShape if possible.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
std::optional< SmallVector< ReassociationIndices > > composeReassociationIndices(ArrayRef< ReassociationIndices > producerReassociations, ArrayRef< ReassociationIndices > consumerReassociations, MLIRContext *context)
Compose reassociation maps that are used in pair of reshape ops where one is a producer and other is ...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getLinearizedDimensions(ArrayRef< ReassociationIndices > reassociationIndices)
Determine which dimensions are linearized by a tensor.collapse_shape op by inspecting its reassociati...
SmallVector< int64_t, 2 > ReassociationIndices
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...