14 #ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
15 #define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
21 #include "llvm/ADT/StringRef.h"
46 ArrayRef<ReassociationIndices> producerReassociations,
47 ArrayRef<ReassociationIndices> consumerReassociations,
48 MLIRContext *context);
52 MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices);
55 SmallVector<AffineMap, 4>
61 ArrayRef<ReassociationIndices> reassociation);
65 OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs);
70 std::optional<SmallVector<ReassociationIndices>>
75 std::optional<SmallVector<ReassociationIndices>>
77 ArrayRef<int64_t> targetShape);
83 int *invalidIndex =
nullptr);
85 template <
typename ReshapeOpTy,
typename InverseReshapeOpTy>
91 reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
92 if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
93 return reshapeSrcOp.getSrc();
95 if (
auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front())) {
96 return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
103 template <
typename Op,
typename T>
105 T collapsedType,
bool isExpansion) {
106 unsigned expandedRank = expandedType.getRank();
107 unsigned collapsedRank = collapsedType.getRank();
108 if (expandedRank < collapsedRank)
111 <<
" to have higher rank than the type = " << collapsedType;
112 if (expandedRank == 0)
113 return op.
emitOpError(
"expected non-zero memref ranks");
114 if (expandedRank == collapsedRank)
115 return op.
emitOpError(
"expected to collapse or expand dims");
117 if (collapsedRank == 0) {
120 if (llvm::any_of(expandedType.getShape(),
121 [](int64_t dim) ->
bool { return dim != 1; }))
122 return op.
emitOpError(
"invalid to reshape tensor/memref with non-unit "
123 "extent dimensions to zero-rank tensor/memref");
126 if (collapsedRank != op.getReassociation().size())
127 return op.
emitOpError(
"expected rank of the collapsed type(")
128 << collapsedRank <<
") to be the number of reassociation maps("
129 << op.getReassociation().size() <<
")";
130 auto maps = op.getReassociationMaps();
132 if (it.value().getNumDims() != expandedRank)
133 return op.
emitOpError(
"expected reassociation map #")
134 << it.index() <<
" of same rank as expanded memref("
135 << expandedRank <<
"), but got " << it.value().getNumDims();
138 return op.
emitOpError(
"expected reassociation map #")
139 << invalidIdx <<
" to be valid and contiguous";
153 ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
154 ArrayRef<ReassociationIndices> reassociationMaps,
bool isExpandingReshape);
156 template <
typename OpTy>
158 ShapedType expandedType,
159 bool isExpandingReshape) {
161 [&](
const Twine &msg) {
return op->
emitOpError(msg); },
162 collapsedType.getShape(), expandedType.getShape(),
163 op.getReassociationIndices(), isExpandingReshape);
171 template <
typename ReshapeOpTy>
177 reshapeOp.getSrc().template getDefiningOp<ReshapeOpTy>();
181 ShapedType resultType = reshapeOp.getResultType();
188 std::optional<SmallVector<ReassociationIndices>> reassociationIndices =
190 reshapeOp.getReassociationIndices(),
192 if (!reassociationIndices)
195 reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
228 template <
typename CollapseOpTy,
typename ExpandOpTy,
typename CastOpTy>
233 auto expandOp = collapseOp.getSrc().template getDefiningOp<ExpandOpTy>();
237 ShapedType srcType = expandOp.getSrcType();
238 ShapedType resultType = collapseOp.getResultType();
245 int64_t srcRank = srcType.getRank();
246 int64_t resultRank = resultType.getRank();
247 if (srcType == resultType)
251 lowerRankReassociation;
253 if (srcRank > resultRank) {
254 higherRankReassociation = expandOp.getReassociationIndices();
255 lowerRankReassociation = collapseOp.getReassociationIndices();
257 higherRankReassociation = collapseOp.getReassociationIndices();
258 lowerRankReassociation = expandOp.getReassociationIndices();
261 size_t higherRankIndicesID = 0;
263 for (
const auto &lowerRankIndices : lowerRankReassociation) {
265 while (higherRankIndicesID < higherRankReassociation.size()) {
266 auto rightmostIndex =
267 higherRankReassociation[higherRankIndicesID].back();
268 if (rightmostIndex > lowerRankIndices.back())
270 composedIndices.push_back(higherRankIndicesID++);
271 if (rightmostIndex == lowerRankIndices.back())
274 composedReassociation.push_back(composedIndices);
276 if (srcRank > resultRank) {
278 collapseOp, resultType, expandOp.getSrc(), composedReassociation);
279 }
else if (srcRank < resultRank) {
281 collapseOp, resultType, expandOp.getSrc(), composedReassociation);
285 assert(llvm::equal(srcType.getShape(), resultType.getShape()) &&
286 "expected same shape");
294 template <
typename ExpandOpTy,
typename CollapseOpTy>
299 auto collapseOp = expandOp.getSrc().template getDefiningOp<CollapseOpTy>();
303 ShapedType srcType = collapseOp.getSrcType();
304 ShapedType resultType = expandOp.getResultType();
311 int64_t srcRank = srcType.getRank();
312 int64_t resultRank = resultType.getRank();
313 if (srcType == resultType)
316 auto srcReassociation = collapseOp.getReassociationIndices();
317 auto resultReassociation = expandOp.getReassociationIndices();
318 if (srcRank > resultRank) {
319 auto composedReassociation = findCollapsingReassociation(
320 srcReassociation, resultReassociation, srcType.getShape(),
321 resultType.getShape());
322 if (!composedReassociation)
326 expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
329 auto composedReassociation =
330 findCollapsingReassociation(resultReassociation, srcReassociation,
331 resultType.getShape(), srcType.getShape());
332 if (!composedReassociation)
336 expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
343 std::optional<SmallVector<ReassociationIndices>> findCollapsingReassociation(
349 if (srcReassociation.empty())
352 for (
auto item : llvm::zip(srcReassociation, resultReassociation)) {
353 auto &srcIndices = std::get<0>(item);
354 auto &resultIndices = std::get<1>(item);
355 auto srcSubShape = srcShape.slice(srcIndices.front(), srcIndices.size());
356 auto resultSubShape =
357 resultShape.slice(resultIndices.front(), resultIndices.size());
359 if (srcSubShape.size() == resultSubShape.size()) {
360 if (srcSubShape == resultSubShape)
361 composedReassociation.push_back(srcIndices);
367 auto subShapeReassociation =
369 if (!subShapeReassociation)
373 for (
auto &subshape_indices : *subShapeReassociation) {
375 for (int64_t index : subshape_indices)
376 shape_indices.push_back(srcIndices.front() + index);
377 composedReassociation.push_back(shape_indices);
380 return {std::move(composedReassociation)};
390 ArrayRef<Range> sliceParams);
427 class SliceFromCollapseHelper {
429 SliceFromCollapseHelper(ArrayRef<ReassociationIndices> reassociationIndices,
430 ArrayRef<OpFoldResult> collapseShapeInputShape,
431 ArrayRef<OpFoldResult> collapseShapeOutputShape,
432 ArrayRef<Range> extractSliceParams)
433 : reassociationIndices(reassociationIndices),
434 collapseShapeInputShape(collapseShapeInputShape),
435 collapseShapeOutputShape(collapseShapeOutputShape),
436 sliceParams(extractSliceParams),
439 extractSliceParams)) {}
451 SmallVector<Range> getExtractSliceParams(MLIRContext *ctx,
452 ArrayRef<ValueRange> multiIndices);
459 SmallVector<Range> getInsertSliceParams(MLIRContext *ctx,
460 ValueRange tileIndices);
463 SmallVector<ReassociationIndices> reassociationIndices;
464 SmallVector<OpFoldResult> collapseShapeInputShape;
465 SmallVector<OpFoldResult> collapseShapeOutputShape;
466 SmallVector<Range> sliceParams;
467 llvm::SmallBitVector linearizedDimensions;
468 llvm::SmallBitVector slicedDimensions;
473 struct CollapseShapeRankReducingSliceSimplificationInfo {
475 RankedTensorType sliceResultType;
478 std::optional<SmallVector<ReassociationIndices>> newReassociationIndices;
517 FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
518 getSimplifyCollapseShapeWithRankReducingSliceInfo(
519 RankedTensorType sourceType,
520 ArrayRef<ReassociationIndices> reassociationIndices);
522 struct PackingMetadata {
523 SmallVector<int64_t> insertPositions;
524 SmallVector<int64_t> outerPositions;
525 SmallVector<ReassociationIndices> reassociations;
534 PackingMetadata computePackingMetadata(int64_t packedRank,
535 ArrayRef<int64_t> innerDimPos);
MLIRContext * getContext() const
This class represents a single result from folding an operation.
This provides public APIs that all operations should have.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
llvm::function_ref< Fn > function_ref
llvm::SmallBitVector getSlicedDimensions(ArrayRef< OpFoldResult > sliceInputShape, ArrayRef< Range > sliceParams)
The input parameters offsets, sizes, strides specify a rectangular non rank-reducing slice of the col...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
bool hasNonIdentityLayout(Type type)
Returns true iff the type is a MemRefType and has a non-identity layout.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rules 1) if a dimension in the collapsed typ...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)
Returns the reassociation maps to collapse sourceShape to targetShape if possible.
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
std::optional< SmallVector< ReassociationIndices > > composeReassociationIndices(ArrayRef< ReassociationIndices > producerReassociations, ArrayRef< ReassociationIndices > consumerReassociations, MLIRContext *context)
Compose reassociation maps that are used in pair of reshape ops where one is a producer and other is ...
SmallVector< int64_t, 2 > ReassociationIndices
llvm::SmallBitVector getLinearizedDimensions(ArrayRef< ReassociationIndices > reassociationIndices)
Determine which dimensions are linearized by a tensor.collapse_shape op by inspecting its reassociati...
static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType, ShapedType expandedType, bool isExpandingReshape)
SmallVector< ReassociationIndices, 2 > convertReassociationMapsToIndices(OpBuilder &b, ArrayRef< ReassociationExprs > reassociationExprs)
Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
LogicalResult matchAndRewrite(CollapseOpTy collapseOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExpandOpTy expandOp, PatternRewriter &rewriter) const override
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, PatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...