14#ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
15#define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
22#include "llvm/ADT/StringRef.h"
47 ArrayRef<ReassociationIndices> producerReassociations,
48 ArrayRef<ReassociationIndices> consumerReassociations,
49 MLIRContext *context);
53 MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices);
62 ArrayRef<ReassociationIndices> reassociation);
66 ArrayRef<ReassociationExprs> reassociationExprs);
71std::optional<SmallVector<ReassociationIndices>>
76std::optional<SmallVector<ReassociationIndices>>
78 ArrayRef<int64_t> targetShape);
84 int *invalidIndex =
nullptr);
86template <
typename ReshapeOpTy,
typename InverseReshapeOpTy>
90 if (reshapeOp.getSrcType() == reshapeOp.getType())
91 return reshapeOp.getSrc();
94 if (
auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
95 return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
100 reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
103 auto srcType = reshapeSrcOp.getSrcType();
104 auto resultType = reshapeOp.getResultType();
105 if (srcType != resultType)
108 if (llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) {
109 return reshapeSrcOp.getSrc();
118 auto reassociations = reshapeOp.getReassociationIndices();
119 if (reassociations != reshapeSrcOp.getReassociationIndices())
123 if (srcType.getRank() < reshapeSrcOp.getResultType().getRank())
124 return reshapeSrcOp.getSrc();
125 if (llvm::all_of(reassociations, [&](
auto reInd) {
127 srcType.getShape().slice(reInd.front(), reInd.size());
128 return llvm::count_if(srcSlice, ShapedType::isDynamic) < 2;
130 return reshapeSrcOp.getSrc();
137template <
typename Op,
typename T>
138static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
139 T collapsedType,
bool isExpansion) {
141 unsigned expandedRank = expandedType.getRank();
142 unsigned collapsedRank = collapsedType.getRank();
143 if (expandedRank < collapsedRank)
144 return op.emitOpError(
"expected the expanded type, ")
145 << expandedType <<
" to have a higher (or same) rank "
146 <<
"than the collapsed type, " << collapsedType <<
'.';
148 if (collapsedRank != op.getReassociation().size())
149 return op.emitOpError(
"expected collapsed rank (")
150 << collapsedRank <<
") to equal the number of reassociation maps ("
151 << op.getReassociation().size() <<
").";
153 auto maps = op.getReassociationMaps();
154 for (
auto it : llvm::enumerate(maps))
155 if (it.value().getNumDims() != expandedRank)
156 return op.emitOpError(
"expected reassociation map #")
157 << it.index() <<
" to have size equal to the expanded rank ("
158 << expandedRank <<
"), but it is " << it.value().getNumDims()
163 return op.emitOpError(
"expected reassociation map #")
164 << invalidIdx <<
" to be valid and contiguous.";
166 return reshapeLikeShapesAreCompatible(
167 [&](
const Twine &msg) {
return op->emitOpError(msg); },
168 collapsedType.getShape(), expandedType.getShape(),
169 op.getReassociationIndices(), isExpansion);
177LogicalResult reshapeLikeShapesAreCompatible(
183bool hasNonIdentityLayout(
Type type);
185enum class ReshapeOpKind { kExpand, kCollapse };
189template <
typename ReshapeOpTy, ReshapeOpKind opKind>
190struct ComposeReassociativeReshapeOps :
public OpRewritePattern<ReshapeOpTy> {
191 using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
192 LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
193 PatternRewriter &rewriter)
const override {
195 reshapeOp.getSrc().template getDefiningOp<ReshapeOpTy>();
199 ShapedType resultType = reshapeOp.getResultType();
201 if (hasNonIdentityLayout(srcReshapeOp.getSrc().getType()) ||
202 hasNonIdentityLayout(reshapeOp.getSrc().getType()) ||
203 hasNonIdentityLayout(reshapeOp.getResult().getType()))
206 std::optional<SmallVector<ReassociationIndices>> reassociationIndices =
208 reshapeOp.getReassociationIndices(),
209 rewriter.getContext());
210 if (!reassociationIndices)
213 if constexpr (opKind == ReshapeOpKind::kExpand) {
214 SmallVector<OpFoldResult> outputShape(
216 reshapeOp.getOutputShape(), rewriter));
217 rewriter.replaceOpWithNewOp<ReshapeOpTy>(
218 reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices,
221 rewriter.replaceOpWithNewOp<ReshapeOpTy>(
222 reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
256template <
typename CollapseOpTy,
typename ExpandOpTy,
typename CastOpTy,
257 typename DimOpTy,
typename TensorTy>
262 auto expandOp = collapseOp.getSrc().template getDefiningOp<ExpandOpTy>();
266 ShapedType srcType = expandOp.getSrcType();
267 ShapedType resultType = collapseOp.getResultType();
269 if (hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
270 hasNonIdentityLayout(expandOp.getSrc().getType()) ||
271 hasNonIdentityLayout(expandOp.getResult().getType()))
274 int64_t srcRank = srcType.getRank();
275 int64_t resultRank = resultType.getRank();
276 if (srcType == resultType)
280 lowerRankReassociation;
282 if (srcRank > resultRank) {
283 higherRankReassociation = expandOp.getReassociationIndices();
284 lowerRankReassociation = collapseOp.getReassociationIndices();
286 higherRankReassociation = collapseOp.getReassociationIndices();
287 lowerRankReassociation = expandOp.getReassociationIndices();
290 size_t higherRankIndicesID = 0;
292 for (
const auto &lowerRankIndices : lowerRankReassociation) {
294 while (higherRankIndicesID < higherRankReassociation.size()) {
295 auto rightmostIndex =
296 higherRankReassociation[higherRankIndicesID].back();
297 if (rightmostIndex > lowerRankIndices.back())
299 composedIndices.push_back(higherRankIndicesID++);
300 if (rightmostIndex == lowerRankIndices.back())
303 composedReassociation.push_back(composedIndices);
305 if (srcRank > resultRank) {
307 collapseOp, resultType, expandOp.getSrc(), composedReassociation);
308 }
else if (srcRank < resultRank) {
312 expandOp.getMixedOutputShape();
315 collapseOp.getReassociationIndices()) {
321 numStaticElems *= maybeCst.value();
324 dynamicSizes.push_back(cast<Value>(size));
326 if (dynamicSizes.empty()) {
327 newOutputShape.push_back(rewriter.
getIndexAttr(numStaticElems));
334 for (
Value v : llvm::drop_begin(dynamicSizes))
335 result = arith::MulIOp::create(rewriter, loc,
result, v);
336 if (numStaticElems != 1) {
337 result = arith::MulIOp::create(
341 newOutputShape.push_back(
result);
344 collapseOp, resultType, expandOp.getSrc(), composedReassociation,
349 assert(llvm::equal(srcType.getShape(), resultType.getShape()) &&
350 "expected same shape");
358template <
typename ExpandOpTy,
typename CollapseOpTy,
typename CastOpTy>
363 auto collapseOp = expandOp.getSrc().template getDefiningOp<CollapseOpTy>();
367 ShapedType srcType = collapseOp.getSrcType();
368 ShapedType resultType = expandOp.getResultType();
370 if (hasNonIdentityLayout(expandOp.getSrc().getType()) ||
371 hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
372 hasNonIdentityLayout(collapseOp.getResult().getType())) {
373 if (CastOpTy::areCastCompatible(srcType, resultType)) {
375 collapseOp.getSrc());
381 int64_t srcRank = srcType.getRank();
382 int64_t resultRank = resultType.getRank();
383 if (srcRank == resultRank)
386 auto srcReassociation = collapseOp.getReassociationIndices();
387 auto resultReassociation = expandOp.getReassociationIndices();
388 if (srcRank > resultRank) {
389 auto composedReassociation = findCollapsingReassociation(
390 srcReassociation, resultReassociation, srcType.getShape(),
391 resultType.getShape());
392 if (!composedReassociation)
396 expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
399 auto composedReassociation =
400 findCollapsingReassociation(resultReassociation, srcReassociation,
401 resultType.getShape(), srcType.getShape());
402 if (!composedReassociation)
406 expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
408 expandOp, resultType, collapseOp.getSrc(), *composedReassociation,
416 std::optional<SmallVector<ReassociationIndices>> findCollapsingReassociation(
422 if (srcReassociation.empty())
425 for (
auto item : llvm::zip(srcReassociation, resultReassociation)) {
426 auto &srcIndices = std::get<0>(item);
427 auto &resultIndices = std::get<1>(item);
428 auto srcSubShape = srcShape.slice(srcIndices.front(), srcIndices.size());
429 auto resultSubShape =
430 resultShape.slice(resultIndices.front(), resultIndices.size());
432 if (llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2 &&
433 llvm::count_if(resultSubShape, ShapedType::isDynamic) >= 2)
436 if (srcSubShape.size() == resultSubShape.size()) {
437 if (srcSubShape != resultSubShape)
440 for (
auto index : llvm::seq<int64_t>(0, srcSubShape.size())) {
441 composedReassociation.emplace_back(1, srcIndices.front() + index);
447 auto subShapeReassociation =
449 if (!subShapeReassociation)
453 for (
auto &subshapeIndices : *subShapeReassociation) {
455 for (int64_t index : subshapeIndices)
456 shapeIndices.push_back(srcIndices.front() + index);
457 composedReassociation.push_back(shapeIndices);
460 return {std::move(composedReassociation)};
507class SliceFromCollapseHelper {
509 SliceFromCollapseHelper(ArrayRef<ReassociationIndices> reassociationIndices,
510 ArrayRef<OpFoldResult> collapseShapeInputShape,
511 ArrayRef<OpFoldResult> collapseShapeOutputShape,
512 ArrayRef<Range> extractSliceParams)
513 : reassociationIndices(reassociationIndices),
514 collapseShapeInputShape(collapseShapeInputShape),
515 collapseShapeOutputShape(collapseShapeOutputShape),
516 sliceParams(extractSliceParams),
519 extractSliceParams)) {}
531 SmallVector<Range> getExtractSliceParams(MLIRContext *ctx,
532 ArrayRef<ValueRange> multiIndices);
539 SmallVector<Range> getInsertSliceParams(MLIRContext *ctx,
543 SmallVector<ReassociationIndices> reassociationIndices;
544 SmallVector<OpFoldResult> collapseShapeInputShape;
545 SmallVector<OpFoldResult> collapseShapeOutputShape;
546 SmallVector<Range> sliceParams;
547 llvm::SmallBitVector linearizedDimensions;
548 llvm::SmallBitVector slicedDimensions;
553struct CollapseShapeRankReducingSliceSimplificationInfo {
558 std::optional<SmallVector<ReassociationIndices>> newReassociationIndices;
597FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
598getSimplifyCollapseShapeWithRankReducingSliceInfo(
599 RankedTensorType sourceType,
602struct PackingMetadata {
603 SmallVector<int64_t> insertPositions;
604 SmallVector<int64_t> outerPositions;
605 SmallVector<ReassociationIndices> reassociations;
614PackingMetadata computePackingMetadata(int64_t packedRank,
622 std::optional<Attribute> cst = std::nullopt);
static RankedTensorType sliceResultType(Type operandType, GridOp grid, ArrayRef< GridAxis > gridAxes, int64_t sliceAxis)
IntegerAttr getIndexAttr(int64_t value)
An attribute that represents a reference to a dense vector or tensor object.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
llvm::SmallBitVector getSlicedDimensions(ArrayRef< OpFoldResult > sliceInputShape, ArrayRef< Range > sliceParams)
The input parameters offsets, sizes, strides specify a rectangular non rank-reducing slice of the col...
ArrayRef< int64_t > ReassociationIndicesRef
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< ReassociationIndices, 2 > convertReassociationMapsToIndices(ArrayRef< ReassociationExprs > reassociationExprs)
Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)
Returns the reassociation maps to collapse sourceShape to targetShape if possible.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
SmallVector< AffineExpr, 2 > ReassociationExprs
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
std::optional< SmallVector< ReassociationIndices > > composeReassociationIndices(ArrayRef< ReassociationIndices > producerReassociations, ArrayRef< ReassociationIndices > consumerReassociations, MLIRContext *context)
Compose reassociation maps that are used in pair of reshape ops where one is a producer and other is ...
llvm::SmallBitVector getLinearizedDimensions(ArrayRef< ReassociationIndices > reassociationIndices)
Determine which dimensions are linearized by a tensor.collapse_shape op by inspecting its reassociati...
SmallVector< int64_t, 2 > ReassociationIndices
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
Common verifier for reshape-like types.
LogicalResult matchAndRewrite(CollapseOpTy collapseOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExpandOpTy expandOp, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})