MLIR 23.0.0git
PackAndUnpackPatterns.cpp
Go to the documentation of this file.
1//===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
16
17namespace mlir {
18namespace linalg {
19namespace {
20
21/// Returns the number of shape sizes that is either dynamic or greater than 1.
22static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
23 return llvm::count_if(
24 shape, [](int64_t v) { return ShapedType::isDynamic(v) || v > 1; });
25}
26
27/// Returns success() if there is only 1 dimension size in non-packed domain
28/// being greater than 1 and packing only happens on the dimension.
29/// Note: this method should only be used by pack/unpack to reshape conversion.
30/// It assumes that non-unit inner tile size must be used by the non-unit
31/// dimension.
32static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
33 ArrayRef<int64_t> srcShape,
34 ArrayRef<int64_t> innerPackTileSize) {
35 if (getNumGtOneDims(srcShape) > 1) {
36 return rewriter.notifyMatchFailure(
37 op, "expects non-packed domain to have at most one non-unit dims");
38 }
39 // Non-unit inner tile size must be used by the non-unit dimension. If not, it
40 // will faill on getting reassociation maps.
41 if (getNumGtOneDims(innerPackTileSize) > 1) {
42 return rewriter.notifyMatchFailure(
43 op, "expects at most one non-unit inner tiles");
44 }
45 return success();
46}
47
48// If the `linalgOp` represents a transpose, return the permutation vector for
49// the transpose. Otherwise, return failure.
50static FailureOr<SmallVector<int64_t>>
51getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
52 if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation()))
53 return SmallVector<int64_t>(transposeOp.getPermutation());
54 if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
55 return failure();
56
57 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
58 return failure();
59 auto mapRange = linalgOp.getIndexingMapsArray();
60 if (!mapRange.front().isPermutation() || !mapRange.back().isPermutation() ||
61 mapRange.front() == mapRange.back()) {
62 return failure();
63 }
64 if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
65 return failure();
66 AffineMap outMap = mapRange.back();
67 AffineMap inMap = mapRange.front();
68 // To get the permutation, look at each output index and find which
69 // dimension in the input we're reading from for that index.
70 return llvm::map_to_vector(outMap.getResults(),
71 [&](AffineExpr expr) -> int64_t {
72 return *inMap.getResultPosition(expr);
73 });
74}
75
76/// Packing one-dimensional tensor can be expressed as an expand shape op.
77struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
78 using OpRewritePattern<PackOp>::OpRewritePattern;
79
80 FailureOr<Value>
81 insertExpand(RewriterBase &rewriter, Location loc, Value operand,
82 Type newOperandType,
83 ArrayRef<ReassociationIndices> reassociation) const {
84 if (operand.getType() == newOperandType)
85 return operand;
86 return tensor::ExpandShapeOp::create(rewriter, loc, newOperandType, operand,
87 reassociation)
88 .getResult();
89 }
90
91 /// Returns success() if it is only packing on the innermost dimension.
92 LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter,
93 PackOp packOp) const {
94 auto outerDimsPerm = packOp.getOuterDimsPerm();
95 if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
96 return rewriter.notifyMatchFailure(
97 packOp,
98 "expects outer_dims_perm is empty or an identity permutation");
99 }
100
101 int64_t srcRank = packOp.getSourceRank();
102 ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
103 if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
104 return rewriter.notifyMatchFailure(
105 packOp, "expects packing at the innermost dimension");
106 }
107 return success();
108 }
109
110 LogicalResult matchAndRewrite(PackOp packOp,
111 PatternRewriter &rewriter) const override {
112 if (packOp.getPaddingValue())
113 return rewriter.notifyMatchFailure(packOp, "expects no padding value");
114 // TODO: Support Memref PackOp. Temporarily return failure.
115 if (!packOp.hasPureTensorSemantics())
116 return failure();
117
118 ShapedType sourceType = packOp.getSourceType();
119 if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
120 failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
121 packOp.getStaticTiles())) &&
122 !packOp.isLikePad()) {
123 return failure();
124 }
125
126 ShapedType destType = packOp.getDestType();
127 auto reassociation =
128 getReassociationIndicesForReshape(sourceType, destType);
129 if (!reassociation)
130 return failure();
131 FailureOr<Value> expanded =
132 insertExpand(rewriter, packOp.getLoc(), packOp.getSource(), destType,
133 *reassociation);
134 if (failed(expanded)) {
135 return rewriter.notifyMatchFailure(
136 packOp, "unable to expand source of tensor.pack");
137 }
138 rewriter.replaceOp(packOp, *expanded);
139 return success();
140 }
141};
142
143struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
144 using OpRewritePattern<UnPackOp>::OpRewritePattern;
145
146 Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
147 Type newOperandType, ArrayAttr reassociation) const {
148 if (operand.getType() == newOperandType)
149 return operand;
150 return tensor::CollapseShapeOp::create(rewriter, loc, newOperandType,
151 operand, reassociation);
152 }
153
154 /// Returns success() if it is unpacking on the innermost dimension.
155 LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter,
156 UnPackOp unpackOp) const {
157 auto outerDimsPerm = unpackOp.getOuterDimsPerm();
158 if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
159 return rewriter.notifyMatchFailure(
160 unpackOp,
161 "expects outer_dims_perm is empty or an identity permutation");
162 }
163
164 ShapedType sourceType = unpackOp.getSourceType();
165 ShapedType destType = unpackOp.getDestType();
166 if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
167 return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
168
169 ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
170 if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
171 return rewriter.notifyMatchFailure(
172 unpackOp, "expects unpacking on the innermost dimension");
173 }
174
175 return success();
176 }
177
178 LogicalResult matchAndRewrite(UnPackOp unpackOp,
179 PatternRewriter &rewriter) const override {
180 // TODO: Support Memref UnPackOp. Temporarily return failure.
181 if (!unpackOp.hasPureTensorSemantics())
182 return failure();
183
184 ShapedType destType = unpackOp.getDestType();
185 if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
186 failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
187 unpackOp.getStaticTiles())) &&
188 !unpackOp.isLikeUnPad()) {
189 return failure();
190 }
191
192 ShapedType sourceType = unpackOp.getSourceType();
193 auto reassociation =
194 getReassociationIndicesForReshape(sourceType, destType);
195 if (!reassociation)
196 return failure();
197 Value collapsed = insertCollapse(
198 rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
199 getReassociationIndicesAttribute(rewriter, *reassociation));
200 rewriter.replaceOp(unpackOp, collapsed);
201 return success();
202 }
203};
204
205/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
206/// the pad op has zero low paddings, or if `pack` has no padding values.
207struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
208public:
209 FoldPadWithPackOp(MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
210 : OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
211
212 LogicalResult matchAndRewrite(PackOp packOp,
213 PatternRewriter &rewriter) const override {
214 auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
215
216 if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
217 return failure();
218
219 // User controlled folding function.
220 if (controlFn && !controlFn(&packOp.getSourceMutable()))
221 return failure();
222
223 Value constantPaddingValue = padOp.getConstantPaddingValue();
224 if (!constantPaddingValue)
225 return failure();
226
227 if (auto paddingValue = packOp.getPaddingValue())
228 if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
229 return failure();
230
231 // Folding is not allowed if it were to introduce artificial padding.
232 // Folding is also disabled in the case of dynamic dimensions and/or tile
233 // sizes - that is because it would be impossible to compute the padding
234 // size and hence to establish whether "artificial" padding would be
235 // created.
236 ShapedType unpackedType = packOp.getSourceType();
237 SmallVector<int64_t> outerShapeWithoutTranspose =
239 for (auto [pos, tileSize, high] :
240 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
241 padOp.getMixedHighPad())) {
242 if (unpackedType.isDynamicDim(pos))
243 return failure();
244 if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
245 return failure();
246 if (ShapedType::isDynamic(tileSize))
247 return failure();
248 std::optional<int64_t> cstHigh = getConstantIntValue(high);
249 if (!cstHigh)
250 return failure();
251 int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
252 unpackedType.getDimSize(pos);
253 // Do not fold the op if it requires artificial padding.
254 if (paddingSize + cstHigh.value() >= tileSize)
255 return failure();
256 }
257
258 rewriter.replaceOpWithNewOp<PackOp>(
259 packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
260 packOp.getMixedTiles(), constantPaddingValue,
261 packOp.getOuterDimsPerm());
262 return success();
263 }
264
265private:
267};
268
269/// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
270/// has extract_slice semantics.
271struct FoldUnpackWithExtractSliceOp
272 : public OpRewritePattern<tensor::ExtractSliceOp> {
273public:
274 FoldUnpackWithExtractSliceOp(MLIRContext *context,
276 : OpRewritePattern<tensor::ExtractSliceOp>(context),
277 controlFn(std::move(controlFn)) {}
278
279 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
280 PatternRewriter &rewriter) const override {
281 auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
282 if (!unpackOp)
283 return failure();
284
285 // TODO: Support Memref UnPackOp. Temporarily return failure.
286 if (!unpackOp.hasPureTensorSemantics())
287 return failure();
288
289 // User controlled folding function.
290 if (controlFn && !controlFn(&sliceOp.getSourceMutable()))
291 return failure();
292
293 if (!unpackOp.canFoldSliceOp(sliceOp))
294 return failure();
295
296 // Create a new empty output tensor.
297 Type elementType = unpackOp.getDestType().getElementType();
298 Value output = tensor::EmptyOp::create(
299 rewriter, sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
300 rewriter.replaceOpWithNewOp<UnPackOp>(
301 sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
302 unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
303 return success();
304 }
305
306private:
308};
309
310// Applies 'permutation' on 'inVec' and stores the result in resVec.
311// 'inVec' may be empty, in that case it's one-to-one mapping with permutation.
312// `rank` sets the boundary for permutation i.e., the permutation dim can't be
313// greater than the rank specified. If it's so then return false.
314// For e.g., permutation {1, 0, 3, 2} with rank 2 is allowed since the values in
315// permutation[:rank] doesn't exceed rank, whereas, permutation {1, 3, 0, 2} is
316// not allowed since `3` exceeds the value of the rank in the given range.
317static bool checkAndPermute(ArrayRef<int64_t> permutation,
318 ArrayRef<int64_t> inVec,
319 SmallVectorImpl<int64_t> &resVec, int64_t rank) {
320
321 for (unsigned int i = 0; i < rank; ++i) {
322 int64_t remappedPosition = permutation[i];
323 if (remappedPosition >= rank)
324 return false;
325 if (!inVec.empty())
326 remappedPosition = inVec[remappedPosition];
327 resVec.push_back(remappedPosition);
328 }
329
330 return true;
331}
332
333/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
334/// semantics.
335struct FoldProducerPackWithConsumerLinalgTransposeOp
336 : public OpInterfaceRewritePattern<linalg::LinalgOp> {
337
338public:
339 FoldProducerPackWithConsumerLinalgTransposeOp(
340 MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
341 : OpInterfaceRewritePattern<linalg::LinalgOp>(context),
342 controlFn(std::move(controlFn)) {}
343
344 LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
345 PatternRewriter &rewriter) const override {
346 auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>();
347
348 if (!packOp)
349 return failure();
350
351 // TODO: Support Memref PackOp. Temporarily return failure.
352 if (!packOp.hasPureTensorSemantics())
353 return failure();
354
355 // User controlled folding function.
356 if (controlFn && !controlFn(&linalgOp->getOpOperand(0)))
357 return failure();
358
359 FailureOr<SmallVector<int64_t>> maybePerm =
360 getTransposeOpPermutation(linalgOp);
361 if (failed(maybePerm))
362 return failure();
363
364 auto innerDimsPos = packOp.getInnerDimsPos();
365 auto mixedInnerTiles = packOp.getMixedTiles();
366 auto outerDimsPerm = packOp.getOuterDimsPerm();
367 const auto &transposePerm = maybePerm.value();
368 SmallVector<int64_t> newOuterDimsPermVec;
369 SmallVector<int64_t> newInnerDimsPosVec;
370 SmallVector<OpFoldResult> newMixedInnerTilesVec;
371 int64_t srcRank = packOp.getSourceRank();
372
373 if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
374 srcRank))
375 return rewriter.notifyMatchFailure(
376 linalgOp,
377 "Cannot fold in tensor.pack if a tile dimension was transposed "
378 "with a non-tile dimension in linalg.transpose.");
379
380 // Process transpose operation for tiled inner dimensions
381 for (unsigned int i = srcRank; i < transposePerm.size(); ++i) {
382 int64_t remappedPosition = transposePerm[i] - srcRank;
383 newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]);
384 newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
385 }
386
387 Value output = packOp.createDestinationTensor(
388 rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
389 newInnerDimsPosVec, newOuterDimsPermVec);
390
391 rewriter.replaceOpWithNewOp<PackOp>(
392 linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
393 newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
394
395 return success();
396 }
397
398private:
400};
401
402/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
403/// semantics.
404struct FoldConsumerPackWithProducerLinalgTransposeOp
405 : public OpRewritePattern<PackOp> {
406
407public:
408 FoldConsumerPackWithProducerLinalgTransposeOp(
409 MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
410 : OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
411
412 LogicalResult matchAndRewrite(PackOp packOp,
413 PatternRewriter &rewriter) const override {
414 // TODO: Support Memref PackOp. Temporarily return failure.
415 if (!packOp.hasPureTensorSemantics())
416 return failure();
417
418 auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
419 if (!linalgOp)
420 return failure();
421
422 // User controlled folding function.
423 if (controlFn && !controlFn(&packOp.getSourceMutable()))
424 return failure();
425
426 FailureOr<SmallVector<int64_t>> maybePerm =
427 getTransposeOpPermutation(linalgOp);
428 if (failed(maybePerm))
429 return failure();
430
431 auto transposePermutation = maybePerm.value();
432 auto outerDimsPerm = packOp.getOuterDimsPerm();
433 auto innerDimsPos = packOp.getInnerDimsPos();
434 SmallVector<int64_t> newInnerDimsPosVec;
435 SmallVector<int64_t> newOuterDimsPermVec =
436 llvm::to_vector(transposePermutation);
437
438 if (!outerDimsPerm.empty())
439 applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
440
441 // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
442 // permutation rank won't necessarily be equal in all cases.
443 for (auto dim : innerDimsPos)
444 newInnerDimsPosVec.push_back(transposePermutation[dim]);
445
446 Value output = packOp.createDestinationTensor(
447 rewriter, packOp.getLoc(), linalgOp->getOperand(0),
448 packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
449
450 rewriter.replaceOpWithNewOp<PackOp>(
451 packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
452 packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
453
454 return success();
455 }
456
457private:
459};
460
461/// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
462/// transpose semantics.
463struct FoldProducerUnPackWithConsumerLinalgTransposeOp
464 : public OpInterfaceRewritePattern<linalg::LinalgOp> {
465
466public:
467 FoldProducerUnPackWithConsumerLinalgTransposeOp(
468 MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
469 : OpInterfaceRewritePattern<linalg::LinalgOp>(context),
470 controlFn(std::move(controlFn)) {}
471
472 LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
473 PatternRewriter &rewriter) const override {
474 auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
475
476 if (!unPackOp)
477 return failure();
478
479 // TODO: Support Memref UnPackOp. Temporarily return failure.
480 if (!unPackOp.hasPureTensorSemantics())
481 return failure();
482
483 // User controlled folding function.
484 if (controlFn && !controlFn(&linalgOp->getOpOperand(0)))
485 return failure();
486
487 FailureOr<SmallVector<int64_t>> maybePerm =
488 getTransposeOpPermutation(linalgOp);
489 if (failed(maybePerm))
490 return failure();
491
492 auto outerDimsPerm = unPackOp.getOuterDimsPerm();
493 auto innerDimsPos = unPackOp.getInnerDimsPos();
494 SmallVector<int64_t> newInnerDimsPosVec;
495 SmallVector<int64_t> newOuterDimsPermVec =
496 invertPermutationVector(maybePerm.value());
497
498 // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
499 // permutation rank won't necessarily be equal in all cases.
500 for (auto dim : innerDimsPos)
501 newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]);
502
503 if (!outerDimsPerm.empty())
504 applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
505
506 // Reuse the destination of the transpose op.
507 rewriter.replaceOpWithNewOp<UnPackOp>(
508 linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
509 newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
510
511 return success();
512 }
513
514private:
516};
517
518/// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
519/// transpose semantics.
520struct FoldConsumerUnPackWithProducerLinalgTransposeOp
521 : public OpRewritePattern<UnPackOp> {
522 using OpRewritePattern<UnPackOp>::OpRewritePattern;
523
524public:
525 FoldConsumerUnPackWithProducerLinalgTransposeOp(
526 MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
527 : OpRewritePattern<UnPackOp>(context), controlFn(std::move(controlFn)) {}
528
529 LogicalResult matchAndRewrite(UnPackOp unPackOp,
530 PatternRewriter &rewriter) const override {
531 // TODO: Support Memref UnPackOp. Temporarily return failure.
532 if (!unPackOp.hasPureTensorSemantics())
533 return failure();
534
535 auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
536 if (!linalgOp)
537 return failure();
538
539 // User controlled folding function.
540 if (controlFn && !controlFn(&unPackOp.getSourceMutable()))
541 return failure();
542
543 FailureOr<SmallVector<int64_t>> maybePerm =
544 getTransposeOpPermutation(linalgOp);
545 if (failed(maybePerm))
546 return failure();
547
548 SmallVector<SmallVector<OpFoldResult>> unpackOpResultDims;
549 if (failed(reifyResultShapes(rewriter, unPackOp, unpackOpResultDims))) {
550 return failure();
551 }
552
553 SmallVector<int64_t> inverseTransposePerm =
554 invertPermutationVector(maybePerm.value());
555 auto outerDimsPerm = unPackOp.getOuterDimsPerm();
556 auto innerDimsPos = unPackOp.getInnerDimsPos();
557 int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
558 auto mixedInnerTilesVec = unPackOp.getMixedTiles();
559 SmallVector<int64_t> newOuterDimsPermVec;
560 SmallVector<int64_t> newInnerDimsPosVec;
561 SmallVector<OpFoldResult> newMixedInnerTilesVec;
562 if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
563 newOuterDimsPermVec, destRank))
564 return rewriter.notifyMatchFailure(
565 unPackOp,
566 "Cannot fold in tensor.unpack if a tile dimension was transposed "
567 "with a non-tile dimension in linalg.transpose.");
568
569 // Process transpose operation for tiled inner dimensions
570 for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
571 int64_t remappedPosition = inverseTransposePerm[i] - destRank;
572 newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
573 newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
574 }
575
576 auto elemType =
577 cast<ShapedType>(unPackOp->getResultTypes()[0]).getElementType();
578 Value output = tensor::EmptyOp::create(rewriter, unPackOp->getLoc(),
579 unpackOpResultDims[0], elemType);
580
581 rewriter.replaceOpWithNewOp<UnPackOp>(
582 unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
583 newMixedInnerTilesVec, newOuterDimsPermVec);
584
585 return success();
586 }
587
588private:
590};
591
592/// tensor.empty does not define any tensor contents, so an unpadded pack
593/// can be folded away.
594struct FoldEmptyTensorWithPackOp : public OpRewritePattern<PackOp> {
595 using OpRewritePattern<PackOp>::OpRewritePattern;
596
597 LogicalResult matchAndRewrite(PackOp packOp,
598 PatternRewriter &rewriter) const override {
599 // TODO: Support Memref PackOp. Temporarily return failure.
600 if (!packOp.hasPureTensorSemantics())
601 return failure();
602
603 // Check for tensor.empty source.
604 auto emptyOp = packOp.getSource().getDefiningOp<tensor::EmptyOp>();
605 if (!emptyOp)
606 return failure();
607
608 // Check for padding.
609 // Packing with padding cannot be simply removed.
610 if (packOp.getPaddingValue())
611 return rewriter.notifyMatchFailure(packOp, "expects no padding value");
612
613 // Replace the pack directly with its destination.
614 rewriter.replaceOp(packOp, packOp.getDest());
615
616 return success();
617 }
618};
619
620/// tensor.empty does not define any tensor contents, so an unpack
621/// can be folded away.
622struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
623 using OpRewritePattern<UnPackOp>::OpRewritePattern;
624
625 LogicalResult matchAndRewrite(UnPackOp unPackOp,
626 PatternRewriter &rewriter) const override {
627 // TODO: Support Memref UnPackOp. Temporarily return failure.
628 if (!unPackOp.hasPureTensorSemantics())
629 return failure();
630
631 // Check for tensor.empty source.
632 auto emptyOp = unPackOp.getSource().getDefiningOp<tensor::EmptyOp>();
633 if (!emptyOp)
634 return failure();
635
636 // Replace the unpack directly with its destination.
637 rewriter.replaceOp(unPackOp, unPackOp.getDest());
638
639 return success();
640 }
641};
642
643} // namespace
644
647 patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
648 FoldProducerPackWithConsumerLinalgTransposeOp,
649 FoldConsumerPackWithProducerLinalgTransposeOp,
650 FoldConsumerUnPackWithProducerLinalgTransposeOp,
651 FoldProducerUnPackWithConsumerLinalgTransposeOp>(
652 patterns.getContext(), controlFn);
653}
654
656 patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
657 patterns.getContext());
658}
659
662 patterns.add<FoldEmptyTensorWithPackOp, FoldEmptyTensorWithUnPackOp>(
663 patterns.getContext());
664}
665
666} // namespace linalg
667} // namespace mlir
return success()
ArrayAttr()
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that simplify tensor.pack and tensor.unpack operations.
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn=nullptr)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
std::function< bool(OpOperand *opOperand)> ControlFoldIntoPackUnpackFn
Function type which is used to control folding operations like tensor.pad and tensor....
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
const FrozenRewritePatternSet & patterns
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.