19 std::optional<SmallVector<ReassociationIndices>>
21 ShapedType targetType) {
22 if (sourceType.getRank() > targetType.getRank())
24 targetType.getShape());
25 if (sourceType.getRank() < targetType.getRank())
27 sourceType.getShape());
31 std::optional<SmallVector<ReassociationIndices>>
34 if (sourceShape.size() <= targetShape.size())
36 unsigned sourceDim = 0;
38 reassociationMap.reserve(targetShape.size());
41 int64_t prodOfCollapsedDims = 1;
42 while (sourceDim < sourceShape.size()) {
43 unsigned targetDim = reassociationMap.size();
46 if (targetDim == targetShape.size())
49 int64_t currTargetShape = targetShape[targetDim];
50 while (sourceDim < sourceShape.size() &&
51 sourceShape[sourceDim] != ShapedType::kDynamic &&
52 prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
53 prodOfCollapsedDims *= sourceShape[sourceDim];
54 currIndices.push_back(sourceDim++);
60 if (sourceShape[sourceDim] == ShapedType::kDynamic &&
61 (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
66 if (currTargetShape == ShapedType::kDynamic &&
67 sourceShape[sourceDim] != ShapedType::kDynamic)
72 if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
75 currIndices.push_back(sourceDim++);
77 std::swap(reassociationMap.back(), currIndices);
78 prodOfCollapsedDims = 1;
81 if (reassociationMap.size() != targetShape.size())
85 for (; sourceDim < sourceShape.size(); sourceDim++) {
86 if (sourceShape[sourceDim] != ShapedType::kDynamic &&
87 sourceShape[sourceDim] != 1)
90 if (!reassociationMap.empty())
91 reassociationMap.back().push_back(sourceDim);
93 return reassociationMap;
96 std::optional<SmallVector<ReassociationIndices>>
104 if (producerReassociations.size() == consumerReassociations.size())
106 if (producerReassociations.size() < consumerReassociations.size())
107 std::swap(producerReassociations, consumerReassociations);
111 if (consumerReassociations.empty())
112 return composedIndices;
114 size_t consumerDims = std::accumulate(
115 consumerReassociations.begin(), consumerReassociations.end(), 0,
117 return all + indices.size();
119 if (producerReassociations.size() != consumerDims)
124 for (int64_t consumerIndex : consumerIndices) {
125 llvm::append_range(reassociations, producerReassociations[consumerIndex]);
127 composedIndices.push_back(std::move(reassociations));
129 return composedIndices;
136 for (
const auto &indices : reassociationIndices) {
138 reassociationMap.reserve(indices.size());
139 for (int64_t index : indices)
141 reassociationMaps.push_back(std::move(reassociationMap));
143 return reassociationMaps;
146 template <
typename AffineExprTy>
149 for (
const auto &exprs : exprArrays) {
150 for (
auto expr : exprs) {
152 if (
auto d = dyn_cast<AffineExprTy>(e))
153 pos =
std::max(pos, d.getPosition());
163 llvm::to_vector<4>(llvm::map_range(
173 for (
const auto &exprs : reassociationExprs) {
175 indices.reserve(exprs.size());
176 for (
const auto &expr : exprs)
177 indices.push_back(cast<AffineDimExpr>(expr).getPosition());
178 reassociationIndices.push_back(indices);
180 return reassociationIndices;
185 unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
186 assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
187 "Expected symbol-less expressions");
189 maps.reserve(reassociation.size());
190 for (
const auto &exprs : reassociation) {
191 assert(!exprs.empty());
199 if (reassociation.empty())
201 unsigned nDims = reassociation[0].getNumDims();
202 unsigned nextExpectedDim = 0;
205 if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
207 *invalidIndex = it.index();
210 for (
auto e : m.getResults()) {
211 auto d = dyn_cast<AffineDimExpr>(e);
212 if (!d || d.getPosition() != nextExpectedDim++) {
214 *invalidIndex = it.index();
219 if (nextExpectedDim != nDims) {
221 *invalidIndex = reassociation.size() - 1;
231 unsigned expandedDimStart = 0;
233 std::optional<int64_t> dynamicShape;
234 int64_t linearizedStaticShape = 1;
236 expandedShape.slice(expandedDimStart, map.value().size()))) {
237 if (ShapedType::isDynamic(dim.value())) {
238 if (isExpandingReshape && dynamicShape) {
239 return emitError(
"invalid to have a single dimension (" +
241 ") expanded into multiple dynamic dims (" +
242 Twine(expandedDimStart + dynamicShape.value()) +
243 "," + Twine(expandedDimStart + dim.index()) +
")");
245 dynamicShape = dim.index();
247 linearizedStaticShape *= dim.value();
251 if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
253 "expected dimension " + Twine(map.index()) +
254 " of collapsed type to be dynamic since one or more of the "
255 "corresponding dimensions in the expanded type is dynamic");
258 if (collapsedShape[map.index()] != linearizedStaticShape) {
259 return emitError(
"expected dimension " + Twine(map.index()) +
260 " of collapsed type to be static value of " +
261 Twine(linearizedStaticShape));
264 expandedDimStart += map.value().size();
270 if (
auto memrefType = dyn_cast<MemRefType>(type))
271 return !memrefType.getLayout().isIdentity();
278 assert(sliceParams.size() == sliceInputShape.size() &&
279 "only supports non rank-reducing case");
280 llvm::SmallBitVector mask(sliceInputShape.size());
282 for (
const auto &[offset, size, stride] : sliceParams) {
286 (!strideConst || *strideConst != 1) ||
287 (!offsetConst || *offsetConst != 0);
295 llvm::SmallBitVector result(reassociationIndices.size());
297 result[it.index()] = it.value().size() > 1;
303 unsigned loopIdx = 0;
307 offsetsSizesAndStrides.reserve(collapseShapeInputShape.size());
312 if (slicedDimensions[it.index()] && linearizedDimensions[it.index()]) {
314 offsetsSizesAndStrides,
315 llvm::map_range(multiIndices[loopIdx++], [&](
Value v) ->
Range {
324 if (linearizedDimensions[it.index()]) {
326 offsetsSizesAndStrides,
327 llvm::map_range(it.value(), [&](int64_t idx) ->
Range {
328 return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
334 offsetsSizesAndStrides.push_back(sliceParams[it.index()]);
336 return offsetsSizesAndStrides;
340 SliceFromCollapseHelper::getInsertSliceParams(
MLIRContext *ctx,
345 insertParams.reserve(linearizedDimensions.size());
346 unsigned loopIdx = 0;
347 for (
unsigned i = 0; i < linearizedDimensions.size(); i++) {
348 if (linearizedDimensions[i] && slicedDimensions[i]) {
349 insertParams.push_back(
Range{tileIndices[loopIdx++], one, one});
352 insertParams.push_back(
Range{zero, sliceParams[i].
size, one});
363 std::optional<int64_t> dimIndex;
364 if (indices.size() < 2)
366 for (int64_t idx : indices) {
367 if (shape[idx] != 1) {
368 if (dimIndex != std::nullopt)
380 RankedTensorType sourceType,
383 for (
const auto &indices : reassociationIndices)
384 trivialSegments.push_back(
386 return trivialSegments;
393 RankedTensorType sourceType,
397 if (!llvm::any_of(trivialSegments, [](
const std::optional<int64_t> &idx) {
398 return idx.has_value();
401 return trivialSegments;
405 mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
406 RankedTensorType sourceType,
410 reassociationIndices);
411 if (
failed(trivialSegments))
416 for (
const auto &[nonUnitDim, indices] :
417 llvm::zip(*trivialSegments, reassociationIndices)) {
419 sliceShape.push_back(sourceType.getDimSize(*nonUnitDim));
422 llvm::append_range(sliceShape, llvm::map_range(indices, [&](int64_t idx) {
423 return sourceType.getDimSize(idx);
430 if (sliceShape.size() == reassociationIndices.size())
431 return CollapseShapeRankReducingSliceSimplificationInfo{sliceType,
439 int64_t groupIdx = 0;
440 for (int64_t dimIdx = 0; dimIdx < sliceType.getRank(); dimIdx++) {
441 reassociation.push_back(dimIdx);
442 if ((*trivialSegments)[groupIdx] ||
443 reassociation.size() == reassociationIndices[groupIdx].size()) {
444 newReassociationIndices.push_back(reassociation);
445 reassociation.clear();
450 return CollapseShapeRankReducingSliceSimplificationInfo{
451 sliceType, newReassociationIndices};
454 PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
457 res.insertPositions.reserve(innerDimPos.size());
473 for (int64_t pos : innerDimPos) {
474 int64_t numInsertedBefore = llvm::count_if(
475 innerDimPos, [&pos](int64_t pos2) {
return pos > pos2; });
476 res.insertPositions.push_back(pos + numInsertedBefore + offset);
480 res.insertPositions.end());
481 res.reassociations.reserve(packedRank);
482 for (int64_t i = 1; i <= packedRank; ++i) {
483 res.outerPositions.push_back(i - 1);
484 if (!posSet.contains(i)) {
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
unsigned getMaxPosOfType(ArrayRef< ReassociationExprs > exprArrays)
static SmallVector< std::optional< int64_t > > getCollapseShapeTrivialSegments(RankedTensorType sourceType, ArrayRef< ReassociationIndices > reassociationIndices)
static std::optional< int64_t > getUniqueNonUnitDim(ArrayRef< int64_t > indices, ArrayRef< int64_t > shape)
Returns the index of the only non-unit dimension among indices of shape, if such a dimension exists a...
static FailureOr< SmallVector< std::optional< int64_t > > > canCollapseShapeBeSimplifiedByRankReducingSlice(RankedTensorType sourceType, ArrayRef< ReassociationIndices > reassociationIndices)
Returns true if any of the segments of the reassociation indices for a collapsing reshape can be simp...
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class provides support for representing a failure result, or a valid value of type T.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
llvm::SmallBitVector getSlicedDimensions(ArrayRef< OpFoldResult > sliceInputShape, ArrayRef< Range > sliceParams)
The input parameters offsets, sizes, strides specify a rectangular non rank-reducing slice of the col...
bool hasNonIdentityLayout(Type type)
Returns true iff the type is a MemRefType and has a non-identity layout.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rules 1) if a dimension in the collapsed typ...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)
Returns the reassociation maps to collapse sourceShape to targetShape if possible.
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
std::optional< SmallVector< ReassociationIndices > > composeReassociationIndices(ArrayRef< ReassociationIndices > producerReassociations, ArrayRef< ReassociationIndices > consumerReassociations, MLIRContext *context)
Compose reassociation maps that are used in pair of reshape ops where one is a producer and other is ...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getLinearizedDimensions(ArrayRef< ReassociationIndices > reassociationIndices)
Determine which dimensions are linearized by a tensor.collapse_shape op by inspecting its reassociati...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
SmallVector< ReassociationIndices, 2 > convertReassociationMapsToIndices(OpBuilder &b, ArrayRef< ReassociationExprs > reassociationExprs)
Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
This class represents an efficient way to signal success or failure.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...