14#ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
15#define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
22#include "llvm/ADT/StringRef.h"
47 ArrayRef<ReassociationIndices> producerReassociations,
48 ArrayRef<ReassociationIndices> consumerReassociations,
49 MLIRContext *context);
53 MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices);
56SmallVector<AffineMap, 4>
62 ArrayRef<ReassociationIndices> reassociation);
66 ArrayRef<ReassociationExprs> reassociationExprs);
71std::optional<SmallVector<ReassociationIndices>>
76std::optional<SmallVector<ReassociationIndices>>
78 ArrayRef<int64_t> targetShape);
84 int *invalidIndex =
nullptr);
86template <
typename ReshapeOpTy,
typename InverseReshapeOpTy>
90 if (reshapeOp.getSrcType() == reshapeOp.getType())
91 return reshapeOp.getSrc();
94 if (
auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
95 return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
100 reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
103 auto srcType = reshapeSrcOp.getSrcType();
104 auto resultType = reshapeOp.getResultType();
105 if (srcType != resultType)
108 if (llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) {
109 return reshapeSrcOp.getSrc();
118 auto reassociations = reshapeOp.getReassociationIndices();
119 if (reassociations != reshapeSrcOp.getReassociationIndices())
123 if (srcType.getRank() < reshapeSrcOp.getResultType().getRank())
124 return reshapeSrcOp.getSrc();
125 if (llvm::all_of(reassociations, [&](
auto reInd) {
127 srcType.getShape().slice(reInd.front(), reInd.size());
128 return llvm::count_if(srcSlice, ShapedType::isDynamic) < 2;
130 return reshapeSrcOp.getSrc();
137template <
typename Op,
typename T>
138static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
139 T collapsedType,
bool isExpansion) {
141 unsigned expandedRank = expandedType.getRank();
142 unsigned collapsedRank = collapsedType.getRank();
143 if (expandedRank < collapsedRank)
144 return op.emitOpError(
"expected the expanded type, ")
145 << expandedType <<
" to have a higher (or same) rank "
146 <<
"than the collapsed type, " << collapsedType <<
'.';
148 if (collapsedRank != op.getReassociation().size())
149 return op.emitOpError(
"expected collapsed rank (")
150 << collapsedRank <<
") to equal the number of reassociation maps ("
151 << op.getReassociation().size() <<
").";
153 auto maps = op.getReassociationMaps();
154 for (
auto it : llvm::enumerate(maps))
155 if (it.value().getNumDims() != expandedRank)
156 return op.emitOpError(
"expected reassociation map #")
157 << it.index() <<
" to have size equal to the expanded rank ("
158 << expandedRank <<
"), but it is " << it.value().getNumDims()
163 return op.emitOpError(
"expected reassociation map #")
164 << invalidIdx <<
" to be valid and contiguous.";
166 return reshapeLikeShapesAreCompatible(
167 [&](
const Twine &msg) {
return op->emitOpError(msg); },
168 collapsedType.getShape(), expandedType.getShape(),
169 op.getReassociationIndices(), isExpansion);
177LogicalResult reshapeLikeShapesAreCompatible(
183bool hasNonIdentityLayout(
Type type);
185enum class ReshapeOpKind { kExpand, kCollapse };
189template <
typename ReshapeOpTy, ReshapeOpKind opKind>
190struct ComposeReassociativeReshapeOps :
public OpRewritePattern<ReshapeOpTy> {
191 using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
192 LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
193 PatternRewriter &rewriter)
const override {
195 reshapeOp.getSrc().template getDefiningOp<ReshapeOpTy>();
199 ShapedType resultType = reshapeOp.getResultType();
201 if (hasNonIdentityLayout(srcReshapeOp.getSrc().getType()) ||
202 hasNonIdentityLayout(reshapeOp.getSrc().getType()) ||
203 hasNonIdentityLayout(reshapeOp.getResult().getType()))
206 std::optional<SmallVector<ReassociationIndices>> reassociationIndices =
208 reshapeOp.getReassociationIndices(),
209 rewriter.getContext());
210 if (!reassociationIndices)
213 if constexpr (opKind == ReshapeOpKind::kExpand) {
214 SmallVector<OpFoldResult> outputShape(
216 reshapeOp.getOutputShape(), rewriter));
217 rewriter.replaceOpWithNewOp<ReshapeOpTy>(
218 reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices,
221 rewriter.replaceOpWithNewOp<ReshapeOpTy>(
222 reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
256template <
typename CollapseOpTy,
typename ExpandOpTy,
typename CastOpTy,
257 typename DimOpTy,
typename TensorTy>
262 auto expandOp = collapseOp.getSrc().template getDefiningOp<ExpandOpTy>();
266 ShapedType srcType = expandOp.getSrcType();
267 ShapedType resultType = collapseOp.getResultType();
269 if (hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
270 hasNonIdentityLayout(expandOp.getSrc().getType()) ||
271 hasNonIdentityLayout(expandOp.getResult().getType()))
274 int64_t srcRank = srcType.getRank();
275 int64_t resultRank = resultType.getRank();
276 if (srcType == resultType)
280 lowerRankReassociation;
282 if (srcRank > resultRank) {
283 higherRankReassociation = expandOp.getReassociationIndices();
284 lowerRankReassociation = collapseOp.getReassociationIndices();
286 higherRankReassociation = collapseOp.getReassociationIndices();
287 lowerRankReassociation = expandOp.getReassociationIndices();
290 size_t higherRankIndicesID = 0;
292 for (
const auto &lowerRankIndices : lowerRankReassociation) {
294 while (higherRankIndicesID < higherRankReassociation.size()) {
295 auto rightmostIndex =
296 higherRankReassociation[higherRankIndicesID].back();
297 if (rightmostIndex > lowerRankIndices.back())
299 composedIndices.push_back(higherRankIndicesID++);
300 if (rightmostIndex == lowerRankIndices.back())
303 composedReassociation.push_back(composedIndices);
305 if (srcRank > resultRank) {
307 collapseOp, resultType, expandOp.getSrc(), composedReassociation);
308 }
else if (srcRank < resultRank) {
312 expandOp.getMixedOutputShape();
315 collapseOp.getReassociationIndices()) {
321 numStaticElems *= maybeCst.value();
324 dynamicSizes.push_back(cast<Value>(size));
326 if (dynamicSizes.empty()) {
327 newOutputShape.push_back(rewriter.
getIndexAttr(numStaticElems));
334 for (
Value v : llvm::drop_begin(dynamicSizes))
335 result = arith::MulIOp::create(rewriter, loc,
result, v,
336 arith::IntegerOverflowFlags::nsw);
337 if (numStaticElems != 1) {
338 result = arith::MulIOp::create(
341 arith::IntegerOverflowFlags::nsw);
343 newOutputShape.push_back(
result);
346 collapseOp, resultType, expandOp.getSrc(), composedReassociation,
351 assert(llvm::equal(srcType.getShape(), resultType.getShape()) &&
352 "expected same shape");
360template <
typename ExpandOpTy,
typename CollapseOpTy,
typename CastOpTy>
365 auto collapseOp = expandOp.getSrc().template getDefiningOp<CollapseOpTy>();
369 ShapedType srcType = collapseOp.getSrcType();
370 ShapedType resultType = expandOp.getResultType();
372 if (hasNonIdentityLayout(expandOp.getSrc().getType()) ||
373 hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
374 hasNonIdentityLayout(collapseOp.getResult().getType())) {
375 if (srcType.hasStaticShape() &&
376 CastOpTy::areCastCompatible(srcType, resultType)) {
378 collapseOp.getSrc());
384 int64_t srcRank = srcType.getRank();
385 int64_t resultRank = resultType.getRank();
386 if (srcRank == resultRank)
389 auto srcReassociation = collapseOp.getReassociationIndices();
390 auto resultReassociation = expandOp.getReassociationIndices();
391 if (srcRank > resultRank) {
392 auto composedReassociation = findCollapsingReassociation(
393 srcReassociation, resultReassociation, srcType.getShape(),
394 resultType.getShape());
395 if (!composedReassociation)
399 expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
402 auto composedReassociation =
403 findCollapsingReassociation(resultReassociation, srcReassociation,
404 resultType.getShape(), srcType.getShape());
405 if (!composedReassociation)
409 expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
411 expandOp, resultType, collapseOp.getSrc(), *composedReassociation,
419 std::optional<SmallVector<ReassociationIndices>> findCollapsingReassociation(
425 if (srcReassociation.empty())
428 for (
auto item : llvm::zip(srcReassociation, resultReassociation)) {
429 auto &srcIndices = std::get<0>(item);
430 auto &resultIndices = std::get<1>(item);
431 auto srcSubShape = srcShape.slice(srcIndices.front(), srcIndices.size());
432 auto resultSubShape =
433 resultShape.slice(resultIndices.front(), resultIndices.size());
435 if (llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2 &&
436 llvm::count_if(resultSubShape, ShapedType::isDynamic) >= 2)
439 if (srcSubShape.size() == resultSubShape.size()) {
440 if (srcSubShape != resultSubShape)
443 for (
auto index : llvm::seq<int64_t>(0, srcSubShape.size())) {
444 composedReassociation.emplace_back(1, srcIndices.front() + index);
450 auto subShapeReassociation =
452 if (!subShapeReassociation)
456 for (
auto &subshapeIndices : *subShapeReassociation) {
458 for (int64_t index : subshapeIndices)
459 shapeIndices.push_back(srcIndices.front() + index);
460 composedReassociation.push_back(shapeIndices);
463 return {std::move(composedReassociation)};
510class SliceFromCollapseHelper {
512 SliceFromCollapseHelper(ArrayRef<ReassociationIndices> reassociationIndices,
513 ArrayRef<OpFoldResult> collapseShapeInputShape,
514 ArrayRef<OpFoldResult> collapseShapeOutputShape,
515 ArrayRef<Range> extractSliceParams)
516 : reassociationIndices(reassociationIndices),
517 collapseShapeInputShape(collapseShapeInputShape),
518 collapseShapeOutputShape(collapseShapeOutputShape),
519 sliceParams(extractSliceParams),
522 extractSliceParams)) {}
534 SmallVector<Range> getExtractSliceParams(MLIRContext *ctx,
535 ArrayRef<ValueRange> multiIndices);
542 SmallVector<Range> getInsertSliceParams(MLIRContext *ctx,
546 SmallVector<ReassociationIndices> reassociationIndices;
547 SmallVector<OpFoldResult> collapseShapeInputShape;
548 SmallVector<OpFoldResult> collapseShapeOutputShape;
549 SmallVector<Range> sliceParams;
550 llvm::SmallBitVector linearizedDimensions;
551 llvm::SmallBitVector slicedDimensions;
556struct CollapseShapeRankReducingSliceSimplificationInfo {
561 std::optional<SmallVector<ReassociationIndices>> newReassociationIndices;
600FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
601getSimplifyCollapseShapeWithRankReducingSliceInfo(
602 RankedTensorType sourceType,
605struct PackingMetadata {
606 SmallVector<int64_t> insertPositions;
607 SmallVector<int64_t> outerPositions;
608 SmallVector<ReassociationIndices> reassociations;
617PackingMetadata computePackingMetadata(int64_t packedRank,
625 std::optional<Attribute> cst = std::nullopt);
static RankedTensorType sliceResultType(Type operandType, GridOp grid, ArrayRef< GridAxis > gridAxes, int64_t sliceAxis)
IntegerAttr getIndexAttr(int64_t value)
An attribute that represents a reference to a dense vector or tensor object.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
llvm::SmallBitVector getSlicedDimensions(ArrayRef< OpFoldResult > sliceInputShape, ArrayRef< Range > sliceParams)
The input parameters offsets, sizes, strides specify a rectangular non rank-reducing slice of the col...
ArrayRef< int64_t > ReassociationIndicesRef
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< ReassociationIndices, 2 > convertReassociationMapsToIndices(ArrayRef< ReassociationExprs > reassociationExprs)
Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)
Returns the reassociation maps to collapse sourceShape to targetShape if possible.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
SmallVector< AffineExpr, 2 > ReassociationExprs
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
std::optional< SmallVector< ReassociationIndices > > composeReassociationIndices(ArrayRef< ReassociationIndices > producerReassociations, ArrayRef< ReassociationIndices > consumerReassociations, MLIRContext *context)
Compose reassociation maps that are used in pair of reshape ops where one is a producer and other is ...
llvm::SmallBitVector getLinearizedDimensions(ArrayRef< ReassociationIndices > reassociationIndices)
Determine which dimensions are linearized by a tensor.collapse_shape op by inspecting its reassociati...
SmallVector< int64_t, 2 > ReassociationIndices
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
Common verifier for reshape-like types.
LogicalResult matchAndRewrite(CollapseOpTy collapseOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExpandOpTy expandOp, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})