21 static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
22 return llvm::count_if(
23 shape, [](int64_t v) {
return ShapedType::isDynamic(v) || v > 1; });
31 static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
32 ArrayRef<int64_t> srcShape,
33 ArrayRef<int64_t> innerPackTileSize) {
34 if (getNumGtOneDims(srcShape) > 1) {
35 return rewriter.notifyMatchFailure(
36 op,
"expects non-packed domain to have at most one non-unit dims");
40 if (getNumGtOneDims(innerPackTileSize) > 1) {
41 return rewriter.notifyMatchFailure(
42 op,
"expects at most one non-unit inner tiles");
49 static FailureOr<SmallVector<int64_t>>
50 getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
51 if (
auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation()))
52 return SmallVector<int64_t>(transposeOp.getPermutation());
53 if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
56 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
58 auto mapRange = linalgOp.getIndexingMapsArray();
59 if (!mapRange.front().isPermutation() || !mapRange.back().isPermutation() ||
60 mapRange.front() == mapRange.back()) {
63 if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
65 AffineMap outMap = mapRange.back();
66 AffineMap inMap = mapRange.front();
69 return llvm::map_to_vector(outMap.getResults(),
70 [&](AffineExpr expr) -> int64_t {
71 return *inMap.getResultPosition(expr);
76 struct SimplifyPackToExpandShape :
public OpRewritePattern<PackOp> {
80 insertExpand(RewriterBase &rewriter, Location loc, Value operand,
82 ArrayRef<ReassociationIndices> reassociation)
const {
83 if (operand.getType() == newOperandType)
85 return tensor::ExpandShapeOp::create(rewriter, loc, newOperandType, operand,
91 LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter,
92 PackOp packOp)
const {
95 return rewriter.notifyMatchFailure(
97 "expects outer_dims_perm is empty or an identity permutation");
100 int64_t srcRank = packOp.getSourceRank();
101 ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
102 if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
103 return rewriter.notifyMatchFailure(
104 packOp,
"expects packing at the innermost dimension");
109 LogicalResult matchAndRewrite(PackOp packOp,
110 PatternRewriter &rewriter)
const override {
111 if (packOp.getPaddingValue())
112 return rewriter.notifyMatchFailure(packOp,
"expects no padding value");
114 RankedTensorType sourceType = packOp.getSourceType();
115 if (
failed(isPackOnInnerMostDim(rewriter, packOp)) &&
116 failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
117 packOp.getStaticTiles())) &&
118 !packOp.isLikePad()) {
122 RankedTensorType destType = packOp.getDestType();
127 FailureOr<Value> expanded =
128 insertExpand(rewriter, packOp.getLoc(), packOp.getSource(), destType,
131 return rewriter.notifyMatchFailure(
132 packOp,
"unable to expand source of tensor.pack");
134 rewriter.replaceOp(packOp, *expanded);
139 struct SimplifyUnPackToCollapseShape :
public OpRewritePattern<UnPackOp> {
142 Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
143 Type newOperandType, ArrayAttr reassociation)
const {
144 if (operand.getType() == newOperandType)
146 return tensor::CollapseShapeOp::create(rewriter, loc, newOperandType,
147 operand, reassociation);
151 LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter,
152 UnPackOp unpackOp)
const {
155 return rewriter.notifyMatchFailure(
157 "expects outer_dims_perm is empty or an identity permutation");
160 RankedTensorType sourceType = unpackOp.getSourceType();
161 RankedTensorType destType = unpackOp.getDestType();
162 if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
163 return rewriter.notifyMatchFailure(unpackOp,
"expects static shapes");
165 ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
166 if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
167 return rewriter.notifyMatchFailure(
168 unpackOp,
"expects unpacking on the innermost dimension");
174 LogicalResult matchAndRewrite(UnPackOp unpackOp,
175 PatternRewriter &rewriter)
const override {
176 RankedTensorType destType = unpackOp.getDestType();
177 if (
failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
178 failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
179 unpackOp.getStaticTiles())) &&
180 !unpackOp.isLikeUnPad()) {
184 RankedTensorType sourceType = unpackOp.getSourceType();
189 Value collapsed = insertCollapse(
190 rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
192 rewriter.replaceOp(unpackOp, collapsed);
199 struct FoldPadWithPackOp :
public OpRewritePattern<PackOp> {
202 : OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
204 LogicalResult matchAndRewrite(PackOp packOp,
205 PatternRewriter &rewriter)
const override {
206 auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
208 if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
212 if (controlFn && !controlFn(&packOp.getSourceMutable()))
215 Value constantPaddingValue = padOp.getConstantPaddingValue();
216 if (!constantPaddingValue)
219 if (
auto paddingValue = packOp.getPaddingValue())
228 RankedTensorType unpackedType = packOp.getSourceType();
229 SmallVector<int64_t> outerShapeWithoutTranspose =
231 for (
auto [pos, tileSize, high] :
232 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
233 padOp.getMixedHighPad())) {
234 if (unpackedType.isDynamicDim(pos))
236 if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
238 if (ShapedType::isDynamic(tileSize))
243 int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
244 unpackedType.getDimSize(pos);
246 if (paddingSize + cstHigh.value() >= tileSize)
250 rewriter.replaceOpWithNewOp<PackOp>(
251 packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
252 packOp.getMixedTiles(), constantPaddingValue,
253 packOp.getOuterDimsPerm());
263 struct FoldUnpackWithExtractSliceOp
264 :
public OpRewritePattern<tensor::ExtractSliceOp> {
266 FoldUnpackWithExtractSliceOp(MLIRContext *context,
268 : OpRewritePattern<tensor::ExtractSliceOp>(context),
269 controlFn(std::move(controlFn)) {}
271 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
272 PatternRewriter &rewriter)
const override {
273 auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
278 if (controlFn && !controlFn(&sliceOp.getSourceMutable()))
281 if (!unpackOp.canFoldSliceOp(sliceOp))
285 Type elementType = unpackOp.getDestType().getElementType();
286 Value output = tensor::EmptyOp::create(
287 rewriter, sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
288 rewriter.replaceOpWithNewOp<UnPackOp>(
289 sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
290 unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
305 static bool checkAndPermute(ArrayRef<int64_t> permutation,
306 ArrayRef<int64_t> inVec,
307 SmallVectorImpl<int64_t> &resVec, int64_t rank) {
309 for (
unsigned int i = 0; i < rank; ++i) {
310 int64_t remappedPosition = permutation[i];
311 if (remappedPosition >= rank)
314 remappedPosition = inVec[remappedPosition];
315 resVec.push_back(remappedPosition);
323 struct FoldProducerPackWithConsumerLinalgTransposeOp
324 :
public OpInterfaceRewritePattern<linalg::LinalgOp> {
327 FoldProducerPackWithConsumerLinalgTransposeOp(
329 : OpInterfaceRewritePattern<linalg::LinalgOp>(context),
330 controlFn(std::move(controlFn)) {}
332 LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
333 PatternRewriter &rewriter)
const override {
334 auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>();
340 if (controlFn && !controlFn(&linalgOp->getOpOperand(0)))
343 FailureOr<SmallVector<int64_t>> maybePerm =
344 getTransposeOpPermutation(linalgOp);
349 auto mixedInnerTiles = packOp.getMixedTiles();
351 auto transposePerm = maybePerm.value();
352 SmallVector<int64_t> newOuterDimsPermVec;
353 SmallVector<int64_t> newInnerDimsPosVec;
354 SmallVector<OpFoldResult> newMixedInnerTilesVec;
355 int64_t srcRank = packOp.getSourceRank();
357 if (!checkAndPermute(transposePerm,
outerDimsPerm, newOuterDimsPermVec,
359 return rewriter.notifyMatchFailure(
361 "Cannot fold in tensor.pack if a tile dimension was transposed "
362 "with a non-tile dimension in linalg.transpose.");
365 for (
unsigned int i = srcRank; i < transposePerm.size(); ++i) {
366 int64_t remappedPosition = transposePerm[i] - srcRank;
367 newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]);
368 newInnerDimsPosVec.push_back(
innerDimsPos[remappedPosition]);
371 Value output = packOp.createDestinationTensor(
372 rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
373 newInnerDimsPosVec, newOuterDimsPermVec);
375 rewriter.replaceOpWithNewOp<PackOp>(
376 linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
377 newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
388 struct FoldConsumerPackWithProducerLinalgTransposeOp
389 :
public OpRewritePattern<PackOp> {
392 FoldConsumerPackWithProducerLinalgTransposeOp(
394 : OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
396 LogicalResult matchAndRewrite(PackOp packOp,
397 PatternRewriter &rewriter)
const override {
398 auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
403 if (controlFn && !controlFn(&packOp.getSourceMutable()))
406 FailureOr<SmallVector<int64_t>> maybePerm =
407 getTransposeOpPermutation(linalgOp);
411 auto transposePermutation = maybePerm.value();
414 SmallVector<int64_t> newInnerDimsPosVec;
415 SmallVector<int64_t> newOuterDimsPermVec =
416 llvm::to_vector(transposePermutation);
424 newInnerDimsPosVec.push_back(transposePermutation[dim]);
426 Value output = packOp.createDestinationTensor(
427 rewriter, packOp.getLoc(), linalgOp->getOperand(0),
428 packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
430 rewriter.replaceOpWithNewOp<PackOp>(
431 packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
432 packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
443 struct FoldProducerUnPackWithConsumerLinalgTransposeOp
444 :
public OpInterfaceRewritePattern<linalg::LinalgOp> {
447 FoldProducerUnPackWithConsumerLinalgTransposeOp(
449 : OpInterfaceRewritePattern<linalg::LinalgOp>(context),
450 controlFn(std::move(controlFn)) {}
452 LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
453 PatternRewriter &rewriter)
const override {
454 auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
460 if (controlFn && !controlFn(&linalgOp->getOpOperand(0)))
463 FailureOr<SmallVector<int64_t>> maybePerm =
464 getTransposeOpPermutation(linalgOp);
470 SmallVector<int64_t> newInnerDimsPosVec;
471 SmallVector<int64_t> newOuterDimsPermVec =
477 newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]);
483 rewriter.replaceOpWithNewOp<UnPackOp>(
484 linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
485 newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
496 struct FoldConsumerUnPackWithProducerLinalgTransposeOp
497 :
public OpRewritePattern<UnPackOp> {
501 FoldConsumerUnPackWithProducerLinalgTransposeOp(
503 : OpRewritePattern<UnPackOp>(context), controlFn(std::move(controlFn)) {}
505 LogicalResult matchAndRewrite(UnPackOp unPackOp,
506 PatternRewriter &rewriter)
const override {
507 auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
512 if (controlFn && !controlFn(&unPackOp.getSourceMutable()))
515 FailureOr<SmallVector<int64_t>> maybePerm =
516 getTransposeOpPermutation(linalgOp);
520 SmallVector<SmallVector<OpFoldResult>> unpackOpResultDims;
525 SmallVector<int64_t> inverseTransposePerm =
529 int64_t destRank = unPackOp.getSourceRank() -
innerDimsPos.size();
530 auto mixedInnerTilesVec = unPackOp.getMixedTiles();
531 SmallVector<int64_t> newOuterDimsPermVec;
532 SmallVector<int64_t> newInnerDimsPosVec;
533 SmallVector<OpFoldResult> newMixedInnerTilesVec;
535 newOuterDimsPermVec, destRank))
536 return rewriter.notifyMatchFailure(
538 "Cannot fold in tensor.unpack if a tile dimension was transposed "
539 "with a non-tile dimension in linalg.transpose.");
542 for (
unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
543 int64_t remappedPosition = inverseTransposePerm[i] - destRank;
544 newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
545 newInnerDimsPosVec.push_back(
innerDimsPos[remappedPosition]);
549 cast<ShapedType>(unPackOp->getResultTypes()[0]).getElementType();
550 Value output = tensor::EmptyOp::create(rewriter, unPackOp->getLoc(),
551 unpackOpResultDims[0], elemType);
553 rewriter.replaceOpWithNewOp<UnPackOp>(
554 unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
555 newMixedInnerTilesVec, newOuterDimsPermVec);
566 struct FoldEmptyTensorWithPackOp :
public OpRewritePattern<PackOp> {
569 LogicalResult matchAndRewrite(PackOp packOp,
570 PatternRewriter &rewriter)
const override {
572 auto emptyOp = packOp.getSource().getDefiningOp<tensor::EmptyOp>();
578 if (packOp.getPaddingValue())
579 return rewriter.notifyMatchFailure(packOp,
"expects no padding value");
582 rewriter.replaceOp(packOp, packOp.getDest());
590 struct FoldEmptyTensorWithUnPackOp :
public OpRewritePattern<UnPackOp> {
593 LogicalResult matchAndRewrite(UnPackOp unPackOp,
594 PatternRewriter &rewriter)
const override {
596 auto emptyOp = unPackOp.getSource().getDefiningOp<tensor::EmptyOp>();
601 rewriter.replaceOp(unPackOp, unPackOp.getDest());
611 patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
612 FoldProducerPackWithConsumerLinalgTransposeOp,
613 FoldConsumerPackWithProducerLinalgTransposeOp,
614 FoldConsumerUnPackWithProducerLinalgTransposeOp,
615 FoldProducerUnPackWithConsumerLinalgTransposeOp>(
620 patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
626 patterns.add<FoldEmptyTensorWithPackOp, FoldEmptyTensorWithUnPackOp>(
SmallVector< int64_t > outerDimsPerm
SmallVector< int64_t > innerDimsPos
std::function< bool(OpOperand *opOperand)> ControlFoldIntoPackUnpackFn
Function type which is used to control folding operations like tensor.pad and tensor....
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that simplify tensor.pack and tensor.unpack operations.
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn=nullptr)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
@ Type
An inlay hint that for a type annotation.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
const FrozenRewritePatternSet & patterns
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...