14 #ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
15 #define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
21 #include "llvm/ADT/StringRef.h"
46 ArrayRef<ReassociationIndices> producerReassociations,
47 ArrayRef<ReassociationIndices> consumerReassociations,
48 MLIRContext *context);
52 MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices);
55 SmallVector<AffineMap, 4>
61 ArrayRef<ReassociationIndices> reassociation);
65 OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs);
70 std::optional<SmallVector<ReassociationIndices>>
75 std::optional<SmallVector<ReassociationIndices>>
77 ArrayRef<int64_t> targetShape);
83 int *invalidIndex =
nullptr);
85 template <
typename ReshapeOpTy,
typename InverseReshapeOpTy>
91 reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
92 if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
93 return reshapeSrcOp.getSrc();
96 return elements.reshape(
97 reshapeOp.getResult().getType().template cast<ShapedType>());
104 template <
typename Op,
typename T>
106 T collapsedType,
bool isExpansion) {
107 unsigned expandedRank = expandedType.getRank();
108 unsigned collapsedRank = collapsedType.getRank();
109 if (expandedRank < collapsedRank)
112 <<
" to have higher rank than the type = " << collapsedType;
113 if (expandedRank == 0)
114 return op.
emitOpError(
"expected non-zero memref ranks");
115 if (expandedRank == collapsedRank)
116 return op.
emitOpError(
"expected to collapse or expand dims");
118 if (collapsedRank == 0) {
121 if (llvm::any_of(expandedType.getShape(),
122 [](int64_t dim) ->
bool { return dim != 1; }))
123 return op.
emitOpError(
"invalid to reshape tensor/memref with non-unit "
124 "extent dimensions to zero-rank tensor/memref");
127 if (collapsedRank != op.getReassociation().size())
128 return op.
emitOpError(
"expected rank of the collapsed type(")
129 << collapsedRank <<
") to be the number of reassociation maps("
130 << op.getReassociation().size() <<
")";
131 auto maps = op.getReassociationMaps();
133 if (it.value().getNumDims() != expandedRank)
134 return op.
emitOpError(
"expected reassociation map #")
135 << it.index() <<
" of same rank as expanded memref("
136 << expandedRank <<
"), but got " << it.value().getNumDims();
139 return op.
emitOpError(
"expected reassociation map #")
140 << invalidIdx <<
" to be valid and contiguous";
154 ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
155 ArrayRef<ReassociationIndices> reassociationMaps,
bool isExpandingReshape);
157 template <
typename OpTy>
159 ShapedType expandedType,
160 bool isExpandingReshape) {
162 [&](
const Twine &msg) {
return op->emitOpError(msg); },
163 collapsedType.getShape(), expandedType.getShape(),
164 op.getReassociationIndices(), isExpandingReshape);
172 template <
typename ReshapeOpTy>
178 reshapeOp.getSrc().template getDefiningOp<ReshapeOpTy>();
182 ShapedType resultType = reshapeOp.getResultType();
189 std::optional<SmallVector<ReassociationIndices>> reassociationIndices =
191 reshapeOp.getReassociationIndices(),
193 if (!reassociationIndices)
196 reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
229 template <
typename CollapseOpTy,
typename ExpandOpTy,
typename CastOpTy>
234 auto expandOp = collapseOp.getSrc().template getDefiningOp<ExpandOpTy>();
238 ShapedType srcType = expandOp.getSrcType();
239 ShapedType resultType = collapseOp.getResultType();
246 int64_t srcRank = srcType.getRank();
247 int64_t resultRank = resultType.getRank();
248 if (srcType == resultType)
252 lowerRankReassociation;
254 if (srcRank > resultRank) {
255 higherRankReassociation = expandOp.getReassociationIndices();
256 lowerRankReassociation = collapseOp.getReassociationIndices();
258 higherRankReassociation = collapseOp.getReassociationIndices();
259 lowerRankReassociation = expandOp.getReassociationIndices();
262 size_t higherRankIndicesID = 0;
264 for (
const auto &lowerRankIndices : lowerRankReassociation) {
266 while (higherRankIndicesID < higherRankReassociation.size()) {
267 auto rightmostIndex =
268 higherRankReassociation[higherRankIndicesID].back();
269 if (rightmostIndex > lowerRankIndices.back())
271 composedIndices.push_back(higherRankIndicesID++);
272 if (rightmostIndex == lowerRankIndices.back())
275 composedReassociation.push_back(composedIndices);
277 if (srcRank > resultRank) {
279 collapseOp, resultType, expandOp.getSrc(), composedReassociation);
280 }
else if (srcRank < resultRank) {
282 collapseOp, resultType, expandOp.getSrc(), composedReassociation);
286 assert(llvm::equal(srcType.getShape(), resultType.getShape()) &&
287 "expected same shape");
295 template <
typename ExpandOpTy,
typename CollapseOpTy>
300 auto collapseOp = expandOp.getSrc().template getDefiningOp<CollapseOpTy>();
304 ShapedType srcType = collapseOp.getSrcType();
305 ShapedType resultType = expandOp.getResultType();
312 int64_t srcRank = srcType.getRank();
313 int64_t resultRank = resultType.getRank();
314 if (srcType == resultType)
317 auto srcReassociation = collapseOp.getReassociationIndices();
318 auto resultReassociation = expandOp.getReassociationIndices();
319 if (srcRank > resultRank) {
320 auto composedReassociation = findCollapsingReassociation(
321 srcReassociation, resultReassociation, srcType.getShape(),
322 resultType.getShape());
323 if (!composedReassociation)
327 expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
330 auto composedReassociation =
331 findCollapsingReassociation(resultReassociation, srcReassociation,
332 resultType.getShape(), srcType.getShape());
333 if (!composedReassociation)
337 expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
344 std::optional<SmallVector<ReassociationIndices>> findCollapsingReassociation(
350 if (srcReassociation.empty())
353 for (
auto item : llvm::zip(srcReassociation, resultReassociation)) {
354 auto &srcIndices = std::get<0>(item);
355 auto &resultIndices = std::get<1>(item);
356 auto srcSubShape = srcShape.slice(srcIndices.front(), srcIndices.size());
357 auto resultSubShape =
358 resultShape.slice(resultIndices.front(), resultIndices.size());
360 if (srcSubShape.size() == resultSubShape.size()) {
361 if (srcSubShape == resultSubShape)
362 composedReassociation.push_back(srcIndices);
368 auto subShapeReassociation =
370 if (!subShapeReassociation)
374 for (
auto &subshape_indices : *subShapeReassociation) {
376 for (int64_t index : subshape_indices)
377 shape_indices.push_back(srcIndices.front() + index);
378 composedReassociation.push_back(shape_indices);
381 return {std::move(composedReassociation)};
391 ArrayRef<Range> sliceParams);
428 class SliceFromCollapseHelper {
430 SliceFromCollapseHelper(ArrayRef<ReassociationIndices> reassociationIndices,
431 ArrayRef<OpFoldResult> collapseShapeInputShape,
432 ArrayRef<OpFoldResult> collapseShapeOutputShape,
433 ArrayRef<Range> extractSliceParams)
434 : reassociationIndices(reassociationIndices),
435 collapseShapeInputShape(collapseShapeInputShape),
436 collapseShapeOutputShape(collapseShapeOutputShape),
437 sliceParams(extractSliceParams),
440 extractSliceParams)) {}
452 SmallVector<Range> getExtractSliceParams(MLIRContext *ctx,
453 ArrayRef<ValueRange> multiIndices);
460 SmallVector<Range> getInsertSliceParams(MLIRContext *ctx,
461 ValueRange tileIndices);
464 SmallVector<ReassociationIndices> reassociationIndices;
465 SmallVector<OpFoldResult> collapseShapeInputShape;
466 SmallVector<OpFoldResult> collapseShapeOutputShape;
467 SmallVector<Range> sliceParams;
468 llvm::SmallBitVector linearizedDimensions;
469 llvm::SmallBitVector slicedDimensions;
474 struct CollapseShapeRankReducingSliceSimplificationInfo {
476 RankedTensorType sliceResultType;
479 std::optional<SmallVector<ReassociationIndices>> newReassociationIndices;
518 FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
519 getSimplifyCollapseShapeWithRankReducingSliceInfo(
520 RankedTensorType sourceType,
521 ArrayRef<ReassociationIndices> reassociationIndices);
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
This class represents a single result from folding an operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This provides public APIs that all operations should have.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
llvm::function_ref< Fn > function_ref
llvm::SmallBitVector getSlicedDimensions(ArrayRef< OpFoldResult > sliceInputShape, ArrayRef< Range > sliceParams)
The input parameters offsets, sizes, strides specify a rectangular non rank-reducing slice of the col...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
bool hasNonIdentityLayout(Type type)
Returns true iff the type is a MemRefType and has a non-identity layout.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rules 1) if a dimension in the collapsed typ...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)
Returns the reassociation maps to collapse sourceShape to targetShape if possible.
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
std::optional< SmallVector< ReassociationIndices > > composeReassociationIndices(ArrayRef< ReassociationIndices > producerReassociations, ArrayRef< ReassociationIndices > consumerReassociations, MLIRContext *context)
Compose reassociation maps that are used in pair of reshape ops where one is a producer and other is ...
SmallVector< int64_t, 2 > ReassociationIndices
llvm::SmallBitVector getLinearizedDimensions(ArrayRef< ReassociationIndices > reassociationIndices)
Determine which dimensions are linearized by a tensor.collapse_shape op by inspecting its reassociati...
static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType, ShapedType expandedType, bool isExpandingReshape)
SmallVector< ReassociationIndices, 2 > convertReassociationMapsToIndices(OpBuilder &b, ArrayRef< ReassociationExprs > reassociationExprs)
Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
LogicalResult matchAndRewrite(CollapseOpTy collapseOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExpandOpTy expandOp, PatternRewriter &rewriter) const override
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, PatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...