14 #ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
15 #define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
21 #include "llvm/ADT/StringRef.h"
46 ArrayRef<ReassociationIndices> producerReassociations,
47 ArrayRef<ReassociationIndices> consumerReassociations,
48 MLIRContext *context);
52 MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices);
55 SmallVector<AffineMap, 4>
61 ArrayRef<ReassociationIndices> reassociation);
65 ArrayRef<ReassociationExprs> reassociationExprs);
70 std::optional<SmallVector<ReassociationIndices>>
75 std::optional<SmallVector<ReassociationIndices>>
77 ArrayRef<int64_t> targetShape);
83 int *invalidIndex =
nullptr);
85 template <
typename ReshapeOpTy,
typename InverseReshapeOpTy>
89 if (reshapeOp.getSrcType() == reshapeOp.getType())
90 return reshapeOp.getSrc();
93 if (
auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
94 return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
99 reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
102 auto srcType = reshapeSrcOp.getSrcType();
103 auto resultType = reshapeOp.getResultType();
104 if (srcType != resultType)
107 if (llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) {
108 return reshapeSrcOp.getSrc();
117 auto reassociations = reshapeOp.getReassociationIndices();
118 if (reassociations != reshapeSrcOp.getReassociationIndices())
122 if (srcType.getRank() < reshapeSrcOp.getResultType().getRank())
123 return reshapeSrcOp.getSrc();
124 if (llvm::all_of(reassociations, [&](
auto reInd) {
126 srcType.getShape().slice(reInd.front(), reInd.size());
127 return llvm::count_if(srcSlice, ShapedType::isDynamic) < 2;
129 return reshapeSrcOp.getSrc();
136 template <
typename Op,
typename T>
138 T collapsedType,
bool isExpansion) {
140 unsigned expandedRank = expandedType.getRank();
141 unsigned collapsedRank = collapsedType.getRank();
142 if (expandedRank < collapsedRank)
143 return op.
emitOpError(
"expected the expanded type, ")
144 << expandedType <<
" to have a higher (or same) rank "
145 <<
"than the collapsed type, " << collapsedType <<
'.';
147 if (collapsedRank != op.getReassociation().size())
149 << collapsedRank <<
") to equal the number of reassociation maps ("
150 << op.getReassociation().size() <<
").";
152 auto maps = op.getReassociationMaps();
154 if (it.value().getNumDims() != expandedRank)
155 return op.
emitOpError(
"expected reassociation map #")
156 << it.index() <<
" to have size equal to the expanded rank ("
157 << expandedRank <<
"), but it is " << it.value().getNumDims()
162 return op.
emitOpError(
"expected reassociation map #")
163 << invalidIdx <<
" to be valid and contiguous.";
166 [&](
const Twine &msg) {
return op->
emitOpError(msg); },
167 collapsedType.getShape(), expandedType.getShape(),
168 op.getReassociationIndices(), isExpansion);
178 ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
179 ArrayRef<ReassociationIndices> reassociationMaps,
bool isExpandingReshape);
188 template <
typename ReshapeOpTy, ReshapeOpKind opKind>
194 reshapeOp.getSrc().template getDefiningOp<ReshapeOpTy>();
198 ShapedType resultType = reshapeOp.getResultType();
205 std::optional<SmallVector<ReassociationIndices>> reassociationIndices =
207 reshapeOp.getReassociationIndices(),
209 if (!reassociationIndices)
215 reshapeOp.getOutputShape(), rewriter));
217 reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices,
221 reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
255 template <
typename CollapseOpTy,
typename ExpandOpTy,
typename CastOpTy,
256 typename DimOpTy,
typename TensorTy>
261 auto expandOp = collapseOp.getSrc().template getDefiningOp<ExpandOpTy>();
265 ShapedType srcType = expandOp.getSrcType();
266 ShapedType resultType = collapseOp.getResultType();
273 int64_t srcRank = srcType.getRank();
274 int64_t resultRank = resultType.getRank();
275 if (srcType == resultType)
279 lowerRankReassociation;
281 if (srcRank > resultRank) {
282 higherRankReassociation = expandOp.getReassociationIndices();
283 lowerRankReassociation = collapseOp.getReassociationIndices();
285 higherRankReassociation = collapseOp.getReassociationIndices();
286 lowerRankReassociation = expandOp.getReassociationIndices();
289 size_t higherRankIndicesID = 0;
291 for (
const auto &lowerRankIndices : lowerRankReassociation) {
293 while (higherRankIndicesID < higherRankReassociation.size()) {
294 auto rightmostIndex =
295 higherRankReassociation[higherRankIndicesID].back();
296 if (rightmostIndex > lowerRankIndices.back())
298 composedIndices.push_back(higherRankIndicesID++);
299 if (rightmostIndex == lowerRankIndices.back())
302 composedReassociation.push_back(composedIndices);
304 if (srcRank > resultRank) {
306 collapseOp, resultType, expandOp.getSrc(), composedReassociation);
307 }
else if (srcRank < resultRank) {
309 collapseOp, resultType, expandOp.getSrc(), composedReassociation);
313 assert(llvm::equal(srcType.getShape(), resultType.getShape()) &&
314 "expected same shape");
322 template <
typename ExpandOpTy,
typename CollapseOpTy>
327 auto collapseOp = expandOp.getSrc().template getDefiningOp<CollapseOpTy>();
331 ShapedType srcType = collapseOp.getSrcType();
332 ShapedType resultType = expandOp.getResultType();
339 int64_t srcRank = srcType.getRank();
340 int64_t resultRank = resultType.getRank();
341 if (srcRank == resultRank)
344 auto srcReassociation = collapseOp.getReassociationIndices();
345 auto resultReassociation = expandOp.getReassociationIndices();
346 if (srcRank > resultRank) {
347 auto composedReassociation = findCollapsingReassociation(
348 srcReassociation, resultReassociation, srcType.getShape(),
349 resultType.getShape());
350 if (!composedReassociation)
354 expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
357 auto composedReassociation =
358 findCollapsingReassociation(resultReassociation, srcReassociation,
359 resultType.getShape(), srcType.getShape());
360 if (!composedReassociation)
364 expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
366 expandOp, resultType, collapseOp.getSrc(), *composedReassociation,
374 std::optional<SmallVector<ReassociationIndices>> findCollapsingReassociation(
380 if (srcReassociation.empty())
383 for (
auto item : llvm::zip(srcReassociation, resultReassociation)) {
384 auto &srcIndices = std::get<0>(item);
385 auto &resultIndices = std::get<1>(item);
386 auto srcSubShape = srcShape.slice(srcIndices.front(), srcIndices.size());
387 auto resultSubShape =
388 resultShape.slice(resultIndices.front(), resultIndices.size());
390 if (srcSubShape.size() == resultSubShape.size()) {
391 if (srcSubShape != resultSubShape ||
392 llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2) {
395 for (
auto index : llvm::seq<int64_t>(0, srcSubShape.size())) {
396 composedReassociation.emplace_back(1, srcIndices.front() + index);
402 auto subShapeReassociation =
404 if (!subShapeReassociation)
408 for (
auto &subshapeIndices : *subShapeReassociation) {
410 for (int64_t index : subshapeIndices)
411 shapeIndices.push_back(srcIndices.front() + index);
412 composedReassociation.push_back(shapeIndices);
415 return {std::move(composedReassociation)};
425 ArrayRef<Range> sliceParams);
462 class SliceFromCollapseHelper {
464 SliceFromCollapseHelper(ArrayRef<ReassociationIndices> reassociationIndices,
465 ArrayRef<OpFoldResult> collapseShapeInputShape,
466 ArrayRef<OpFoldResult> collapseShapeOutputShape,
467 ArrayRef<Range> extractSliceParams)
468 : reassociationIndices(reassociationIndices),
469 collapseShapeInputShape(collapseShapeInputShape),
470 collapseShapeOutputShape(collapseShapeOutputShape),
471 sliceParams(extractSliceParams),
474 extractSliceParams)) {}
486 SmallVector<Range> getExtractSliceParams(MLIRContext *ctx,
487 ArrayRef<ValueRange> multiIndices);
494 SmallVector<Range> getInsertSliceParams(MLIRContext *ctx,
495 ValueRange tileIndices);
498 SmallVector<ReassociationIndices> reassociationIndices;
499 SmallVector<OpFoldResult> collapseShapeInputShape;
500 SmallVector<OpFoldResult> collapseShapeOutputShape;
501 SmallVector<Range> sliceParams;
502 llvm::SmallBitVector linearizedDimensions;
503 llvm::SmallBitVector slicedDimensions;
508 struct CollapseShapeRankReducingSliceSimplificationInfo {
513 std::optional<SmallVector<ReassociationIndices>> newReassociationIndices;
552 FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
553 getSimplifyCollapseShapeWithRankReducingSliceInfo(
554 RankedTensorType sourceType,
555 ArrayRef<ReassociationIndices> reassociationIndices);
557 struct PackingMetadata {
558 SmallVector<int64_t> insertPositions;
559 SmallVector<int64_t> outerPositions;
560 SmallVector<ReassociationIndices> reassociations;
569 PackingMetadata computePackingMetadata(int64_t packedRank,
570 ArrayRef<int64_t> innerDimPos);
static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, ArrayRef< MeshAxis > meshAxes, int64_t sliceAxis)
MLIRContext * getContext() const
This class represents a single result from folding an operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This provides public APIs that all operations should have.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
llvm::SmallBitVector getSlicedDimensions(ArrayRef< OpFoldResult > sliceInputShape, ArrayRef< Range > sliceParams)
The input parameters offsets, sizes, strides specify a rectangular non rank-reducing slice of the col...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
bool hasNonIdentityLayout(Type type)
Returns true iff the type is a MemRefType and has a non-identity layout.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< ReassociationIndices, 2 > convertReassociationMapsToIndices(ArrayRef< ReassociationExprs > reassociationExprs)
Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rule: if a dimension in the collapsed type i...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)
Returns the reassociation maps to collapse sourceShape to targetShape if possible.
ArrayRef< int64_t > ReassociationIndicesRef
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
std::optional< SmallVector< ReassociationIndices > > composeReassociationIndices(ArrayRef< ReassociationIndices > producerReassociations, ArrayRef< ReassociationIndices > consumerReassociations, MLIRContext *context)
Compose reassociation maps that are used in pair of reshape ops where one is a producer and other is ...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
llvm::SmallBitVector getLinearizedDimensions(ArrayRef< ReassociationIndices > reassociationIndices)
Determine which dimensions are linearized by a tensor.collapse_shape op by inspecting its reassociati...
SmallVector< int64_t, 2 > ReassociationIndices
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
LogicalResult matchAndRewrite(CollapseOpTy collapseOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExpandOpTy expandOp, PatternRewriter &rewriter) const override
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...