14#ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
15#define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
22#include "llvm/ADT/StringRef.h"
47 ArrayRef<ReassociationIndices> producerReassociations,
48 ArrayRef<ReassociationIndices> consumerReassociations,
49 MLIRContext *context);
53 MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices);
56SmallVector<AffineMap, 4>
62 ArrayRef<ReassociationIndices> reassociation);
66 ArrayRef<ReassociationExprs> reassociationExprs);
71std::optional<SmallVector<ReassociationIndices>>
76std::optional<SmallVector<ReassociationIndices>>
78 ArrayRef<int64_t> targetShape);
84 int *invalidIndex =
nullptr);
86template <
typename ReshapeOpTy,
typename InverseReshapeOpTy>
90 if (reshapeOp.getSrcType() == reshapeOp.getType())
91 return reshapeOp.getSrc();
96 if (
auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front())) {
97 auto resultType = cast<ShapedType>(reshapeOp.getResult().getType());
98 if (resultType.hasStaticShape())
99 return elements.reshape(resultType);
105 reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
108 auto srcType = reshapeSrcOp.getSrcType();
109 auto resultType = reshapeOp.getResultType();
110 if (srcType != resultType)
113 if (llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) {
114 return reshapeSrcOp.getSrc();
123 auto reassociations = reshapeOp.getReassociationIndices();
124 if (reassociations != reshapeSrcOp.getReassociationIndices())
128 if (srcType.getRank() < reshapeSrcOp.getResultType().getRank())
129 return reshapeSrcOp.getSrc();
130 if (llvm::all_of(reassociations, [&](
auto reInd) {
132 srcType.getShape().slice(reInd.front(), reInd.size());
133 return llvm::count_if(srcSlice, ShapedType::isDynamic) < 2;
135 return reshapeSrcOp.getSrc();
142template <
typename Op,
typename T>
143static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
144 T collapsedType,
bool isExpansion) {
146 unsigned expandedRank = expandedType.getRank();
147 unsigned collapsedRank = collapsedType.getRank();
148 if (expandedRank < collapsedRank)
149 return op.emitOpError(
"expected the expanded type, ")
150 << expandedType <<
" to have a higher (or same) rank "
151 <<
"than the collapsed type, " << collapsedType <<
'.';
153 if (collapsedRank != op.getReassociation().size())
154 return op.emitOpError(
"expected collapsed rank (")
155 << collapsedRank <<
") to equal the number of reassociation maps ("
156 << op.getReassociation().size() <<
").";
158 auto maps = op.getReassociationMaps();
159 for (
auto it : llvm::enumerate(maps))
160 if (it.value().getNumDims() != expandedRank)
161 return op.emitOpError(
"expected reassociation map #")
162 << it.index() <<
" to have size equal to the expanded rank ("
163 << expandedRank <<
"), but it is " << it.value().getNumDims()
168 return op.emitOpError(
"expected reassociation map #")
169 << invalidIdx <<
" to be valid and contiguous.";
171 return reshapeLikeShapesAreCompatible(
172 [&](
const Twine &msg) {
return op->emitOpError(msg); },
173 collapsedType.getShape(), expandedType.getShape(),
174 op.getReassociationIndices(), isExpansion);
182LogicalResult reshapeLikeShapesAreCompatible(
188bool hasNonIdentityLayout(
Type type);
190enum class ReshapeOpKind { kExpand, kCollapse };
194template <
typename ReshapeOpTy, ReshapeOpKind opKind>
195struct ComposeReassociativeReshapeOps :
public OpRewritePattern<ReshapeOpTy> {
196 using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
197 LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
198 PatternRewriter &rewriter)
const override {
200 reshapeOp.getSrc().template getDefiningOp<ReshapeOpTy>();
204 ShapedType resultType = reshapeOp.getResultType();
206 if (hasNonIdentityLayout(srcReshapeOp.getSrc().getType()) ||
207 hasNonIdentityLayout(reshapeOp.getSrc().getType()) ||
208 hasNonIdentityLayout(reshapeOp.getResult().getType()))
211 std::optional<SmallVector<ReassociationIndices>> reassociationIndices =
213 reshapeOp.getReassociationIndices(),
214 rewriter.getContext());
215 if (!reassociationIndices)
218 if constexpr (opKind == ReshapeOpKind::kExpand) {
219 SmallVector<OpFoldResult> outputShape(
221 reshapeOp.getOutputShape(), rewriter));
222 rewriter.replaceOpWithNewOp<ReshapeOpTy>(
223 reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices,
226 rewriter.replaceOpWithNewOp<ReshapeOpTy>(
227 reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
261template <
typename CollapseOpTy,
typename ExpandOpTy,
typename CastOpTy,
262 typename DimOpTy,
typename TensorTy>
267 auto expandOp = collapseOp.getSrc().template getDefiningOp<ExpandOpTy>();
271 ShapedType srcType = expandOp.getSrcType();
272 ShapedType resultType = collapseOp.getResultType();
274 if (hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
275 hasNonIdentityLayout(expandOp.getSrc().getType()) ||
276 hasNonIdentityLayout(expandOp.getResult().getType()))
279 int64_t srcRank = srcType.getRank();
280 int64_t resultRank = resultType.getRank();
281 if (srcType == resultType)
285 lowerRankReassociation;
287 if (srcRank > resultRank) {
288 higherRankReassociation = expandOp.getReassociationIndices();
289 lowerRankReassociation = collapseOp.getReassociationIndices();
291 higherRankReassociation = collapseOp.getReassociationIndices();
292 lowerRankReassociation = expandOp.getReassociationIndices();
295 size_t higherRankIndicesID = 0;
297 for (
const auto &lowerRankIndices : lowerRankReassociation) {
299 while (higherRankIndicesID < higherRankReassociation.size()) {
300 auto rightmostIndex =
301 higherRankReassociation[higherRankIndicesID].back();
302 if (rightmostIndex > lowerRankIndices.back())
304 composedIndices.push_back(higherRankIndicesID++);
305 if (rightmostIndex == lowerRankIndices.back())
308 composedReassociation.push_back(composedIndices);
310 if (srcRank > resultRank) {
312 collapseOp, resultType, expandOp.getSrc(), composedReassociation);
313 }
else if (srcRank < resultRank) {
317 expandOp.getMixedOutputShape();
320 collapseOp.getReassociationIndices()) {
326 numStaticElems *= maybeCst.value();
329 dynamicSizes.push_back(cast<Value>(size));
331 if (dynamicSizes.empty()) {
332 newOutputShape.push_back(rewriter.
getIndexAttr(numStaticElems));
339 for (
Value v : llvm::drop_begin(dynamicSizes))
340 result = arith::MulIOp::create(rewriter, loc,
result, v,
341 arith::IntegerOverflowFlags::nsw);
342 if (numStaticElems != 1) {
343 result = arith::MulIOp::create(
346 arith::IntegerOverflowFlags::nsw);
348 newOutputShape.push_back(
result);
351 collapseOp, resultType, expandOp.getSrc(), composedReassociation,
356 assert(llvm::equal(srcType.getShape(), resultType.getShape()) &&
357 "expected same shape");
365template <
typename ExpandOpTy,
typename CollapseOpTy,
typename CastOpTy>
370 auto collapseOp = expandOp.getSrc().template getDefiningOp<CollapseOpTy>();
374 ShapedType srcType = collapseOp.getSrcType();
375 ShapedType resultType = expandOp.getResultType();
377 if (hasNonIdentityLayout(expandOp.getSrc().getType()) ||
378 hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
379 hasNonIdentityLayout(collapseOp.getResult().getType())) {
380 if (srcType.hasStaticShape() &&
381 CastOpTy::areCastCompatible(srcType, resultType)) {
383 collapseOp.getSrc());
389 int64_t srcRank = srcType.getRank();
390 int64_t resultRank = resultType.getRank();
391 if (srcRank == resultRank)
394 auto srcReassociation = collapseOp.getReassociationIndices();
395 auto resultReassociation = expandOp.getReassociationIndices();
396 if (srcRank > resultRank) {
397 auto composedReassociation = findCollapsingReassociation(
398 srcReassociation, resultReassociation, srcType.getShape(),
399 resultType.getShape());
400 if (!composedReassociation)
404 expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
407 auto composedReassociation =
408 findCollapsingReassociation(resultReassociation, srcReassociation,
409 resultType.getShape(), srcType.getShape());
410 if (!composedReassociation)
414 expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
416 expandOp, resultType, collapseOp.getSrc(), *composedReassociation,
424 std::optional<SmallVector<ReassociationIndices>> findCollapsingReassociation(
430 if (srcReassociation.empty())
433 for (
auto item : llvm::zip(srcReassociation, resultReassociation)) {
434 auto &srcIndices = std::get<0>(item);
435 auto &resultIndices = std::get<1>(item);
436 auto srcSubShape = srcShape.slice(srcIndices.front(), srcIndices.size());
437 auto resultSubShape =
438 resultShape.slice(resultIndices.front(), resultIndices.size());
440 if (llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2 &&
441 llvm::count_if(resultSubShape, ShapedType::isDynamic) >= 2)
444 if (srcSubShape.size() == resultSubShape.size()) {
445 if (srcSubShape != resultSubShape)
448 for (
auto index : llvm::seq<int64_t>(0, srcSubShape.size())) {
449 composedReassociation.emplace_back(1, srcIndices.front() + index);
455 auto subShapeReassociation =
457 if (!subShapeReassociation)
461 for (
auto &subshapeIndices : *subShapeReassociation) {
463 for (int64_t index : subshapeIndices)
464 shapeIndices.push_back(srcIndices.front() + index);
465 composedReassociation.push_back(shapeIndices);
468 return {std::move(composedReassociation)};
515class SliceFromCollapseHelper {
517 SliceFromCollapseHelper(ArrayRef<ReassociationIndices> reassociationIndices,
518 ArrayRef<OpFoldResult> collapseShapeInputShape,
519 ArrayRef<OpFoldResult> collapseShapeOutputShape,
520 ArrayRef<Range> extractSliceParams)
521 : reassociationIndices(reassociationIndices),
522 collapseShapeInputShape(collapseShapeInputShape),
523 collapseShapeOutputShape(collapseShapeOutputShape),
524 sliceParams(extractSliceParams),
527 extractSliceParams)) {}
539 SmallVector<Range> getExtractSliceParams(MLIRContext *ctx,
540 ArrayRef<ValueRange> multiIndices);
547 SmallVector<Range> getInsertSliceParams(MLIRContext *ctx,
551 SmallVector<ReassociationIndices> reassociationIndices;
552 SmallVector<OpFoldResult> collapseShapeInputShape;
553 SmallVector<OpFoldResult> collapseShapeOutputShape;
554 SmallVector<Range> sliceParams;
555 llvm::SmallBitVector linearizedDimensions;
556 llvm::SmallBitVector slicedDimensions;
561struct CollapseShapeRankReducingSliceSimplificationInfo {
566 std::optional<SmallVector<ReassociationIndices>> newReassociationIndices;
605FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
606getSimplifyCollapseShapeWithRankReducingSliceInfo(
607 RankedTensorType sourceType,
610struct PackingMetadata {
611 SmallVector<int64_t> insertPositions;
612 SmallVector<int64_t> outerPositions;
613 SmallVector<ReassociationIndices> reassociations;
622PackingMetadata computePackingMetadata(int64_t packedRank,
630 std::optional<Attribute> cst = std::nullopt);
static RankedTensorType sliceResultType(Type operandType, GridOp grid, ArrayRef< GridAxis > gridAxes, int64_t sliceAxis)
IntegerAttr getIndexAttr(int64_t value)
An attribute that represents a reference to a dense vector or tensor object.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
llvm::SmallBitVector getSlicedDimensions(ArrayRef< OpFoldResult > sliceInputShape, ArrayRef< Range > sliceParams)
The input parameters offsets, sizes, strides specify a rectangular non rank-reducing slice of the col...
ArrayRef< int64_t > ReassociationIndicesRef
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< ReassociationIndices, 2 > convertReassociationMapsToIndices(ArrayRef< ReassociationExprs > reassociationExprs)
Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)
Returns the reassociation maps to collapse sourceShape to targetShape if possible.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
SmallVector< AffineExpr, 2 > ReassociationExprs
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
std::optional< SmallVector< ReassociationIndices > > composeReassociationIndices(ArrayRef< ReassociationIndices > producerReassociations, ArrayRef< ReassociationIndices > consumerReassociations, MLIRContext *context)
Compose reassociation maps that are used in pair of reshape ops where one is a producer and other is ...
llvm::SmallBitVector getLinearizedDimensions(ArrayRef< ReassociationIndices > reassociationIndices)
Determine which dimensions are linearized by a tensor.collapse_shape op by inspecting its reassociati...
SmallVector< int64_t, 2 > ReassociationIndices
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
Common verifier for reshape-like types.
LogicalResult matchAndRewrite(CollapseOpTy collapseOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExpandOpTy expandOp, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})