14 #include "llvm/ADT/ArrayRef.h"
15 #include "llvm/ADT/SmallVector.h"
22 std::optional<SmallVector<ReassociationIndices>>
24 ShapedType targetType) {
25 if (sourceType.getRank() > targetType.getRank())
27 targetType.getShape());
28 if (sourceType.getRank() < targetType.getRank())
30 sourceType.getShape());
38 struct ReassociationIndexRange {
42 int64_t leftIdx = 0, rightIdx = 0;
45 LogicalResult
verify()
const {
46 return leftIdx >= 0 && (leftIdx <= rightIdx) ? success() : failure();
51 bool isInRange(
const ReassociationIndexRange &outerRange)
const {
52 return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx;
55 unsigned size()
const {
56 assert(succeeded(
verify()));
57 return rightIdx - leftIdx + 1;
59 bool containsSingleIndex()
const {
return size() == 1; }
63 getNonOverlappingIndicesWith(ReassociationIndexRange &rhs)
const {
64 if (rightIdx < rhs.leftIdx) {
66 auto jointFullIndices = getFullIndices();
67 jointFullIndices.append(rhs.getFullIndices());
68 return jointFullIndices;
72 int64_t leftStart =
std::min(leftIdx, rhs.leftIdx);
73 int64_t leftEnd =
std::max(leftIdx, rhs.leftIdx);
74 llvm::append_range(result, llvm::seq(leftStart, leftEnd));
77 int64_t rightStart =
std::min(rightIdx, rhs.rightIdx) + 1;
78 int64_t rightEnd =
std::max(rightIdx, rhs.rightIdx);
79 if (rightStart < rightEnd)
80 llvm::append_range(result, llvm::seq_inclusive(rightStart, rightEnd));
87 for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) {
88 result.push_back(idx);
101 static FailureOr<ReassociationIndexRange>
103 int64_t sourceStartIdx,
104 bool matchGreedily =
false) {
105 const unsigned numSourceDims = sourceShape.size();
106 ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
107 std::optional<ReassociationIndexRange> resultRange = std::nullopt;
109 ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
110 for (; iterationRange.isInRange(sourceShapeAsRange);
111 iterationRange.rightIdx++) {
112 int64_t sourceSize = sourceShape[iterationRange.rightIdx];
113 if (sourceSize == ShapedType::kDynamic) {
114 resultRange = iterationRange;
121 resultRange->rightIdx = sourceShapeAsRange.rightIdx;
130 static FailureOr<ReassociationIndexRange>
132 int64_t sourceStartIdx, int64_t targetSize,
133 bool matchGreedily =
false) {
134 const unsigned numSourceDims = sourceShape.size();
135 ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
136 std::optional<ReassociationIndexRange> resultRange = std::nullopt;
138 ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
139 int64_t prodOfCollapsedDims = 1;
140 while (iterationRange.isInRange(sourceShapeAsRange)) {
141 int64_t sourceSize = sourceShape[iterationRange.rightIdx];
142 if (sourceSize == ShapedType::kDynamic) {
146 prodOfCollapsedDims = 1;
147 iterationRange = {iterationRange.rightIdx + 1,
148 iterationRange.rightIdx + 1};
151 prodOfCollapsedDims *= sourceSize;
155 while (prodOfCollapsedDims > targetSize &&
156 !iterationRange.containsSingleIndex()) {
157 int64_t frontSourceSize = sourceShape[iterationRange.leftIdx];
158 prodOfCollapsedDims /= frontSourceSize;
160 iterationRange.leftIdx++;
164 if (prodOfCollapsedDims == targetSize) {
165 resultRange = iterationRange;
169 iterationRange.rightIdx++;
177 iterationRange.rightIdx++;
178 while (iterationRange.isInRange(sourceShapeAsRange) &&
179 sourceShape[iterationRange.rightIdx] == 1) {
180 resultRange = iterationRange;
181 iterationRange.rightIdx++;
200 static FailureOr<SmallVector<ReassociationIndexRange>>
203 unsigned numSourceDims = sourceShape.size(),
204 numTargetDims = targetShape.size();
205 assert(numSourceDims > numTargetDims);
206 ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
209 reassocRanges.reserve(numTargetDims);
213 std::optional<int64_t> prevTargetSize = std::nullopt;
214 for (
unsigned targetDimIdx = 0, sourceDimIdx = 0;
215 targetDimIdx < numTargetDims; ++targetDimIdx) {
216 int64_t targetSize = targetShape[targetDimIdx];
219 bool shouldMatchGreedily = targetDimIdx == numTargetDims - 1;
220 FailureOr<ReassociationIndexRange> sourceRange;
221 if (targetSize == ShapedType::kDynamic) {
223 sourceShape, sourceDimIdx, shouldMatchGreedily);
226 sourceShape, sourceDimIdx, targetSize, shouldMatchGreedily);
230 if (failed(sourceRange) || failed(sourceRange->verify()) ||
231 !sourceRange->isInRange(sourceShapeAsRange))
233 if (sourceRange->leftIdx > sourceDimIdx) {
236 if (!prevTargetSize || prevTargetSize != ShapedType::kDynamic)
238 reassocRanges.back().rightIdx = sourceRange->leftIdx - 1;
242 prevTargetSize = targetSize;
243 sourceDimIdx = sourceRange->rightIdx + 1;
244 reassocRanges.push_back(*sourceRange);
249 if (reassocRanges.back().rightIdx < sourceShapeAsRange.rightIdx)
251 return reassocRanges;
256 static FailureOr<SmallVector<ReassociationIndexRange>>
259 bool iterateRightToLeft) {
260 if (!iterateRightToLeft)
267 std::vector<int64_t> sourceToReverse = sourceShape.vec(),
268 targetToReverse = targetShape.vec();
269 std::reverse(sourceToReverse.begin(), sourceToReverse.end());
270 std::reverse(targetToReverse.begin(), targetToReverse.end());
271 auto invertedRanges =
273 if (failed(invertedRanges))
276 unsigned numSourceDims = sourceShape.size();
279 for (
auto &range : rangesToInvert) {
280 int64_t invLeftIdx = range.leftIdx, invRightIdx = range.rightIdx;
281 range.leftIdx = numSourceDims - 1 - invRightIdx;
282 range.rightIdx = numSourceDims - 1 - invLeftIdx;
286 std::reverse(rangesToInvert.begin(), rangesToInvert.end());
287 return rangesToInvert;
290 std::optional<SmallVector<ReassociationIndices>>
293 unsigned numSourceDims = sourceShape.size(),
294 numTargetDims = targetShape.size();
299 if (numSourceDims <= numTargetDims)
304 if (numTargetDims == 0) {
305 for (
unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;
307 int64_t sourceSize = sourceShape[sourceDimIdx];
308 if (sourceSize != 1 && sourceSize != ShapedType::kDynamic)
315 FailureOr<SmallVector<ReassociationIndexRange>> maybeForwardRanges =
317 if (failed(maybeForwardRanges))
319 auto &ranges = *maybeForwardRanges;
328 FailureOr<SmallVector<ReassociationIndexRange>> maybeReverseRanges =
331 if (failed(maybeReverseRanges))
333 auto &reverseRanges = *maybeReverseRanges;
335 if (ranges.size() != numTargetDims || reverseRanges.size() != numTargetDims)
341 for (
unsigned targetDimIdx = 0; targetDimIdx < numTargetDims;
343 ReassociationIndexRange &range = ranges[targetDimIdx];
344 ReassociationIndexRange &reverseRange = reverseRanges[targetDimIdx];
347 range.getNonOverlappingIndicesWith(reverseRange);
350 for (int64_t sourceDimIdx : nonMatchingIndices) {
351 if (sourceShape[sourceDimIdx] != 1)
354 reassociationMap[targetDimIdx] = range.getFullIndices();
356 return reassociationMap;
359 std::optional<SmallVector<ReassociationIndices>>
367 if (producerReassociations.size() == consumerReassociations.size())
369 if (producerReassociations.size() < consumerReassociations.size())
370 std::swap(producerReassociations, consumerReassociations);
374 if (consumerReassociations.empty())
375 return composedIndices;
377 size_t consumerDims = std::accumulate(
378 consumerReassociations.begin(), consumerReassociations.end(), 0,
380 return all + indices.size();
382 if (producerReassociations.size() != consumerDims)
387 for (int64_t consumerIndex : consumerIndices) {
388 llvm::append_range(reassociations, producerReassociations[consumerIndex]);
390 composedIndices.push_back(std::move(reassociations));
392 return composedIndices;
399 for (
const auto &indices : reassociationIndices) {
401 reassociationMap.reserve(indices.size());
402 for (int64_t index : indices)
404 reassociationMaps.push_back(std::move(reassociationMap));
406 return reassociationMaps;
409 template <
typename AffineExprTy>
412 for (
const auto &exprs : exprArrays) {
413 for (
auto expr : exprs) {
415 if (
auto d = dyn_cast<AffineExprTy>(e))
416 pos =
std::max(pos, d.getPosition());
426 llvm::to_vector<4>(llvm::map_range(
436 for (
const auto &exprs : reassociationExprs) {
438 indices.reserve(exprs.size());
439 for (
const auto &expr : exprs)
440 indices.push_back(cast<AffineDimExpr>(expr).getPosition());
441 reassociationIndices.push_back(indices);
443 return reassociationIndices;
448 unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
449 assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
450 "Expected symbol-less expressions");
452 maps.reserve(reassociation.size());
453 for (
const auto &exprs : reassociation) {
454 assert(!exprs.empty());
462 if (reassociation.empty())
464 unsigned nDims = reassociation[0].getNumDims();
465 unsigned nextExpectedDim = 0;
468 if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
470 *invalidIndex = it.index();
473 for (
auto e : m.getResults()) {
474 auto d = dyn_cast<AffineDimExpr>(e);
475 if (!d || d.getPosition() != nextExpectedDim++) {
477 *invalidIndex = it.index();
482 if (nextExpectedDim != nDims) {
484 *invalidIndex = reassociation.size() - 1;
494 unsigned expandedDimStart = 0;
496 bool foundDynamicShape =
false;
497 int64_t linearizedStaticShape = 1;
500 expandedShape.slice(expandedDimStart, map.value().size()))) {
501 if (ShapedType::isDynamic(dim.value()))
502 foundDynamicShape =
true;
504 linearizedStaticShape *= dim.value();
506 if (foundDynamicShape) {
507 if (ShapedType::isStatic(collapsedShape[map.index()])) {
509 "expected dimension " + Twine(map.index()) +
510 " of collapsed type to be dynamic since one or more of the "
511 "corresponding dimensions in the expanded type is dynamic");
514 if (collapsedShape[map.index()] != linearizedStaticShape) {
515 return emitError(
"expected dimension " + Twine(map.index()) +
516 " of collapsed type to be static value of " +
517 Twine(linearizedStaticShape));
520 expandedDimStart += map.value().size();
526 if (
auto memrefType = dyn_cast<MemRefType>(type))
527 return !memrefType.getLayout().isIdentity();
534 assert(sliceParams.size() == sliceInputShape.size() &&
535 "only supports non rank-reducing case");
536 llvm::SmallBitVector mask(sliceInputShape.size());
538 for (
const auto &[offset, size, stride] : sliceParams) {
542 (!strideConst || *strideConst != 1) ||
543 (!offsetConst || *offsetConst != 0);
551 llvm::SmallBitVector result(reassociationIndices.size());
553 result[it.index()] = it.value().size() > 1;
559 unsigned loopIdx = 0;
563 offsetsSizesAndStrides.reserve(collapseShapeInputShape.size());
568 if (slicedDimensions[it.index()] && linearizedDimensions[it.index()]) {
570 offsetsSizesAndStrides,
571 llvm::map_range(multiIndices[loopIdx++], [&](
Value v) ->
Range {
580 if (linearizedDimensions[it.index()]) {
581 llvm::append_range(offsetsSizesAndStrides,
582 llvm::map_range(it.value(), [&](int64_t idx) ->
Range {
583 return {zeroAttr, collapseShapeInputShape[idx],
590 offsetsSizesAndStrides.push_back(sliceParams[it.index()]);
592 return offsetsSizesAndStrides;
596 SliceFromCollapseHelper::getInsertSliceParams(
MLIRContext *ctx,
601 insertParams.reserve(linearizedDimensions.size());
602 unsigned loopIdx = 0;
603 for (
unsigned i = 0; i < linearizedDimensions.size(); i++) {
604 if (linearizedDimensions[i] && slicedDimensions[i]) {
605 insertParams.push_back(
Range{tileIndices[loopIdx++], one, one});
608 insertParams.push_back(
Range{zero, sliceParams[i].
size, one});
619 std::optional<int64_t> dimIndex;
620 if (indices.size() < 2)
622 for (int64_t idx : indices) {
623 if (shape[idx] != 1) {
624 if (dimIndex != std::nullopt)
636 RankedTensorType sourceType,
639 for (
const auto &indices : reassociationIndices)
640 trivialSegments.push_back(
642 return trivialSegments;
647 static FailureOr<SmallVector<std::optional<int64_t>>>
649 RankedTensorType sourceType,
653 if (!llvm::any_of(trivialSegments, [](
const std::optional<int64_t> &idx) {
654 return idx.has_value();
657 return trivialSegments;
660 FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
661 mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
662 RankedTensorType sourceType,
664 FailureOr<SmallVector<std::optional<int64_t>>> trivialSegments =
666 reassociationIndices);
667 if (failed(trivialSegments))
672 for (
const auto &[nonUnitDim, indices] :
673 llvm::zip(*trivialSegments, reassociationIndices)) {
675 sliceShape.push_back(sourceType.getDimSize(*nonUnitDim));
678 llvm::append_range(sliceShape, llvm::map_range(indices, [&](int64_t idx) {
679 return sourceType.getDimSize(idx);
686 if (sliceShape.size() == reassociationIndices.size())
687 return CollapseShapeRankReducingSliceSimplificationInfo{sliceType,
695 int64_t groupIdx = 0;
696 for (int64_t dimIdx = 0; dimIdx < sliceType.getRank(); dimIdx++) {
697 reassociation.push_back(dimIdx);
698 if ((*trivialSegments)[groupIdx] ||
699 reassociation.size() == reassociationIndices[groupIdx].size()) {
700 newReassociationIndices.push_back(reassociation);
701 reassociation.clear();
706 return CollapseShapeRankReducingSliceSimplificationInfo{
707 sliceType, newReassociationIndices};
710 PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
713 res.insertPositions.reserve(innerDimPos.size());
729 for (int64_t pos : innerDimPos) {
730 int64_t numInsertedBefore = llvm::count_if(
731 innerDimPos, [&pos](int64_t pos2) {
return pos > pos2; });
732 res.insertPositions.push_back(pos + numInsertedBefore + offset);
736 res.insertPositions.end());
737 res.reassociations.reserve(packedRank);
738 for (int64_t i = 1; i <= packedRank; ++i) {
739 res.outerPositions.push_back(i - 1);
740 if (!posSet.contains(i)) {
752 std::optional<Attribute> cst) {
753 if (source && source.
isSplat() && result.hasStaticShape() &&
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
unsigned getMaxPosOfType(ArrayRef< ReassociationExprs > exprArrays)
static FailureOr< ReassociationIndexRange > findReassociationRangeForSize(ArrayRef< int64_t > sourceShape, int64_t sourceStartIdx, int64_t targetSize, bool matchGreedily=false)
Starting from sourceStartIdx, searches sourceShape for the first sequence of static dimensions such t...
static SmallVector< std::optional< int64_t > > getCollapseShapeTrivialSegments(RankedTensorType sourceType, ArrayRef< ReassociationIndices > reassociationIndices)
static FailureOr< ReassociationIndexRange > findReassociationRangeForDynamicDim(ArrayRef< int64_t > sourceShape, int64_t sourceStartIdx, bool matchGreedily=false)
Starting from sourceStartIdx, searches sourceShape for the first sequence that can be collapsed into ...
static std::optional< int64_t > getUniqueNonUnitDim(ArrayRef< int64_t > indices, ArrayRef< int64_t > shape)
Returns the index of the only non-unit dimension among indices of shape, if such a dimension exists a...
static FailureOr< SmallVector< ReassociationIndexRange > > findReassociationRangesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)
Attempts to find a valid collapsing reassociation of sourceShape into targetShape through a simple tr...
static FailureOr< SmallVector< std::optional< int64_t > > > canCollapseShapeBeSimplifiedByRankReducingSlice(RankedTensorType sourceType, ArrayRef< ReassociationIndices > reassociationIndices)
Returns true if any of the segments of the reassociation indices for a collapsing reshape can be simp...
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
MLIRContext is the top-level object for a collection of MLIR operations.
This class represents a single result from folding an operation.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
llvm::SmallBitVector getSlicedDimensions(ArrayRef< OpFoldResult > sliceInputShape, ArrayRef< Range > sliceParams)
The input parameters offsets, sizes, strides specify a rectangular non rank-reducing slice of the col...
bool hasNonIdentityLayout(Type type)
Returns true iff the type is a MemRefType and has a non-identity layout.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< ReassociationIndices, 2 > convertReassociationMapsToIndices(ArrayRef< ReassociationExprs > reassociationExprs)
Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rule: if a dimension in the collapsed type i...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)
Returns the reassociation maps to collapse sourceShape to targetShape if possible.
ArrayRef< int64_t > ReassociationIndicesRef
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
std::optional< SmallVector< ReassociationIndices > > composeReassociationIndices(ArrayRef< ReassociationIndices > producerReassociations, ArrayRef< ReassociationIndices > consumerReassociations, MLIRContext *context)
Compose reassociation maps that are used in pair of reshape ops where one is a producer and other is ...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getLinearizedDimensions(ArrayRef< ReassociationIndices > reassociationIndices)
Determine which dimensions are linearized by a tensor.collapse_shape op by inspecting its reassociati...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...