22static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
23 return llvm::count_if(
24 shape, [](int64_t v) {
return ShapedType::isDynamic(v) || v > 1; });
32static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
33 ArrayRef<int64_t> srcShape,
34 ArrayRef<int64_t> innerPackTileSize) {
35 if (getNumGtOneDims(srcShape) > 1) {
36 return rewriter.notifyMatchFailure(
37 op,
"expects non-packed domain to have at most one non-unit dims");
41 if (getNumGtOneDims(innerPackTileSize) > 1) {
42 return rewriter.notifyMatchFailure(
43 op,
"expects at most one non-unit inner tiles");
50static FailureOr<SmallVector<int64_t>>
51getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
52 if (
auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation()))
53 return SmallVector<int64_t>(transposeOp.getPermutation());
54 if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
57 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
59 auto mapRange = linalgOp.getIndexingMapsArray();
60 if (!mapRange.front().isPermutation() || !mapRange.back().isPermutation() ||
61 mapRange.front() == mapRange.back()) {
64 if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
66 AffineMap outMap = mapRange.back();
67 AffineMap inMap = mapRange.front();
70 return llvm::map_to_vector(outMap.getResults(),
71 [&](AffineExpr expr) -> int64_t {
72 return *inMap.getResultPosition(expr);
77struct SimplifyPackToExpandShape :
public OpRewritePattern<PackOp> {
78 using OpRewritePattern<PackOp>::OpRewritePattern;
81 insertExpand(RewriterBase &rewriter, Location loc, Value operand,
83 ArrayRef<ReassociationIndices> reassociation)
const {
84 if (operand.getType() == newOperandType)
86 return tensor::ExpandShapeOp::create(rewriter, loc, newOperandType, operand,
92 LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter,
93 PackOp packOp)
const {
94 auto outerDimsPerm = packOp.getOuterDimsPerm();
96 return rewriter.notifyMatchFailure(
98 "expects outer_dims_perm is empty or an identity permutation");
101 int64_t srcRank = packOp.getSourceRank();
102 ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
103 if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
104 return rewriter.notifyMatchFailure(
105 packOp,
"expects packing at the innermost dimension");
110 LogicalResult matchAndRewrite(PackOp packOp,
111 PatternRewriter &rewriter)
const override {
112 if (packOp.getPaddingValue())
113 return rewriter.notifyMatchFailure(packOp,
"expects no padding value");
115 if (!packOp.hasPureTensorSemantics())
118 ShapedType sourceType = packOp.getSourceType();
119 if (
failed(isPackOnInnerMostDim(rewriter, packOp)) &&
120 failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
121 packOp.getStaticTiles())) &&
122 !packOp.isLikePad()) {
126 ShapedType destType = packOp.getDestType();
131 FailureOr<Value> expanded =
132 insertExpand(rewriter, packOp.getLoc(), packOp.getSource(), destType,
135 return rewriter.notifyMatchFailure(
136 packOp,
"unable to expand source of tensor.pack");
138 rewriter.replaceOp(packOp, *expanded);
143struct SimplifyUnPackToCollapseShape :
public OpRewritePattern<UnPackOp> {
144 using OpRewritePattern<UnPackOp>::OpRewritePattern;
146 Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
147 Type newOperandType,
ArrayAttr reassociation)
const {
148 if (operand.getType() == newOperandType)
150 return tensor::CollapseShapeOp::create(rewriter, loc, newOperandType,
151 operand, reassociation);
155 LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter,
156 UnPackOp unpackOp)
const {
157 auto outerDimsPerm = unpackOp.getOuterDimsPerm();
159 return rewriter.notifyMatchFailure(
161 "expects outer_dims_perm is empty or an identity permutation");
164 ShapedType sourceType = unpackOp.getSourceType();
165 ShapedType destType = unpackOp.getDestType();
166 if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
167 return rewriter.notifyMatchFailure(unpackOp,
"expects static shapes");
169 ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
170 if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
171 return rewriter.notifyMatchFailure(
172 unpackOp,
"expects unpacking on the innermost dimension");
178 LogicalResult matchAndRewrite(UnPackOp unpackOp,
179 PatternRewriter &rewriter)
const override {
181 if (!unpackOp.hasPureTensorSemantics())
184 ShapedType destType = unpackOp.getDestType();
185 if (
failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
186 failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
187 unpackOp.getStaticTiles())) &&
188 !unpackOp.isLikeUnPad()) {
192 ShapedType sourceType = unpackOp.getSourceType();
197 Value collapsed = insertCollapse(
198 rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
200 rewriter.replaceOp(unpackOp, collapsed);
207struct FoldPadWithPackOp :
public OpRewritePattern<PackOp> {
210 : OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
212 LogicalResult matchAndRewrite(PackOp packOp,
213 PatternRewriter &rewriter)
const override {
214 auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
216 if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
220 if (controlFn && !controlFn(&packOp.getSourceMutable()))
223 Value constantPaddingValue = padOp.getConstantPaddingValue();
224 if (!constantPaddingValue)
227 if (
auto paddingValue = packOp.getPaddingValue())
236 ShapedType unpackedType = packOp.getSourceType();
237 SmallVector<int64_t> outerShapeWithoutTranspose =
239 for (
auto [pos, tileSize, high] :
240 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
241 padOp.getMixedHighPad())) {
242 if (unpackedType.isDynamicDim(pos))
244 if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
246 if (ShapedType::isDynamic(tileSize))
251 int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
252 unpackedType.getDimSize(pos);
254 if (paddingSize + cstHigh.value() >= tileSize)
258 rewriter.replaceOpWithNewOp<PackOp>(
259 packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
260 packOp.getMixedTiles(), constantPaddingValue,
261 packOp.getOuterDimsPerm());
271struct FoldUnpackWithExtractSliceOp
272 :
public OpRewritePattern<tensor::ExtractSliceOp> {
274 FoldUnpackWithExtractSliceOp(MLIRContext *context,
276 : OpRewritePattern<tensor::ExtractSliceOp>(context),
277 controlFn(std::move(controlFn)) {}
279 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
280 PatternRewriter &rewriter)
const override {
281 auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
286 if (!unpackOp.hasPureTensorSemantics())
290 if (controlFn && !controlFn(&sliceOp.getSourceMutable()))
293 if (!unpackOp.canFoldSliceOp(sliceOp))
297 Type elementType = unpackOp.getDestType().getElementType();
298 Value output = tensor::EmptyOp::create(
299 rewriter, sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
300 rewriter.replaceOpWithNewOp<UnPackOp>(
301 sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
302 unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
317static bool checkAndPermute(ArrayRef<int64_t> permutation,
318 ArrayRef<int64_t> inVec,
319 SmallVectorImpl<int64_t> &resVec, int64_t rank) {
321 for (
unsigned int i = 0; i < rank; ++i) {
322 int64_t remappedPosition = permutation[i];
323 if (remappedPosition >= rank)
326 remappedPosition = inVec[remappedPosition];
327 resVec.push_back(remappedPosition);
335struct FoldProducerPackWithConsumerLinalgTransposeOp
336 :
public OpInterfaceRewritePattern<linalg::LinalgOp> {
339 FoldProducerPackWithConsumerLinalgTransposeOp(
341 : OpInterfaceRewritePattern<linalg::LinalgOp>(context),
342 controlFn(std::move(controlFn)) {}
344 LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
345 PatternRewriter &rewriter)
const override {
346 auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>();
352 if (!packOp.hasPureTensorSemantics())
356 if (controlFn && !controlFn(&linalgOp->getOpOperand(0)))
359 FailureOr<SmallVector<int64_t>> maybePerm =
360 getTransposeOpPermutation(linalgOp);
364 auto innerDimsPos = packOp.getInnerDimsPos();
365 auto mixedInnerTiles = packOp.getMixedTiles();
366 auto outerDimsPerm = packOp.getOuterDimsPerm();
367 const auto &transposePerm = maybePerm.value();
368 SmallVector<int64_t> newOuterDimsPermVec;
369 SmallVector<int64_t> newInnerDimsPosVec;
370 SmallVector<OpFoldResult> newMixedInnerTilesVec;
371 int64_t srcRank = packOp.getSourceRank();
373 if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
375 return rewriter.notifyMatchFailure(
377 "Cannot fold in tensor.pack if a tile dimension was transposed "
378 "with a non-tile dimension in linalg.transpose.");
381 for (
unsigned int i = srcRank; i < transposePerm.size(); ++i) {
382 int64_t remappedPosition = transposePerm[i] - srcRank;
383 newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]);
384 newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
387 Value output = packOp.createDestinationTensor(
388 rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
389 newInnerDimsPosVec, newOuterDimsPermVec);
391 rewriter.replaceOpWithNewOp<PackOp>(
392 linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
393 newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
404struct FoldConsumerPackWithProducerLinalgTransposeOp
405 :
public OpRewritePattern<PackOp> {
408 FoldConsumerPackWithProducerLinalgTransposeOp(
410 : OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
412 LogicalResult matchAndRewrite(PackOp packOp,
413 PatternRewriter &rewriter)
const override {
415 if (!packOp.hasPureTensorSemantics())
418 auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
423 if (controlFn && !controlFn(&packOp.getSourceMutable()))
426 FailureOr<SmallVector<int64_t>> maybePerm =
427 getTransposeOpPermutation(linalgOp);
431 auto transposePermutation = maybePerm.value();
432 auto outerDimsPerm = packOp.getOuterDimsPerm();
433 auto innerDimsPos = packOp.getInnerDimsPos();
434 SmallVector<int64_t> newInnerDimsPosVec;
435 SmallVector<int64_t> newOuterDimsPermVec =
436 llvm::to_vector(transposePermutation);
438 if (!outerDimsPerm.empty())
443 for (
auto dim : innerDimsPos)
444 newInnerDimsPosVec.push_back(transposePermutation[dim]);
446 Value output = packOp.createDestinationTensor(
447 rewriter, packOp.getLoc(), linalgOp->getOperand(0),
448 packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
450 rewriter.replaceOpWithNewOp<PackOp>(
451 packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
452 packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
463struct FoldProducerUnPackWithConsumerLinalgTransposeOp
464 :
public OpInterfaceRewritePattern<linalg::LinalgOp> {
467 FoldProducerUnPackWithConsumerLinalgTransposeOp(
469 : OpInterfaceRewritePattern<linalg::LinalgOp>(context),
470 controlFn(std::move(controlFn)) {}
472 LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
473 PatternRewriter &rewriter)
const override {
474 auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
480 if (!unPackOp.hasPureTensorSemantics())
484 if (controlFn && !controlFn(&linalgOp->getOpOperand(0)))
487 FailureOr<SmallVector<int64_t>> maybePerm =
488 getTransposeOpPermutation(linalgOp);
492 auto outerDimsPerm = unPackOp.getOuterDimsPerm();
493 auto innerDimsPos = unPackOp.getInnerDimsPos();
494 SmallVector<int64_t> newInnerDimsPosVec;
495 SmallVector<int64_t> newOuterDimsPermVec =
500 for (
auto dim : innerDimsPos)
501 newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]);
503 if (!outerDimsPerm.empty())
507 rewriter.replaceOpWithNewOp<UnPackOp>(
508 linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
509 newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
520struct FoldConsumerUnPackWithProducerLinalgTransposeOp
521 :
public OpRewritePattern<UnPackOp> {
522 using OpRewritePattern<UnPackOp>::OpRewritePattern;
525 FoldConsumerUnPackWithProducerLinalgTransposeOp(
527 : OpRewritePattern<UnPackOp>(context), controlFn(std::move(controlFn)) {}
529 LogicalResult matchAndRewrite(UnPackOp unPackOp,
530 PatternRewriter &rewriter)
const override {
532 if (!unPackOp.hasPureTensorSemantics())
535 auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
540 if (controlFn && !controlFn(&unPackOp.getSourceMutable()))
543 FailureOr<SmallVector<int64_t>> maybePerm =
544 getTransposeOpPermutation(linalgOp);
548 SmallVector<SmallVector<OpFoldResult>> unpackOpResultDims;
553 SmallVector<int64_t> inverseTransposePerm =
555 auto outerDimsPerm = unPackOp.getOuterDimsPerm();
556 auto innerDimsPos = unPackOp.getInnerDimsPos();
557 int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
558 auto mixedInnerTilesVec = unPackOp.getMixedTiles();
559 SmallVector<int64_t> newOuterDimsPermVec;
560 SmallVector<int64_t> newInnerDimsPosVec;
561 SmallVector<OpFoldResult> newMixedInnerTilesVec;
562 if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
563 newOuterDimsPermVec, destRank))
564 return rewriter.notifyMatchFailure(
566 "Cannot fold in tensor.unpack if a tile dimension was transposed "
567 "with a non-tile dimension in linalg.transpose.");
570 for (
unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
571 int64_t remappedPosition = inverseTransposePerm[i] - destRank;
572 newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
573 newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
577 cast<ShapedType>(unPackOp->getResultTypes()[0]).getElementType();
578 Value output = tensor::EmptyOp::create(rewriter, unPackOp->getLoc(),
579 unpackOpResultDims[0], elemType);
581 rewriter.replaceOpWithNewOp<UnPackOp>(
582 unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
583 newMixedInnerTilesVec, newOuterDimsPermVec);
594struct FoldEmptyTensorWithPackOp :
public OpRewritePattern<PackOp> {
595 using OpRewritePattern<PackOp>::OpRewritePattern;
597 LogicalResult matchAndRewrite(PackOp packOp,
598 PatternRewriter &rewriter)
const override {
600 if (!packOp.hasPureTensorSemantics())
604 auto emptyOp = packOp.getSource().getDefiningOp<tensor::EmptyOp>();
610 if (packOp.getPaddingValue())
611 return rewriter.notifyMatchFailure(packOp,
"expects no padding value");
614 rewriter.replaceOp(packOp, packOp.getDest());
622struct FoldEmptyTensorWithUnPackOp :
public OpRewritePattern<UnPackOp> {
623 using OpRewritePattern<UnPackOp>::OpRewritePattern;
625 LogicalResult matchAndRewrite(UnPackOp unPackOp,
626 PatternRewriter &rewriter)
const override {
628 if (!unPackOp.hasPureTensorSemantics())
632 auto emptyOp = unPackOp.getSource().getDefiningOp<tensor::EmptyOp>();
637 rewriter.replaceOp(unPackOp, unPackOp.getDest());
647 patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
648 FoldProducerPackWithConsumerLinalgTransposeOp,
649 FoldConsumerPackWithProducerLinalgTransposeOp,
650 FoldConsumerUnPackWithProducerLinalgTransposeOp,
651 FoldProducerUnPackWithConsumerLinalgTransposeOp>(
656 patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
662 patterns.add<FoldEmptyTensorWithPackOp, FoldEmptyTensorWithUnPackOp>(
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that simplify tensor.pack and tensor.unpack operations.
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn=nullptr)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
std::function< bool(OpOperand *opOperand)> ControlFoldIntoPackUnpackFn
Function type which is used to control folding operations like tensor.pad and tensor....
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
const FrozenRewritePatternSet & patterns
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.