14#ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
15#define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
22#include "llvm/ADT/StringRef.h"
47 ArrayRef<ReassociationIndices> producerReassociations,
48 ArrayRef<ReassociationIndices> consumerReassociations,
49 MLIRContext *context);
53 MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices);
62 ArrayRef<ReassociationIndices> reassociation);
66 ArrayRef<ReassociationExprs> reassociationExprs);
71std::optional<SmallVector<ReassociationIndices>>
76std::optional<SmallVector<ReassociationIndices>>
78 ArrayRef<int64_t> targetShape);
84 int *invalidIndex =
nullptr);
86template <
typename ReshapeOpTy,
typename InverseReshapeOpTy>
90 if (reshapeOp.getSrcType() == reshapeOp.getType())
91 return reshapeOp.getSrc();
94 if (
auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
95 return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
100 reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
103 auto srcType = reshapeSrcOp.getSrcType();
104 auto resultType = reshapeOp.getResultType();
105 if (srcType != resultType)
108 if (llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) {
109 return reshapeSrcOp.getSrc();
118 auto reassociations = reshapeOp.getReassociationIndices();
119 if (reassociations != reshapeSrcOp.getReassociationIndices())
123 if (srcType.getRank() < reshapeSrcOp.getResultType().getRank())
124 return reshapeSrcOp.getSrc();
125 if (llvm::all_of(reassociations, [&](
auto reInd) {
127 srcType.getShape().slice(reInd.front(), reInd.size());
128 return llvm::count_if(srcSlice, ShapedType::isDynamic) < 2;
130 return reshapeSrcOp.getSrc();
137template <
typename Op,
typename T>
138static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
139 T collapsedType,
bool isExpansion) {
141 unsigned expandedRank = expandedType.getRank();
142 unsigned collapsedRank = collapsedType.getRank();
143 if (expandedRank < collapsedRank)
144 return op.emitOpError(
"expected the expanded type, ")
145 << expandedType <<
" to have a higher (or same) rank "
146 <<
"than the collapsed type, " << collapsedType <<
'.';
148 if (collapsedRank != op.getReassociation().size())
149 return op.emitOpError(
"expected collapsed rank (")
150 << collapsedRank <<
") to equal the number of reassociation maps ("
151 << op.getReassociation().size() <<
").";
153 auto maps = op.getReassociationMaps();
154 for (
auto it : llvm::enumerate(maps))
155 if (it.value().getNumDims() != expandedRank)
156 return op.emitOpError(
"expected reassociation map #")
157 << it.index() <<
" to have size equal to the expanded rank ("
158 << expandedRank <<
"), but it is " << it.value().getNumDims()
163 return op.emitOpError(
"expected reassociation map #")
164 << invalidIdx <<
" to be valid and contiguous.";
166 return reshapeLikeShapesAreCompatible(
167 [&](
const Twine &msg) {
return op->emitOpError(msg); },
168 collapsedType.getShape(), expandedType.getShape(),
169 op.getReassociationIndices(), isExpansion);
177LogicalResult reshapeLikeShapesAreCompatible(
183bool hasNonIdentityLayout(
Type type);
185enum class ReshapeOpKind { kExpand, kCollapse };
189template <
typename ReshapeOpTy, ReshapeOpKind opKind>
190struct ComposeReassociativeReshapeOps :
public OpRewritePattern<ReshapeOpTy> {
191 using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
192 LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
193 PatternRewriter &rewriter)
const override {
195 reshapeOp.getSrc().template getDefiningOp<ReshapeOpTy>();
199 ShapedType resultType = reshapeOp.getResultType();
201 if (hasNonIdentityLayout(srcReshapeOp.getSrc().getType()) ||
202 hasNonIdentityLayout(reshapeOp.getSrc().getType()) ||
203 hasNonIdentityLayout(reshapeOp.getResult().getType()))
206 std::optional<SmallVector<ReassociationIndices>> reassociationIndices =
208 reshapeOp.getReassociationIndices(),
209 rewriter.getContext());
210 if (!reassociationIndices)
213 if constexpr (opKind == ReshapeOpKind::kExpand) {
214 SmallVector<OpFoldResult> outputShape(
216 reshapeOp.getOutputShape(), rewriter));
217 rewriter.replaceOpWithNewOp<ReshapeOpTy>(
218 reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices,
221 rewriter.replaceOpWithNewOp<ReshapeOpTy>(
222 reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
256template <
typename CollapseOpTy,
typename ExpandOpTy,
typename CastOpTy,
257 typename DimOpTy,
typename TensorTy>
262 auto expandOp = collapseOp.getSrc().template getDefiningOp<ExpandOpTy>();
266 ShapedType srcType = expandOp.getSrcType();
267 ShapedType resultType = collapseOp.getResultType();
269 if (hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
270 hasNonIdentityLayout(expandOp.getSrc().getType()) ||
271 hasNonIdentityLayout(expandOp.getResult().getType()))
274 int64_t srcRank = srcType.getRank();
275 int64_t resultRank = resultType.getRank();
276 if (srcType == resultType)
280 lowerRankReassociation;
282 if (srcRank > resultRank) {
283 higherRankReassociation = expandOp.getReassociationIndices();
284 lowerRankReassociation = collapseOp.getReassociationIndices();
286 higherRankReassociation = collapseOp.getReassociationIndices();
287 lowerRankReassociation = expandOp.getReassociationIndices();
290 size_t higherRankIndicesID = 0;
292 for (
const auto &lowerRankIndices : lowerRankReassociation) {
294 while (higherRankIndicesID < higherRankReassociation.size()) {
295 auto rightmostIndex =
296 higherRankReassociation[higherRankIndicesID].back();
297 if (rightmostIndex > lowerRankIndices.back())
299 composedIndices.push_back(higherRankIndicesID++);
300 if (rightmostIndex == lowerRankIndices.back())
303 composedReassociation.push_back(composedIndices);
305 if (srcRank > resultRank) {
307 collapseOp, resultType, expandOp.getSrc(), composedReassociation);
308 }
else if (srcRank < resultRank) {
312 expandOp.getMixedOutputShape();
315 collapseOp.getReassociationIndices()) {
321 numStaticElems *= maybeCst.value();
324 dynamicSizes.push_back(cast<Value>(size));
326 if (dynamicSizes.empty()) {
327 newOutputShape.push_back(rewriter.
getIndexAttr(numStaticElems));
334 for (
Value v : llvm::drop_begin(dynamicSizes))
335 result = arith::MulIOp::create(rewriter, loc,
result, v);
336 if (numStaticElems != 1) {
337 result = arith::MulIOp::create(
341 newOutputShape.push_back(
result);
344 collapseOp, resultType, expandOp.getSrc(), composedReassociation,
349 assert(llvm::equal(srcType.getShape(), resultType.getShape()) &&
350 "expected same shape");
358template <
typename ExpandOpTy,
typename CollapseOpTy>
363 auto collapseOp = expandOp.getSrc().template getDefiningOp<CollapseOpTy>();
367 ShapedType srcType = collapseOp.getSrcType();
368 ShapedType resultType = expandOp.getResultType();
370 if (hasNonIdentityLayout(expandOp.getSrc().getType()) ||
371 hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
372 hasNonIdentityLayout(collapseOp.getResult().getType()))
375 int64_t srcRank = srcType.getRank();
376 int64_t resultRank = resultType.getRank();
377 if (srcRank == resultRank)
380 auto srcReassociation = collapseOp.getReassociationIndices();
381 auto resultReassociation = expandOp.getReassociationIndices();
382 if (srcRank > resultRank) {
383 auto composedReassociation = findCollapsingReassociation(
384 srcReassociation, resultReassociation, srcType.getShape(),
385 resultType.getShape());
386 if (!composedReassociation)
390 expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
393 auto composedReassociation =
394 findCollapsingReassociation(resultReassociation, srcReassociation,
395 resultType.getShape(), srcType.getShape());
396 if (!composedReassociation)
400 expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
402 expandOp, resultType, collapseOp.getSrc(), *composedReassociation,
410 std::optional<SmallVector<ReassociationIndices>> findCollapsingReassociation(
416 if (srcReassociation.empty())
419 for (
auto item : llvm::zip(srcReassociation, resultReassociation)) {
420 auto &srcIndices = std::get<0>(item);
421 auto &resultIndices = std::get<1>(item);
422 auto srcSubShape = srcShape.slice(srcIndices.front(), srcIndices.size());
423 auto resultSubShape =
424 resultShape.slice(resultIndices.front(), resultIndices.size());
426 if (llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2 &&
427 llvm::count_if(resultSubShape, ShapedType::isDynamic) >= 2)
430 if (srcSubShape.size() == resultSubShape.size()) {
431 if (srcSubShape != resultSubShape)
434 for (
auto index : llvm::seq<int64_t>(0, srcSubShape.size())) {
435 composedReassociation.emplace_back(1, srcIndices.front() + index);
441 auto subShapeReassociation =
443 if (!subShapeReassociation)
447 for (
auto &subshapeIndices : *subShapeReassociation) {
449 for (int64_t index : subshapeIndices)
450 shapeIndices.push_back(srcIndices.front() + index);
451 composedReassociation.push_back(shapeIndices);
454 return {std::move(composedReassociation)};
501class SliceFromCollapseHelper {
503 SliceFromCollapseHelper(ArrayRef<ReassociationIndices> reassociationIndices,
504 ArrayRef<OpFoldResult> collapseShapeInputShape,
505 ArrayRef<OpFoldResult> collapseShapeOutputShape,
506 ArrayRef<Range> extractSliceParams)
507 : reassociationIndices(reassociationIndices),
508 collapseShapeInputShape(collapseShapeInputShape),
509 collapseShapeOutputShape(collapseShapeOutputShape),
510 sliceParams(extractSliceParams),
513 extractSliceParams)) {}
525 SmallVector<Range> getExtractSliceParams(MLIRContext *ctx,
526 ArrayRef<ValueRange> multiIndices);
533 SmallVector<Range> getInsertSliceParams(MLIRContext *ctx,
537 SmallVector<ReassociationIndices> reassociationIndices;
538 SmallVector<OpFoldResult> collapseShapeInputShape;
539 SmallVector<OpFoldResult> collapseShapeOutputShape;
540 SmallVector<Range> sliceParams;
541 llvm::SmallBitVector linearizedDimensions;
542 llvm::SmallBitVector slicedDimensions;
547struct CollapseShapeRankReducingSliceSimplificationInfo {
552 std::optional<SmallVector<ReassociationIndices>> newReassociationIndices;
591FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
592getSimplifyCollapseShapeWithRankReducingSliceInfo(
593 RankedTensorType sourceType,
596struct PackingMetadata {
597 SmallVector<int64_t> insertPositions;
598 SmallVector<int64_t> outerPositions;
599 SmallVector<ReassociationIndices> reassociations;
608PackingMetadata computePackingMetadata(int64_t packedRank,
616 std::optional<Attribute> cst = std::nullopt);
static RankedTensorType sliceResultType(Type operandType, GridOp grid, ArrayRef< GridAxis > gridAxes, int64_t sliceAxis)
IntegerAttr getIndexAttr(int64_t value)
An attribute that represents a reference to a dense vector or tensor object.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
llvm::SmallBitVector getSlicedDimensions(ArrayRef< OpFoldResult > sliceInputShape, ArrayRef< Range > sliceParams)
The input parameters offsets, sizes, strides specify a rectangular non rank-reducing slice of the col...
ArrayRef< int64_t > ReassociationIndicesRef
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< ReassociationIndices, 2 > convertReassociationMapsToIndices(ArrayRef< ReassociationExprs > reassociationExprs)
Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)
Returns the reassociation maps to collapse sourceShape to targetShape if possible.
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
SmallVector< AffineExpr, 2 > ReassociationExprs
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
std::optional< SmallVector< ReassociationIndices > > composeReassociationIndices(ArrayRef< ReassociationIndices > producerReassociations, ArrayRef< ReassociationIndices > consumerReassociations, MLIRContext *context)
Compose reassociation maps that are used in pair of reshape ops where one is a producer and other is ...
llvm::SmallBitVector getLinearizedDimensions(ArrayRef< ReassociationIndices > reassociationIndices)
Determine which dimensions are linearized by a tensor.collapse_shape op by inspecting its reassociati...
SmallVector< int64_t, 2 > ReassociationIndices
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
Common verifier for reshape-like types.
LogicalResult matchAndRewrite(CollapseOpTy collapseOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExpandOpTy expandOp, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})