18 Optional<SmallVector<ReassociationIndices>>
20 ShapedType targetType) {
21 if (sourceType.getRank() > targetType.getRank())
23 targetType.getShape());
24 if (sourceType.getRank() < targetType.getRank())
26 sourceType.getShape());
33 if (sourceShape.size() <= targetShape.size())
35 unsigned sourceDim = 0;
37 reassociationMap.reserve(targetShape.size());
40 int64_t prodOfCollapsedDims = 1;
41 while (sourceDim < sourceShape.size()) {
42 unsigned targetDim = reassociationMap.size();
45 if (targetDim == targetShape.size())
48 int64_t currTargetShape = targetShape[targetDim];
49 while (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
50 prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape &&
51 sourceDim < sourceShape.size()) {
52 prodOfCollapsedDims *= sourceShape[sourceDim];
53 currIndices.push_back(sourceDim++);
59 if (sourceShape[sourceDim] == ShapedType::kDynamicSize &&
60 (currTargetShape != ShapedType::kDynamicSize ||
61 prodOfCollapsedDims != 1))
66 if (currTargetShape == ShapedType::kDynamicSize &&
67 sourceShape[sourceDim] != ShapedType::kDynamicSize)
72 if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
75 currIndices.push_back(sourceDim++);
77 std::swap(reassociationMap.back(), currIndices);
78 prodOfCollapsedDims = 1;
81 if (reassociationMap.size() != targetShape.size())
85 for (; sourceDim < sourceShape.size(); sourceDim++) {
86 if (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
87 sourceShape[sourceDim] != 1)
90 if (!reassociationMap.empty())
91 reassociationMap.back().push_back(sourceDim);
93 return reassociationMap;
103 if (producerReassociations.size() == consumerReassociations.size())
105 if (producerReassociations.size() < consumerReassociations.size())
106 std::swap(producerReassociations, consumerReassociations);
110 if (consumerReassociations.empty())
111 return composedIndices;
113 size_t consumerDims = std::accumulate(
114 consumerReassociations.begin(), consumerReassociations.end(), 0,
116 return all + indices.size();
118 if (producerReassociations.size() != consumerDims)
123 for (int64_t consumerIndex : consumerIndices) {
124 llvm::append_range(reassociations, producerReassociations[consumerIndex]);
126 composedIndices.push_back(std::move(reassociations));
128 return composedIndices;
135 for (
const auto &indices : reassociationIndices) {
137 reassociationMap.reserve(indices.size());
138 for (int64_t index : indices)
140 reassociationMaps.push_back(std::move(reassociationMap));
142 return reassociationMaps;
145 template <
typename AffineExprTy>
148 for (
const auto &exprs : exprArrays) {
149 for (
auto expr : exprs) {
151 if (
auto d = e.
dyn_cast<AffineExprTy>())
152 pos =
std::max(pos, d.getPosition());
162 llvm::to_vector<4>(llvm::map_range(
172 for (
const auto &exprs : reassociationExprs) {
174 indices.reserve(exprs.size());
175 for (
const auto &expr : exprs)
177 reassociationIndices.push_back(indices);
179 return reassociationIndices;
184 unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
185 assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
186 "Expected symbol-less expressions");
188 maps.reserve(reassociation.size());
189 for (
const auto &exprs : reassociation) {
190 assert(!exprs.empty());
191 maps.push_back(
AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext()));
198 if (reassociation.empty())
200 unsigned nDims = reassociation[0].getNumDims();
201 unsigned nextExpectedDim = 0;
204 if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
206 *invalidIndex = it.index();
209 for (
auto e : m.getResults()) {
211 if (!d || d.getPosition() != nextExpectedDim++) {
213 *invalidIndex = it.index();
218 if (nextExpectedDim != nDims) {
220 *invalidIndex = reassociation.size() - 1;
230 unsigned expandedDimStart = 0;
233 int64_t linearizedStaticShape = 1;
235 expandedShape.slice(expandedDimStart, map.value().size()))) {
236 if (ShapedType::isDynamic(dim.value())) {
237 if (isExpandingReshape && dynamicShape) {
238 return emitError(
"invalid to have a single dimension (" +
240 ") expanded into multiple dynamic dims (" +
241 Twine(expandedDimStart + dynamicShape.getValue()) +
242 "," + Twine(expandedDimStart + dim.index()) +
")");
244 dynamicShape = dim.index();
246 linearizedStaticShape *= dim.value();
250 if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
252 "expected dimension " + Twine(map.index()) +
253 " of collapsed type to be dynamic since one or more of the " 254 "corresponding dimensions in the expanded type is dynamic");
257 if (collapsedShape[map.index()] != linearizedStaticShape) {
258 return emitError(
"expected dimension " + Twine(map.index()) +
259 " of collapsed type to be static value of " +
260 Twine(linearizedStaticShape));
263 expandedDimStart += map.value().size();
269 if (
auto memrefType = type.
dyn_cast<MemRefType>())
270 return !memrefType.getLayout().isIdentity();
Include the generated interface declarations.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
Optional< SmallVector< ReassociationIndices > > getReassociationIndicesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)
Returns the reassociation maps to collapse sourceShape to targetShape if possible.
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rules 1) if a dimension in the collapsed typ...
SmallVector< ReassociationIndices, 2 > convertReassociationMapsToIndices(OpBuilder &b, ArrayRef< ReassociationExprs > reassociationExprs)
Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
Optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
unsigned getPosition() const
unsigned getMaxPosOfType(ArrayRef< ReassociationExprs > exprArrays)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
This class represents an efficient way to signal success or failure.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Optional< SmallVector< ReassociationIndices > > composeReassociationIndices(ArrayRef< ReassociationIndices > producerReassociations, ArrayRef< ReassociationIndices > consumerReassociations, MLIRContext *context)
Compose reassociation maps that are used in pair of reshape ops where one is a producer and other is ...
Base type for affine expression.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool hasNonIdentityLayout(Type type)
Returns true iff the type is a MemRefType and has a non-identity layout.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
A dimensional identifier appearing in an affine expression.
MLIRContext is the top-level object for a collection of MLIR operations.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
This class helps build Operations.
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)