19 std::optional<SmallVector<ReassociationIndices>>
21 ShapedType targetType) {
22 if (sourceType.getRank() > targetType.getRank())
24 targetType.getShape());
25 if (sourceType.getRank() < targetType.getRank())
27 sourceType.getShape());
31 std::optional<SmallVector<ReassociationIndices>>
34 if (sourceShape.size() <= targetShape.size())
36 unsigned sourceDim = 0;
38 reassociationMap.reserve(targetShape.size());
41 int64_t prodOfCollapsedDims = 1;
42 while (sourceDim < sourceShape.size()) {
43 unsigned targetDim = reassociationMap.size();
46 if (targetDim == targetShape.size())
49 int64_t currTargetShape = targetShape[targetDim];
50 while (sourceDim < (sourceShape.size() - 1) &&
51 sourceShape[sourceDim] != ShapedType::kDynamic &&
52 prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
53 prodOfCollapsedDims *= sourceShape[sourceDim];
54 currIndices.push_back(sourceDim++);
60 if (sourceShape[sourceDim] == ShapedType::kDynamic &&
61 (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
66 if (currTargetShape == ShapedType::kDynamic &&
67 sourceShape[sourceDim] != ShapedType::kDynamic)
72 if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
75 currIndices.push_back(sourceDim++);
77 std::swap(reassociationMap.back(), currIndices);
78 prodOfCollapsedDims = 1;
81 if (reassociationMap.size() != targetShape.size())
85 for (; sourceDim < sourceShape.size(); sourceDim++) {
86 if (sourceShape[sourceDim] != ShapedType::kDynamic &&
87 sourceShape[sourceDim] != 1)
90 if (!reassociationMap.empty())
91 reassociationMap.back().push_back(sourceDim);
93 return reassociationMap;
96 std::optional<SmallVector<ReassociationIndices>>
104 if (producerReassociations.size() == consumerReassociations.size())
106 if (producerReassociations.size() < consumerReassociations.size())
107 std::swap(producerReassociations, consumerReassociations);
111 if (consumerReassociations.empty())
112 return composedIndices;
114 size_t consumerDims = std::accumulate(
115 consumerReassociations.begin(), consumerReassociations.end(), 0,
117 return all + indices.size();
119 if (producerReassociations.size() != consumerDims)
124 for (int64_t consumerIndex : consumerIndices) {
125 llvm::append_range(reassociations, producerReassociations[consumerIndex]);
127 composedIndices.push_back(std::move(reassociations));
129 return composedIndices;
136 for (
const auto &indices : reassociationIndices) {
138 reassociationMap.reserve(indices.size());
139 for (int64_t index : indices)
141 reassociationMaps.push_back(std::move(reassociationMap));
143 return reassociationMaps;
146 template <
typename AffineExprTy>
149 for (
const auto &exprs : exprArrays) {
150 for (
auto expr : exprs) {
152 if (
auto d = dyn_cast<AffineExprTy>(e))
153 pos =
std::max(pos, d.getPosition());
163 llvm::to_vector<4>(llvm::map_range(
173 for (
const auto &exprs : reassociationExprs) {
175 indices.reserve(exprs.size());
176 for (
const auto &expr : exprs)
177 indices.push_back(cast<AffineDimExpr>(expr).getPosition());
178 reassociationIndices.push_back(indices);
180 return reassociationIndices;
185 unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
186 assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
187 "Expected symbol-less expressions");
189 maps.reserve(reassociation.size());
190 for (
const auto &exprs : reassociation) {
191 assert(!exprs.empty());
199 if (reassociation.empty())
201 unsigned nDims = reassociation[0].getNumDims();
202 unsigned nextExpectedDim = 0;
205 if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
207 *invalidIndex = it.index();
210 for (
auto e : m.getResults()) {
211 auto d = dyn_cast<AffineDimExpr>(e);
212 if (!d || d.getPosition() != nextExpectedDim++) {
214 *invalidIndex = it.index();
219 if (nextExpectedDim != nDims) {
221 *invalidIndex = reassociation.size() - 1;
231 unsigned expandedDimStart = 0;
233 bool foundDynamicShape =
false;
234 int64_t linearizedStaticShape = 1;
237 expandedShape.slice(expandedDimStart, map.value().size()))) {
238 if (ShapedType::isDynamic(dim.value()))
239 foundDynamicShape =
true;
241 linearizedStaticShape *= dim.value();
243 if (foundDynamicShape) {
244 if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
246 "expected dimension " + Twine(map.index()) +
247 " of collapsed type to be dynamic since one or more of the "
248 "corresponding dimensions in the expanded type is dynamic");
251 if (collapsedShape[map.index()] != linearizedStaticShape) {
252 return emitError(
"expected dimension " + Twine(map.index()) +
253 " of collapsed type to be static value of " +
254 Twine(linearizedStaticShape));
257 expandedDimStart += map.value().size();
263 if (
auto memrefType = dyn_cast<MemRefType>(type))
264 return !memrefType.getLayout().isIdentity();
271 assert(sliceParams.size() == sliceInputShape.size() &&
272 "only supports non rank-reducing case");
273 llvm::SmallBitVector mask(sliceInputShape.size());
275 for (
const auto &[offset, size, stride] : sliceParams) {
279 (!strideConst || *strideConst != 1) ||
280 (!offsetConst || *offsetConst != 0);
288 llvm::SmallBitVector result(reassociationIndices.size());
290 result[it.index()] = it.value().size() > 1;
296 unsigned loopIdx = 0;
300 offsetsSizesAndStrides.reserve(collapseShapeInputShape.size());
305 if (slicedDimensions[it.index()] && linearizedDimensions[it.index()]) {
307 offsetsSizesAndStrides,
308 llvm::map_range(multiIndices[loopIdx++], [&](
Value v) ->
Range {
317 if (linearizedDimensions[it.index()]) {
319 offsetsSizesAndStrides,
320 llvm::map_range(it.value(), [&](int64_t idx) ->
Range {
321 return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
327 offsetsSizesAndStrides.push_back(sliceParams[it.index()]);
329 return offsetsSizesAndStrides;
333 SliceFromCollapseHelper::getInsertSliceParams(
MLIRContext *ctx,
338 insertParams.reserve(linearizedDimensions.size());
339 unsigned loopIdx = 0;
340 for (
unsigned i = 0; i < linearizedDimensions.size(); i++) {
341 if (linearizedDimensions[i] && slicedDimensions[i]) {
342 insertParams.push_back(
Range{tileIndices[loopIdx++], one, one});
345 insertParams.push_back(
Range{zero, sliceParams[i].
size, one});
356 std::optional<int64_t> dimIndex;
357 if (indices.size() < 2)
359 for (int64_t idx : indices) {
360 if (shape[idx] != 1) {
361 if (dimIndex != std::nullopt)
373 RankedTensorType sourceType,
376 for (
const auto &indices : reassociationIndices)
377 trivialSegments.push_back(
379 return trivialSegments;
384 static FailureOr<SmallVector<std::optional<int64_t>>>
386 RankedTensorType sourceType,
390 if (!llvm::any_of(trivialSegments, [](
const std::optional<int64_t> &idx) {
391 return idx.has_value();
394 return trivialSegments;
397 FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
398 mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo(
399 RankedTensorType sourceType,
401 FailureOr<SmallVector<std::optional<int64_t>>> trivialSegments =
403 reassociationIndices);
404 if (failed(trivialSegments))
409 for (
const auto &[nonUnitDim, indices] :
410 llvm::zip(*trivialSegments, reassociationIndices)) {
412 sliceShape.push_back(sourceType.getDimSize(*nonUnitDim));
415 llvm::append_range(sliceShape, llvm::map_range(indices, [&](int64_t idx) {
416 return sourceType.getDimSize(idx);
423 if (sliceShape.size() == reassociationIndices.size())
424 return CollapseShapeRankReducingSliceSimplificationInfo{sliceType,
432 int64_t groupIdx = 0;
433 for (int64_t dimIdx = 0; dimIdx < sliceType.getRank(); dimIdx++) {
434 reassociation.push_back(dimIdx);
435 if ((*trivialSegments)[groupIdx] ||
436 reassociation.size() == reassociationIndices[groupIdx].size()) {
437 newReassociationIndices.push_back(reassociation);
438 reassociation.clear();
443 return CollapseShapeRankReducingSliceSimplificationInfo{
444 sliceType, newReassociationIndices};
447 PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
450 res.insertPositions.reserve(innerDimPos.size());
466 for (int64_t pos : innerDimPos) {
467 int64_t numInsertedBefore = llvm::count_if(
468 innerDimPos, [&pos](int64_t pos2) {
return pos > pos2; });
469 res.insertPositions.push_back(pos + numInsertedBefore + offset);
473 res.insertPositions.end());
474 res.reassociations.reserve(packedRank);
475 for (int64_t i = 1; i <= packedRank; ++i) {
476 res.outerPositions.push_back(i - 1);
477 if (!posSet.contains(i)) {
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
unsigned getMaxPosOfType(ArrayRef< ReassociationExprs > exprArrays)
static SmallVector< std::optional< int64_t > > getCollapseShapeTrivialSegments(RankedTensorType sourceType, ArrayRef< ReassociationIndices > reassociationIndices)
static std::optional< int64_t > getUniqueNonUnitDim(ArrayRef< int64_t > indices, ArrayRef< int64_t > shape)
Returns the index of the only non-unit dimension among indices of shape, if such a dimension exists a...
static FailureOr< SmallVector< std::optional< int64_t > > > canCollapseShapeBeSimplifiedByRankReducingSlice(RankedTensorType sourceType, ArrayRef< ReassociationIndices > reassociationIndices)
Returns true if any of the segments of the reassociation indices for a collapsing reshape can be simp...
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
llvm::SmallBitVector getSlicedDimensions(ArrayRef< OpFoldResult > sliceInputShape, ArrayRef< Range > sliceParams)
The input parameters offsets, sizes, strides specify a rectangular non rank-reducing slice of the col...
bool hasNonIdentityLayout(Type type)
Returns true iff the type is a MemRefType and has a non-identity layout.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< ReassociationIndices, 2 > convertReassociationMapsToIndices(ArrayRef< ReassociationExprs > reassociationExprs)
Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rule: if a dimension in the collapsed type i...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)
Returns the reassociation maps to collapse sourceShape to targetShape if possible.
ArrayRef< int64_t > ReassociationIndicesRef
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
std::optional< SmallVector< ReassociationIndices > > composeReassociationIndices(ArrayRef< ReassociationIndices > producerReassociations, ArrayRef< ReassociationIndices > consumerReassociations, MLIRContext *context)
Compose reassociation maps that are used in pair of reshape ops where one is a producer and other is ...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getLinearizedDimensions(ArrayRef< ReassociationIndices > reassociationIndices)
Determine which dimensions are linearized by a tensor.collapse_shape op by inspecting its reassociati...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...