14 #ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
15 #define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
22 #include "llvm/ADT/StringRef.h"
47 ArrayRef<ReassociationIndices> producerReassociations,
48 ArrayRef<ReassociationIndices> consumerReassociations,
49 MLIRContext *context);
53 MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices);
56 SmallVector<AffineMap, 4>
62 ArrayRef<ReassociationIndices> reassociation);
66 ArrayRef<ReassociationExprs> reassociationExprs);
71 std::optional<SmallVector<ReassociationIndices>>
76 std::optional<SmallVector<ReassociationIndices>>
78 ArrayRef<int64_t> targetShape);
84 int *invalidIndex =
nullptr);
86 template <
typename ReshapeOpTy,
typename InverseReshapeOpTy>
90 if (reshapeOp.getSrcType() == reshapeOp.getType())
91 return reshapeOp.getSrc();
94 if (
auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
95 return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
100 reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
103 auto srcType = reshapeSrcOp.getSrcType();
104 auto resultType = reshapeOp.getResultType();
105 if (srcType != resultType)
108 if (llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) {
109 return reshapeSrcOp.getSrc();
118 auto reassociations = reshapeOp.getReassociationIndices();
119 if (reassociations != reshapeSrcOp.getReassociationIndices())
123 if (srcType.getRank() < reshapeSrcOp.getResultType().getRank())
124 return reshapeSrcOp.getSrc();
125 if (llvm::all_of(reassociations, [&](
auto reInd) {
127 srcType.getShape().slice(reInd.front(), reInd.size());
128 return llvm::count_if(srcSlice, ShapedType::isDynamic) < 2;
130 return reshapeSrcOp.getSrc();
137 template <
typename Op,
typename T>
139 T collapsedType,
bool isExpansion) {
141 unsigned expandedRank = expandedType.getRank();
142 unsigned collapsedRank = collapsedType.getRank();
143 if (expandedRank < collapsedRank)
144 return op.
emitOpError(
"expected the expanded type, ")
145 << expandedType <<
" to have a higher (or same) rank "
146 <<
"than the collapsed type, " << collapsedType <<
'.';
148 if (collapsedRank != op.getReassociation().size())
150 << collapsedRank <<
") to equal the number of reassociation maps ("
151 << op.getReassociation().size() <<
").";
153 auto maps = op.getReassociationMaps();
155 if (it.value().getNumDims() != expandedRank)
156 return op.
emitOpError(
"expected reassociation map #")
157 << it.index() <<
" to have size equal to the expanded rank ("
158 << expandedRank <<
"), but it is " << it.value().getNumDims()
163 return op.
emitOpError(
"expected reassociation map #")
164 << invalidIdx <<
" to be valid and contiguous.";
167 [&](
const Twine &msg) {
return op->
emitOpError(msg); },
168 collapsedType.getShape(), expandedType.getShape(),
169 op.getReassociationIndices(), isExpansion);
179 ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
180 ArrayRef<ReassociationIndices> reassociationMaps,
bool isExpandingReshape);
189 template <
typename ReshapeOpTy, ReshapeOpKind opKind>
195 reshapeOp.getSrc().template getDefiningOp<ReshapeOpTy>();
199 ShapedType resultType = reshapeOp.getResultType();
206 std::optional<SmallVector<ReassociationIndices>> reassociationIndices =
208 reshapeOp.getReassociationIndices(),
210 if (!reassociationIndices)
216 reshapeOp.getOutputShape(), rewriter));
218 reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices,
222 reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
256 template <
typename CollapseOpTy,
typename ExpandOpTy,
typename CastOpTy,
257 typename DimOpTy,
typename TensorTy>
262 auto expandOp = collapseOp.getSrc().template getDefiningOp<ExpandOpTy>();
266 ShapedType srcType = expandOp.getSrcType();
267 ShapedType resultType = collapseOp.getResultType();
274 int64_t srcRank = srcType.getRank();
275 int64_t resultRank = resultType.getRank();
276 if (srcType == resultType)
280 lowerRankReassociation;
282 if (srcRank > resultRank) {
283 higherRankReassociation = expandOp.getReassociationIndices();
284 lowerRankReassociation = collapseOp.getReassociationIndices();
286 higherRankReassociation = collapseOp.getReassociationIndices();
287 lowerRankReassociation = expandOp.getReassociationIndices();
290 size_t higherRankIndicesID = 0;
292 for (
const auto &lowerRankIndices : lowerRankReassociation) {
294 while (higherRankIndicesID < higherRankReassociation.size()) {
295 auto rightmostIndex =
296 higherRankReassociation[higherRankIndicesID].back();
297 if (rightmostIndex > lowerRankIndices.back())
299 composedIndices.push_back(higherRankIndicesID++);
300 if (rightmostIndex == lowerRankIndices.back())
303 composedReassociation.push_back(composedIndices);
305 if (srcRank > resultRank) {
307 collapseOp, resultType, expandOp.getSrc(), composedReassociation);
308 }
else if (srcRank < resultRank) {
312 expandOp.getMixedOutputShape();
315 collapseOp.getReassociationIndices()) {
316 int64_t numStaticElems = 1;
318 for (int64_t idx : indices) {
321 numStaticElems *= maybeCst.value();
324 dynamicSizes.push_back(cast<Value>(size));
326 if (dynamicSizes.empty()) {
327 newOutputShape.push_back(rewriter.
getIndexAttr(numStaticElems));
333 Value result = dynamicSizes[0];
334 for (
Value v : llvm::drop_begin(dynamicSizes))
335 result = rewriter.
create<arith::MulIOp>(loc, result, v);
336 if (numStaticElems != 1) {
337 result = rewriter.
create<arith::MulIOp>(
341 newOutputShape.push_back(result);
344 collapseOp, resultType, expandOp.getSrc(), composedReassociation,
349 assert(llvm::equal(srcType.getShape(), resultType.getShape()) &&
350 "expected same shape");
358 template <
typename ExpandOpTy,
typename CollapseOpTy>
363 auto collapseOp = expandOp.getSrc().template getDefiningOp<CollapseOpTy>();
367 ShapedType srcType = collapseOp.getSrcType();
368 ShapedType resultType = expandOp.getResultType();
375 int64_t srcRank = srcType.getRank();
376 int64_t resultRank = resultType.getRank();
377 if (srcRank == resultRank)
380 auto srcReassociation = collapseOp.getReassociationIndices();
381 auto resultReassociation = expandOp.getReassociationIndices();
382 if (srcRank > resultRank) {
383 auto composedReassociation = findCollapsingReassociation(
384 srcReassociation, resultReassociation, srcType.getShape(),
385 resultType.getShape());
386 if (!composedReassociation)
390 expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
393 auto composedReassociation =
394 findCollapsingReassociation(resultReassociation, srcReassociation,
395 resultType.getShape(), srcType.getShape());
396 if (!composedReassociation)
400 expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
402 expandOp, resultType, collapseOp.getSrc(), *composedReassociation,
410 std::optional<SmallVector<ReassociationIndices>> findCollapsingReassociation(
416 if (srcReassociation.empty())
419 for (
auto item : llvm::zip(srcReassociation, resultReassociation)) {
420 auto &srcIndices = std::get<0>(item);
421 auto &resultIndices = std::get<1>(item);
422 auto srcSubShape = srcShape.slice(srcIndices.front(), srcIndices.size());
423 auto resultSubShape =
424 resultShape.slice(resultIndices.front(), resultIndices.size());
426 if (llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2 &&
427 llvm::count_if(resultSubShape, ShapedType::isDynamic) >= 2)
430 if (srcSubShape.size() == resultSubShape.size()) {
431 if (srcSubShape != resultSubShape)
434 for (
auto index : llvm::seq<int64_t>(0, srcSubShape.size())) {
435 composedReassociation.emplace_back(1, srcIndices.front() + index);
441 auto subShapeReassociation =
443 if (!subShapeReassociation)
447 for (
auto &subshapeIndices : *subShapeReassociation) {
449 for (int64_t index : subshapeIndices)
450 shapeIndices.push_back(srcIndices.front() + index);
451 composedReassociation.push_back(shapeIndices);
454 return {std::move(composedReassociation)};
464 ArrayRef<Range> sliceParams);
501 class SliceFromCollapseHelper {
503 SliceFromCollapseHelper(ArrayRef<ReassociationIndices> reassociationIndices,
504 ArrayRef<OpFoldResult> collapseShapeInputShape,
505 ArrayRef<OpFoldResult> collapseShapeOutputShape,
506 ArrayRef<Range> extractSliceParams)
507 : reassociationIndices(reassociationIndices),
508 collapseShapeInputShape(collapseShapeInputShape),
509 collapseShapeOutputShape(collapseShapeOutputShape),
510 sliceParams(extractSliceParams),
513 extractSliceParams)) {}
525 SmallVector<Range> getExtractSliceParams(MLIRContext *ctx,
526 ArrayRef<ValueRange> multiIndices);
533 SmallVector<Range> getInsertSliceParams(MLIRContext *ctx,
534 ValueRange tileIndices);
537 SmallVector<ReassociationIndices> reassociationIndices;
538 SmallVector<OpFoldResult> collapseShapeInputShape;
539 SmallVector<OpFoldResult> collapseShapeOutputShape;
540 SmallVector<Range> sliceParams;
541 llvm::SmallBitVector linearizedDimensions;
542 llvm::SmallBitVector slicedDimensions;
547 struct CollapseShapeRankReducingSliceSimplificationInfo {
552 std::optional<SmallVector<ReassociationIndices>> newReassociationIndices;
591 FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
592 getSimplifyCollapseShapeWithRankReducingSliceInfo(
593 RankedTensorType sourceType,
594 ArrayRef<ReassociationIndices> reassociationIndices);
596 struct PackingMetadata {
597 SmallVector<int64_t> insertPositions;
598 SmallVector<int64_t> outerPositions;
599 SmallVector<ReassociationIndices> reassociations;
608 PackingMetadata computePackingMetadata(int64_t packedRank,
609 ArrayRef<int64_t> innerDimPos);
615 OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result,
616 std::optional<Attribute> cst = std::nullopt);
static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, ArrayRef< MeshAxis > meshAxes, int64_t sliceAxis)
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This provides public APIs that all operations should have.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Specialization of arith.constant op that returns an integer of index type.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
llvm::SmallBitVector getSlicedDimensions(ArrayRef< OpFoldResult > sliceInputShape, ArrayRef< Range > sliceParams)
The input parameters offsets, sizes, strides specify a rectangular non rank-reducing slice of the col...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
bool hasNonIdentityLayout(Type type)
Returns true iff the type is a MemRefType and has a non-identity layout.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector< ReassociationIndices, 2 > convertReassociationMapsToIndices(ArrayRef< ReassociationExprs > reassociationExprs)
Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
LogicalResult reshapeLikeShapesAreCompatible(function_ref< LogicalResult(const Twine &)> emitError, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociationMaps, bool isExpandingReshape)
Verify that shapes of the reshaped types using following rule: if a dimension in the collapsed type i...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForCollapse(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > targetShape)
Returns the reassociation maps to collapse sourceShape to targetShape if possible.
ArrayRef< int64_t > ReassociationIndicesRef
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
std::optional< SmallVector< ReassociationIndices > > composeReassociationIndices(ArrayRef< ReassociationIndices > producerReassociations, ArrayRef< ReassociationIndices > consumerReassociations, MLIRContext *context)
Compose reassociation maps that are used in pair of reshape ops where one is a producer and other is ...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
llvm::SmallBitVector getLinearizedDimensions(ArrayRef< ReassociationIndices > reassociationIndices)
Determine which dimensions are linearized by a tensor.collapse_shape op by inspecting its reassociati...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
SmallVector< int64_t, 2 > ReassociationIndices
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
LogicalResult matchAndRewrite(CollapseOpTy collapseOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExpandOpTy expandOp, PatternRewriter &rewriter) const override
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, PatternRewriter &rewriter) const override
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...