MLIR  15.0.0git
ReshapeOpsUtils.cpp
Go to the documentation of this file.
1 //===- ReshapeOpsUtils.cpp - Utilities used by structured ops -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 #include "mlir/IR/AffineMap.h"
12 #include "mlir/IR/Builders.h"
13 
14 #include <numeric>
15 
16 using namespace mlir;
17 
18 Optional<SmallVector<ReassociationIndices>>
20  ShapedType targetType) {
21  if (sourceType.getRank() > targetType.getRank())
22  return getReassociationIndicesForCollapse(sourceType.getShape(),
23  targetType.getShape());
24  if (sourceType.getRank() < targetType.getRank())
25  return getReassociationIndicesForCollapse(targetType.getShape(),
26  sourceType.getShape());
27  return llvm::None;
28 }
29 
32  ArrayRef<int64_t> targetShape) {
33  if (sourceShape.size() <= targetShape.size())
34  return llvm::None;
35  unsigned sourceDim = 0;
36  SmallVector<ReassociationIndices> reassociationMap;
37  reassociationMap.reserve(targetShape.size());
38 
39  ReassociationIndices currIndices;
40  int64_t prodOfCollapsedDims = 1;
41  while (sourceDim < sourceShape.size()) {
42  unsigned targetDim = reassociationMap.size();
43  // If we have mapped all the target dimensions stop and handle the remaining
44  // tail of size-1 dimensions explictly.
45  if (targetDim == targetShape.size())
46  break;
47 
48  int64_t currTargetShape = targetShape[targetDim];
49  while (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
50  prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape &&
51  sourceDim < sourceShape.size()) {
52  prodOfCollapsedDims *= sourceShape[sourceDim];
53  currIndices.push_back(sourceDim++);
54  }
55 
56  // If the current expanded dimension is dynamic, then the collapsed
57  // dimensions should also be dynamic and product of all previous unprocessed
58  // dimensions of the expanded shape should be 1.
59  if (sourceShape[sourceDim] == ShapedType::kDynamicSize &&
60  (currTargetShape != ShapedType::kDynamicSize ||
61  prodOfCollapsedDims != 1))
62  return llvm::None;
63 
64  // If the collapsed dim is dynamic, the current expanded dim should also
65  // be dynamic.
66  if (currTargetShape == ShapedType::kDynamicSize &&
67  sourceShape[sourceDim] != ShapedType::kDynamicSize)
68  return llvm::None;
69 
70  // For static shapes, if the product of dimensions of the expanded shape
71  // should match the collapsed dimension shape.
72  if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
73  return llvm::None;
74 
75  currIndices.push_back(sourceDim++);
76  reassociationMap.emplace_back(ReassociationIndices{});
77  std::swap(reassociationMap.back(), currIndices);
78  prodOfCollapsedDims = 1;
79  }
80  // All the dimensions in the target must have been processed.
81  if (reassociationMap.size() != targetShape.size())
82  return llvm::None;
83  // Process any remaining entries in the source shape. They all need to be
84  // 1 or dynamic.
85  for (; sourceDim < sourceShape.size(); sourceDim++) {
86  if (sourceShape[sourceDim] != ShapedType::kDynamicSize &&
87  sourceShape[sourceDim] != 1)
88  return llvm::None;
89  // The map is empty when the target type is a scalar.
90  if (!reassociationMap.empty())
91  reassociationMap.back().push_back(sourceDim);
92  }
93  return reassociationMap;
94 }
95 
97  ArrayRef<ReassociationIndices> producerReassociations,
98  ArrayRef<ReassociationIndices> consumerReassociations,
99  MLIRContext *context) {
100  SmallVector<ReassociationIndices> composedIndices;
101  // Make the producer the larger sized vector. If they are of same size, the
102  // resulting reshape is not a supported reshape op.
103  if (producerReassociations.size() == consumerReassociations.size())
104  return llvm::None;
105  if (producerReassociations.size() < consumerReassociations.size())
106  std::swap(producerReassociations, consumerReassociations);
107 
108  // Handle the corner case of the result being a rank 0 shaped type. Return an
109  // empty reassociation.
110  if (consumerReassociations.empty())
111  return composedIndices;
112 
113  size_t consumerDims = std::accumulate(
114  consumerReassociations.begin(), consumerReassociations.end(), 0,
115  [](size_t all, ReassociationIndicesRef indices) {
116  return all + indices.size();
117  });
118  if (producerReassociations.size() != consumerDims)
119  return llvm::None;
120 
121  for (ReassociationIndicesRef consumerIndices : consumerReassociations) {
122  ReassociationIndices reassociations;
123  for (int64_t consumerIndex : consumerIndices) {
124  llvm::append_range(reassociations, producerReassociations[consumerIndex]);
125  }
126  composedIndices.push_back(std::move(reassociations));
127  }
128  return composedIndices;
129 }
130 
133  MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices) {
134  SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps;
135  for (const auto &indices : reassociationIndices) {
136  SmallVector<AffineExpr, 2> reassociationMap;
137  reassociationMap.reserve(indices.size());
138  for (int64_t index : indices)
139  reassociationMap.push_back(mlir::getAffineDimExpr(index, context));
140  reassociationMaps.push_back(std::move(reassociationMap));
141  }
142  return reassociationMaps;
143 }
144 
145 template <typename AffineExprTy>
147  unsigned pos = 0;
148  for (const auto &exprs : exprArrays) {
149  for (auto expr : exprs) {
150  expr.walk([&pos](AffineExpr e) {
151  if (auto d = e.dyn_cast<AffineExprTy>())
152  pos = std::max(pos, d.getPosition());
153  });
154  }
155  }
156  return pos;
157 }
158 
160  OpBuilder &b, ArrayRef<ReassociationIndices> reassociation) {
161  SmallVector<Attribute, 4> reassociationAttr =
162  llvm::to_vector<4>(llvm::map_range(
163  reassociation, [&](const ReassociationIndices &indices) -> Attribute {
164  return b.getI64ArrayAttr(indices).cast<Attribute>();
165  }));
166  return b.getArrayAttr(reassociationAttr);
167 }
168 
170  OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs) {
171  SmallVector<ReassociationIndices, 2> reassociationIndices;
172  for (const auto &exprs : reassociationExprs) {
173  ReassociationIndices indices;
174  indices.reserve(exprs.size());
175  for (const auto &expr : exprs)
176  indices.push_back(expr.cast<AffineDimExpr>().getPosition());
177  reassociationIndices.push_back(indices);
178  }
179  return reassociationIndices;
180 }
181 
184  unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation);
185  assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 &&
186  "Expected symbol-less expressions");
188  maps.reserve(reassociation.size());
189  for (const auto &exprs : reassociation) {
190  assert(!exprs.empty());
191  maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext()));
192  }
193  return maps;
194 }
195 
197  int *invalidIndex) {
198  if (reassociation.empty())
199  return true;
200  unsigned nDims = reassociation[0].getNumDims();
201  unsigned nextExpectedDim = 0;
202  for (const auto &it : llvm::enumerate(reassociation)) {
203  auto m = it.value();
204  if (m.getNumDims() != nDims || m.getNumSymbols() != 0) {
205  if (invalidIndex)
206  *invalidIndex = it.index();
207  return false;
208  }
209  for (auto e : m.getResults()) {
210  auto d = e.dyn_cast<AffineDimExpr>();
211  if (!d || d.getPosition() != nextExpectedDim++) {
212  if (invalidIndex)
213  *invalidIndex = it.index();
214  return false;
215  }
216  }
217  }
218  if (nextExpectedDim != nDims) {
219  if (invalidIndex)
220  *invalidIndex = reassociation.size() - 1;
221  return false;
222  }
223  return true;
224 }
225 
227  function_ref<LogicalResult(const Twine &)> emitError,
228  ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
229  ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape) {
230  unsigned expandedDimStart = 0;
231  for (const auto &map : llvm::enumerate(reassociationMaps)) {
232  Optional<int64_t> dynamicShape;
233  int64_t linearizedStaticShape = 1;
234  for (const auto &dim : llvm::enumerate(
235  expandedShape.slice(expandedDimStart, map.value().size()))) {
236  if (ShapedType::isDynamic(dim.value())) {
237  if (isExpandingReshape && dynamicShape) {
238  return emitError("invalid to have a single dimension (" +
239  Twine(map.index()) +
240  ") expanded into multiple dynamic dims (" +
241  Twine(expandedDimStart + dynamicShape.getValue()) +
242  "," + Twine(expandedDimStart + dim.index()) + ")");
243  }
244  dynamicShape = dim.index();
245  } else {
246  linearizedStaticShape *= dim.value();
247  }
248  }
249  if (dynamicShape) {
250  if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
251  return emitError(
252  "expected dimension " + Twine(map.index()) +
253  " of collapsed type to be dynamic since one or more of the "
254  "corresponding dimensions in the expanded type is dynamic");
255  }
256  } else {
257  if (collapsedShape[map.index()] != linearizedStaticShape) {
258  return emitError("expected dimension " + Twine(map.index()) +
259  " of collapsed type to be static value of " +
260  Twine(linearizedStaticShape));
261  }
262  }
263  expandedDimStart += map.value().size();
264  }
265  return success();
266 }
267 
269  if (auto memrefType = type.dyn_cast<MemRefType>())
270  return !memrefType.getLayout().isIdentity();
271  return false;
272 }
Include the generated interface declarations.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
Optional< SmallVector< ReassociationIndices > > getReassociationIndicesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)
Returns the reassociation maps to collapse sourceShape to targetShape if possible.
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rules 1) if a dimension in the collapsed typ...
SmallVector< ReassociationIndices, 2 > convertReassociationMapsToIndices(OpBuilder &b, ArrayRef< ReassociationExprs > reassociationExprs)
Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
Optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:220
unsigned getPosition() const
Definition: AffineExpr.cpp:312
unsigned getMaxPosOfType(ArrayRef< ReassociationExprs > exprArrays)
U dyn_cast() const
Definition: AffineExpr.h:281
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
U dyn_cast() const
Definition: Types.h:244
Attributes are known-constant values of operations.
Definition: Attributes.h:24
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:234
Optional< SmallVector< ReassociationIndices > > composeReassociationIndices(ArrayRef< ReassociationIndices > producerReassociations, ArrayRef< ReassociationIndices > consumerReassociations, MLIRContext *context)
Compose reassociation maps that are used in pair of reshape ops where one is a producer and other is ...
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:489
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool hasNonIdentityLayout(Type type)
Returns true iff the type is a MemRefType and has a non-identity layout.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
This class helps build Operations.
Definition: Builders.h:177
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:205
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)