14 #ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
15 #define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
21 #include "llvm/ADT/StringRef.h"
46 ArrayRef<ReassociationIndices> producerReassociations,
47 ArrayRef<ReassociationIndices> consumerReassociations,
48 MLIRContext *context);
52 MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices);
55 SmallVector<AffineMap, 4>
61 ArrayRef<ReassociationIndices> reassociation);
65 OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs);
70 std::optional<SmallVector<ReassociationIndices>>
75 std::optional<SmallVector<ReassociationIndices>>
77 ArrayRef<int64_t> targetShape);
83 int *invalidIndex =
nullptr);
85 template <
typename ReshapeOpTy,
typename InverseReshapeOpTy>
89 if (reshapeOp.getSrcType() == reshapeOp.getType())
90 return reshapeOp.getSrc();
95 reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
96 if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
97 return reshapeSrcOp.getSrc();
100 if (
auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
101 return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
108 template <
typename Op,
typename T>
110 T collapsedType,
bool isExpansion) {
112 unsigned expandedRank = expandedType.getRank();
113 unsigned collapsedRank = collapsedType.getRank();
114 if (expandedRank < collapsedRank)
115 return op.
emitOpError(
"expected the expanded type, ")
116 << expandedType <<
" to have a higher (or same) rank "
117 <<
"than the collapsed type, " << collapsedType <<
'.';
119 if (collapsedRank != op.getReassociation().size())
121 << collapsedRank <<
") to equal the number of reassociation maps ("
122 << op.getReassociation().size() <<
").";
124 auto maps = op.getReassociationMaps();
126 if (it.value().getNumDims() != expandedRank)
127 return op.
emitOpError(
"expected reassociation map #")
128 << it.index() <<
" to have size equal to the expanded rank ("
129 << expandedRank <<
"), but it is " << it.value().getNumDims()
134 return op.
emitOpError(
"expected reassociation map #")
135 << invalidIdx <<
" to be valid and contiguous.";
138 [&](
const Twine &msg) {
return op->
emitOpError(msg); },
139 collapsedType.getShape(), expandedType.getShape(),
140 op.getReassociationIndices(), isExpansion);
153 ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
154 ArrayRef<ReassociationIndices> reassociationMaps,
bool isExpandingReshape);
161 template <
typename ReshapeOpTy>
167 reshapeOp.getSrc().template getDefiningOp<ReshapeOpTy>();
171 ShapedType resultType = reshapeOp.getResultType();
178 std::optional<SmallVector<ReassociationIndices>> reassociationIndices =
180 reshapeOp.getReassociationIndices(),
182 if (!reassociationIndices)
185 reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
218 template <
typename CollapseOpTy,
typename ExpandOpTy,
typename CastOpTy>
223 auto expandOp = collapseOp.getSrc().template getDefiningOp<ExpandOpTy>();
227 ShapedType srcType = expandOp.getSrcType();
228 ShapedType resultType = collapseOp.getResultType();
235 int64_t srcRank = srcType.getRank();
236 int64_t resultRank = resultType.getRank();
237 if (srcType == resultType)
241 lowerRankReassociation;
243 if (srcRank > resultRank) {
244 higherRankReassociation = expandOp.getReassociationIndices();
245 lowerRankReassociation = collapseOp.getReassociationIndices();
247 higherRankReassociation = collapseOp.getReassociationIndices();
248 lowerRankReassociation = expandOp.getReassociationIndices();
251 size_t higherRankIndicesID = 0;
253 for (
const auto &lowerRankIndices : lowerRankReassociation) {
255 while (higherRankIndicesID < higherRankReassociation.size()) {
256 auto rightmostIndex =
257 higherRankReassociation[higherRankIndicesID].back();
258 if (rightmostIndex > lowerRankIndices.back())
260 composedIndices.push_back(higherRankIndicesID++);
261 if (rightmostIndex == lowerRankIndices.back())
264 composedReassociation.push_back(composedIndices);
266 if (srcRank > resultRank) {
268 collapseOp, resultType, expandOp.getSrc(), composedReassociation);
269 }
else if (srcRank < resultRank) {
271 collapseOp, resultType, expandOp.getSrc(), composedReassociation);
275 assert(llvm::equal(srcType.getShape(), resultType.getShape()) &&
276 "expected same shape");
284 template <
typename ExpandOpTy,
typename CollapseOpTy>
289 auto collapseOp = expandOp.getSrc().template getDefiningOp<CollapseOpTy>();
293 ShapedType srcType = collapseOp.getSrcType();
294 ShapedType resultType = expandOp.getResultType();
301 int64_t srcRank = srcType.getRank();
302 int64_t resultRank = resultType.getRank();
303 if (srcType == resultType)
306 auto srcReassociation = collapseOp.getReassociationIndices();
307 auto resultReassociation = expandOp.getReassociationIndices();
308 if (srcRank > resultRank) {
309 auto composedReassociation = findCollapsingReassociation(
310 srcReassociation, resultReassociation, srcType.getShape(),
311 resultType.getShape());
312 if (!composedReassociation)
316 expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
319 auto composedReassociation =
320 findCollapsingReassociation(resultReassociation, srcReassociation,
321 resultType.getShape(), srcType.getShape());
322 if (!composedReassociation)
326 expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
333 std::optional<SmallVector<ReassociationIndices>> findCollapsingReassociation(
339 if (srcReassociation.empty())
342 for (
auto item : llvm::zip(srcReassociation, resultReassociation)) {
343 auto &srcIndices = std::get<0>(item);
344 auto &resultIndices = std::get<1>(item);
345 auto srcSubShape = srcShape.slice(srcIndices.front(), srcIndices.size());
346 auto resultSubShape =
347 resultShape.slice(resultIndices.front(), resultIndices.size());
349 if (srcSubShape.size() == resultSubShape.size()) {
350 if (srcSubShape == resultSubShape)
351 composedReassociation.push_back(srcIndices);
357 auto subShapeReassociation =
359 if (!subShapeReassociation)
363 for (
auto &subshape_indices : *subShapeReassociation) {
365 for (int64_t index : subshape_indices)
366 shape_indices.push_back(srcIndices.front() + index);
367 composedReassociation.push_back(shape_indices);
370 return {std::move(composedReassociation)};
380 ArrayRef<Range> sliceParams);
417 class SliceFromCollapseHelper {
419 SliceFromCollapseHelper(ArrayRef<ReassociationIndices> reassociationIndices,
420 ArrayRef<OpFoldResult> collapseShapeInputShape,
421 ArrayRef<OpFoldResult> collapseShapeOutputShape,
422 ArrayRef<Range> extractSliceParams)
423 : reassociationIndices(reassociationIndices),
424 collapseShapeInputShape(collapseShapeInputShape),
425 collapseShapeOutputShape(collapseShapeOutputShape),
426 sliceParams(extractSliceParams),
429 extractSliceParams)) {}
441 SmallVector<Range> getExtractSliceParams(MLIRContext *ctx,
442 ArrayRef<ValueRange> multiIndices);
449 SmallVector<Range> getInsertSliceParams(MLIRContext *ctx,
450 ValueRange tileIndices);
453 SmallVector<ReassociationIndices> reassociationIndices;
454 SmallVector<OpFoldResult> collapseShapeInputShape;
455 SmallVector<OpFoldResult> collapseShapeOutputShape;
456 SmallVector<Range> sliceParams;
457 llvm::SmallBitVector linearizedDimensions;
458 llvm::SmallBitVector slicedDimensions;
463 struct CollapseShapeRankReducingSliceSimplificationInfo {
468 std::optional<SmallVector<ReassociationIndices>> newReassociationIndices;
507 FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
508 getSimplifyCollapseShapeWithRankReducingSliceInfo(
509 RankedTensorType sourceType,
510 ArrayRef<ReassociationIndices> reassociationIndices);
512 struct PackingMetadata {
513 SmallVector<int64_t> insertPositions;
514 SmallVector<int64_t> outerPositions;
515 SmallVector<ReassociationIndices> reassociations;
524 PackingMetadata computePackingMetadata(int64_t packedRank,
525 ArrayRef<int64_t> innerDimPos);
static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, ArrayRef< MeshAxis > meshAxes, int64_t sliceAxis)
MLIRContext * getContext() const
This class represents a single result from folding an operation.
This provides public APIs that all operations should have.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
llvm::function_ref< Fn > function_ref
llvm::SmallBitVector getSlicedDimensions(ArrayRef< OpFoldResult > sliceInputShape, ArrayRef< Range > sliceParams)
The input parameters offsets, sizes, strides specify a rectangular non rank-reducing slice of the col...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
bool hasNonIdentityLayout(Type type)
Returns true iff the type is a MemRefType and has a non-identity layout.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rules 1) if a dimension in the collapsed typ...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)
Returns the reassociation maps to collapse sourceShape to targetShape if possible.
ArrayRef< int64_t > ReassociationIndicesRef
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
std::optional< SmallVector< ReassociationIndices > > composeReassociationIndices(ArrayRef< ReassociationIndices > producerReassociations, ArrayRef< ReassociationIndices > consumerReassociations, MLIRContext *context)
Compose reassociation maps that are used in pair of reshape ops where one is a producer and other is ...
SmallVector< int64_t, 2 > ReassociationIndices
llvm::SmallBitVector getLinearizedDimensions(ArrayRef< ReassociationIndices > reassociationIndices)
Determine which dimensions are linearized by a tensor.collapse_shape op by inspecting its reassociati...
SmallVector< ReassociationIndices, 2 > convertReassociationMapsToIndices(OpBuilder &b, ArrayRef< ReassociationExprs > reassociationExprs)
Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
LogicalResult matchAndRewrite(CollapseOpTy collapseOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExpandOpTy expandOp, PatternRewriter &rewriter) const override
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, PatternRewriter &rewriter) const override
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...