MLIR 22.0.0git
PackAndUnpackPatterns.cpp
Go to the documentation of this file.
1//===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15
16namespace mlir {
17namespace linalg {
18namespace {
19
20/// Returns the number of shape sizes that is either dynamic or greater than 1.
21static int64_t getNumGtOneDims(ArrayRef<int64_t> shape) {
22 return llvm::count_if(
23 shape, [](int64_t v) { return ShapedType::isDynamic(v) || v > 1; });
24}
25
26/// Returns success() if there is only 1 dimension size in non-packed domain
27/// being greater than 1 and packing only happens on the dimension.
28/// Note: this method should only be used by pack/unpack to reshape conversion.
29/// It assumes that non-unit inner tile size must be used by the non-unit
30/// dimension.
31static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
32 ArrayRef<int64_t> srcShape,
33 ArrayRef<int64_t> innerPackTileSize) {
34 if (getNumGtOneDims(srcShape) > 1) {
35 return rewriter.notifyMatchFailure(
36 op, "expects non-packed domain to have at most one non-unit dims");
37 }
38 // Non-unit inner tile size must be used by the non-unit dimension. If not, it
39 // will faill on getting reassociation maps.
40 if (getNumGtOneDims(innerPackTileSize) > 1) {
41 return rewriter.notifyMatchFailure(
42 op, "expects at most one non-unit inner tiles");
43 }
44 return success();
45}
46
47// If the `linalgOp` represents a transpose, return the permutation vector for
48// the transpose. Otherwise, return failure.
49static FailureOr<SmallVector<int64_t>>
50getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
51 if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation()))
52 return SmallVector<int64_t>(transposeOp.getPermutation());
53 if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
54 return failure();
55
56 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
57 return failure();
58 auto mapRange = linalgOp.getIndexingMapsArray();
59 if (!mapRange.front().isPermutation() || !mapRange.back().isPermutation() ||
60 mapRange.front() == mapRange.back()) {
61 return failure();
62 }
63 if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
64 return failure();
65 AffineMap outMap = mapRange.back();
66 AffineMap inMap = mapRange.front();
67 // To get the permutation, look at each output index and find which
68 // dimension in the input we're reading from for that index.
69 return llvm::map_to_vector(outMap.getResults(),
70 [&](AffineExpr expr) -> int64_t {
71 return *inMap.getResultPosition(expr);
72 });
73}
74
75/// Packing one-dimensional tensor can be expressed as an expand shape op.
76struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
77 using OpRewritePattern<PackOp>::OpRewritePattern;
78
79 FailureOr<Value>
80 insertExpand(RewriterBase &rewriter, Location loc, Value operand,
81 Type newOperandType,
82 ArrayRef<ReassociationIndices> reassociation) const {
83 if (operand.getType() == newOperandType)
84 return operand;
85 return tensor::ExpandShapeOp::create(rewriter, loc, newOperandType, operand,
86 reassociation)
87 .getResult();
88 }
89
90 /// Returns success() if it is only packing on the innermost dimension.
91 LogicalResult isPackOnInnerMostDim(RewriterBase &rewriter,
92 PackOp packOp) const {
93 auto outerDimsPerm = packOp.getOuterDimsPerm();
94 if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
95 return rewriter.notifyMatchFailure(
96 packOp,
97 "expects outer_dims_perm is empty or an identity permutation");
98 }
99
100 int64_t srcRank = packOp.getSourceRank();
101 ArrayRef<int64_t> dimsPos = packOp.getInnerDimsPos();
102 if (dimsPos.size() != 1 || (dimsPos[0] + 1 != srcRank)) {
103 return rewriter.notifyMatchFailure(
104 packOp, "expects packing at the innermost dimension");
105 }
106 return success();
107 }
108
109 LogicalResult matchAndRewrite(PackOp packOp,
110 PatternRewriter &rewriter) const override {
111 if (packOp.getPaddingValue())
112 return rewriter.notifyMatchFailure(packOp, "expects no padding value");
113
114 RankedTensorType sourceType = packOp.getSourceType();
115 if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
116 failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
117 packOp.getStaticTiles())) &&
118 !packOp.isLikePad()) {
119 return failure();
120 }
121
122 RankedTensorType destType = packOp.getDestType();
123 auto reassociation =
124 getReassociationIndicesForReshape(sourceType, destType);
125 if (!reassociation)
126 return failure();
127 FailureOr<Value> expanded =
128 insertExpand(rewriter, packOp.getLoc(), packOp.getSource(), destType,
129 *reassociation);
130 if (failed(expanded)) {
131 return rewriter.notifyMatchFailure(
132 packOp, "unable to expand source of tensor.pack");
133 }
134 rewriter.replaceOp(packOp, *expanded);
135 return success();
136 }
137};
138
139struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
140 using OpRewritePattern<UnPackOp>::OpRewritePattern;
141
142 Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
143 Type newOperandType, ArrayAttr reassociation) const {
144 if (operand.getType() == newOperandType)
145 return operand;
146 return tensor::CollapseShapeOp::create(rewriter, loc, newOperandType,
147 operand, reassociation);
148 }
149
150 /// Returns success() if it is unpacking on the innermost dimension.
151 LogicalResult isUnpackOnInnerMostDim(RewriterBase &rewriter,
152 UnPackOp unpackOp) const {
153 auto outerDimsPerm = unpackOp.getOuterDimsPerm();
154 if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
155 return rewriter.notifyMatchFailure(
156 unpackOp,
157 "expects outer_dims_perm is empty or an identity permutation");
158 }
159
160 RankedTensorType sourceType = unpackOp.getSourceType();
161 RankedTensorType destType = unpackOp.getDestType();
162 if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
163 return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
164
165 ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
166 if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
167 return rewriter.notifyMatchFailure(
168 unpackOp, "expects unpacking on the innermost dimension");
169 }
170
171 return success();
172 }
173
174 LogicalResult matchAndRewrite(UnPackOp unpackOp,
175 PatternRewriter &rewriter) const override {
176 RankedTensorType destType = unpackOp.getDestType();
177 if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
178 failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
179 unpackOp.getStaticTiles())) &&
180 !unpackOp.isLikeUnPad()) {
181 return failure();
182 }
183
184 RankedTensorType sourceType = unpackOp.getSourceType();
185 auto reassociation =
186 getReassociationIndicesForReshape(sourceType, destType);
187 if (!reassociation)
188 return failure();
189 Value collapsed = insertCollapse(
190 rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
191 getReassociationIndicesAttribute(rewriter, *reassociation));
192 rewriter.replaceOp(unpackOp, collapsed);
193 return success();
194 }
195};
196
197/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
198/// the pad op has zero low paddings, or if `pack` has no padding values.
199struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
200public:
201 FoldPadWithPackOp(MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
202 : OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
203
204 LogicalResult matchAndRewrite(PackOp packOp,
205 PatternRewriter &rewriter) const override {
206 auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
207
208 if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
209 return failure();
210
211 // User controlled folding function.
212 if (controlFn && !controlFn(&packOp.getSourceMutable()))
213 return failure();
214
215 Value constantPaddingValue = padOp.getConstantPaddingValue();
216 if (!constantPaddingValue)
217 return failure();
218
219 if (auto paddingValue = packOp.getPaddingValue())
220 if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
221 return failure();
222
223 // Folding is not allowed if it were to introduce artificial padding.
224 // Folding is also disabled in the case of dynamic dimensions and/or tile
225 // sizes - that is because it would be impossible to compute the padding
226 // size and hence to establish whether "artificial" padding would be
227 // created.
228 RankedTensorType unpackedType = packOp.getSourceType();
229 SmallVector<int64_t> outerShapeWithoutTranspose =
231 for (auto [pos, tileSize, high] :
232 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
233 padOp.getMixedHighPad())) {
234 if (unpackedType.isDynamicDim(pos))
235 return failure();
236 if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
237 return failure();
238 if (ShapedType::isDynamic(tileSize))
239 return failure();
240 std::optional<int64_t> cstHigh = getConstantIntValue(high);
241 if (!cstHigh)
242 return failure();
243 int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
244 unpackedType.getDimSize(pos);
245 // Do not fold the op if it requires artificial padding.
246 if (paddingSize + cstHigh.value() >= tileSize)
247 return failure();
248 }
249
250 rewriter.replaceOpWithNewOp<PackOp>(
251 packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
252 packOp.getMixedTiles(), constantPaddingValue,
253 packOp.getOuterDimsPerm());
254 return success();
255 }
256
257private:
259};
260
261/// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
262/// has extract_slice semantics.
263struct FoldUnpackWithExtractSliceOp
264 : public OpRewritePattern<tensor::ExtractSliceOp> {
265public:
266 FoldUnpackWithExtractSliceOp(MLIRContext *context,
268 : OpRewritePattern<tensor::ExtractSliceOp>(context),
269 controlFn(std::move(controlFn)) {}
270
271 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
272 PatternRewriter &rewriter) const override {
273 auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
274 if (!unpackOp)
275 return failure();
276
277 // User controlled folding function.
278 if (controlFn && !controlFn(&sliceOp.getSourceMutable()))
279 return failure();
280
281 if (!unpackOp.canFoldSliceOp(sliceOp))
282 return failure();
283
284 // Create a new empty output tensor.
285 Type elementType = unpackOp.getDestType().getElementType();
286 Value output = tensor::EmptyOp::create(
287 rewriter, sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
288 rewriter.replaceOpWithNewOp<UnPackOp>(
289 sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
290 unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
291 return success();
292 }
293
294private:
296};
297
298// Applies 'permutation' on 'inVec' and stores the result in resVec.
299// 'inVec' may be empty, in that case it's one-to-one mapping with permutation.
300// `rank` sets the boundary for permutation i.e., the permutation dim can't be
301// greater than the rank specified. If it's so then return false.
302// For e.g., permutation {1, 0, 3, 2} with rank 2 is allowed since the values in
303// permutation[:rank] doesn't exceed rank, whereas, permutation {1, 3, 0, 2} is
304// not allowed since `3` exceeds the value of the rank in the given range.
305static bool checkAndPermute(ArrayRef<int64_t> permutation,
306 ArrayRef<int64_t> inVec,
307 SmallVectorImpl<int64_t> &resVec, int64_t rank) {
308
309 for (unsigned int i = 0; i < rank; ++i) {
310 int64_t remappedPosition = permutation[i];
311 if (remappedPosition >= rank)
312 return false;
313 if (!inVec.empty())
314 remappedPosition = inVec[remappedPosition];
315 resVec.push_back(remappedPosition);
316 }
317
318 return true;
319}
320
321/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
322/// semantics.
323struct FoldProducerPackWithConsumerLinalgTransposeOp
324 : public OpInterfaceRewritePattern<linalg::LinalgOp> {
325
326public:
327 FoldProducerPackWithConsumerLinalgTransposeOp(
328 MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
329 : OpInterfaceRewritePattern<linalg::LinalgOp>(context),
330 controlFn(std::move(controlFn)) {}
331
332 LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
333 PatternRewriter &rewriter) const override {
334 auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>();
335
336 if (!packOp)
337 return failure();
338
339 // User controlled folding function.
340 if (controlFn && !controlFn(&linalgOp->getOpOperand(0)))
341 return failure();
342
343 FailureOr<SmallVector<int64_t>> maybePerm =
344 getTransposeOpPermutation(linalgOp);
345 if (failed(maybePerm))
346 return failure();
347
348 auto innerDimsPos = packOp.getInnerDimsPos();
349 auto mixedInnerTiles = packOp.getMixedTiles();
350 auto outerDimsPerm = packOp.getOuterDimsPerm();
351 const auto &transposePerm = maybePerm.value();
352 SmallVector<int64_t> newOuterDimsPermVec;
353 SmallVector<int64_t> newInnerDimsPosVec;
354 SmallVector<OpFoldResult> newMixedInnerTilesVec;
355 int64_t srcRank = packOp.getSourceRank();
356
357 if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
358 srcRank))
359 return rewriter.notifyMatchFailure(
360 linalgOp,
361 "Cannot fold in tensor.pack if a tile dimension was transposed "
362 "with a non-tile dimension in linalg.transpose.");
363
364 // Process transpose operation for tiled inner dimensions
365 for (unsigned int i = srcRank; i < transposePerm.size(); ++i) {
366 int64_t remappedPosition = transposePerm[i] - srcRank;
367 newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]);
368 newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
369 }
370
371 Value output = packOp.createDestinationTensor(
372 rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
373 newInnerDimsPosVec, newOuterDimsPermVec);
374
375 rewriter.replaceOpWithNewOp<PackOp>(
376 linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
377 newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
378
379 return success();
380 }
381
382private:
384};
385
386/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
387/// semantics.
388struct FoldConsumerPackWithProducerLinalgTransposeOp
389 : public OpRewritePattern<PackOp> {
390
391public:
392 FoldConsumerPackWithProducerLinalgTransposeOp(
393 MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
394 : OpRewritePattern<PackOp>(context), controlFn(std::move(controlFn)) {}
395
396 LogicalResult matchAndRewrite(PackOp packOp,
397 PatternRewriter &rewriter) const override {
398 auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
399 if (!linalgOp)
400 return failure();
401
402 // User controlled folding function.
403 if (controlFn && !controlFn(&packOp.getSourceMutable()))
404 return failure();
405
406 FailureOr<SmallVector<int64_t>> maybePerm =
407 getTransposeOpPermutation(linalgOp);
408 if (failed(maybePerm))
409 return failure();
410
411 auto transposePermutation = maybePerm.value();
412 auto outerDimsPerm = packOp.getOuterDimsPerm();
413 auto innerDimsPos = packOp.getInnerDimsPos();
414 SmallVector<int64_t> newInnerDimsPosVec;
415 SmallVector<int64_t> newOuterDimsPermVec =
416 llvm::to_vector(transposePermutation);
417
418 if (!outerDimsPerm.empty())
419 applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
420
421 // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
422 // permutation rank won't necessarily be equal in all cases.
423 for (auto dim : innerDimsPos)
424 newInnerDimsPosVec.push_back(transposePermutation[dim]);
425
426 Value output = packOp.createDestinationTensor(
427 rewriter, packOp.getLoc(), linalgOp->getOperand(0),
428 packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
429
430 rewriter.replaceOpWithNewOp<PackOp>(
431 packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
432 packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
433
434 return success();
435 }
436
437private:
439};
440
441/// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
442/// transpose semantics.
443struct FoldProducerUnPackWithConsumerLinalgTransposeOp
444 : public OpInterfaceRewritePattern<linalg::LinalgOp> {
445
446public:
447 FoldProducerUnPackWithConsumerLinalgTransposeOp(
448 MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
449 : OpInterfaceRewritePattern<linalg::LinalgOp>(context),
450 controlFn(std::move(controlFn)) {}
451
452 LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
453 PatternRewriter &rewriter) const override {
454 auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
455
456 if (!unPackOp)
457 return failure();
458
459 // User controlled folding function.
460 if (controlFn && !controlFn(&linalgOp->getOpOperand(0)))
461 return failure();
462
463 FailureOr<SmallVector<int64_t>> maybePerm =
464 getTransposeOpPermutation(linalgOp);
465 if (failed(maybePerm))
466 return failure();
467
468 auto outerDimsPerm = unPackOp.getOuterDimsPerm();
469 auto innerDimsPos = unPackOp.getInnerDimsPos();
470 SmallVector<int64_t> newInnerDimsPosVec;
471 SmallVector<int64_t> newOuterDimsPermVec =
472 invertPermutationVector(maybePerm.value());
473
474 // Can't use applyPermutationToVector for newInnerDimsPosVec since input and
475 // permutation rank won't necessarily be equal in all cases.
476 for (auto dim : innerDimsPos)
477 newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]);
478
479 if (!outerDimsPerm.empty())
480 applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
481
482 // Reuse the destination of the transpose op.
483 rewriter.replaceOpWithNewOp<UnPackOp>(
484 linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
485 newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
486
487 return success();
488 }
489
490private:
492};
493
494/// Fold 'transpose' -> 'unpack' into 'unpack' since 'unpack' already has
495/// transpose semantics.
496struct FoldConsumerUnPackWithProducerLinalgTransposeOp
497 : public OpRewritePattern<UnPackOp> {
498 using OpRewritePattern<UnPackOp>::OpRewritePattern;
499
500public:
501 FoldConsumerUnPackWithProducerLinalgTransposeOp(
502 MLIRContext *context, ControlFoldIntoPackUnpackFn controlFn)
503 : OpRewritePattern<UnPackOp>(context), controlFn(std::move(controlFn)) {}
504
505 LogicalResult matchAndRewrite(UnPackOp unPackOp,
506 PatternRewriter &rewriter) const override {
507 auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
508 if (!linalgOp)
509 return failure();
510
511 // User controlled folding function.
512 if (controlFn && !controlFn(&unPackOp.getSourceMutable()))
513 return failure();
514
515 FailureOr<SmallVector<int64_t>> maybePerm =
516 getTransposeOpPermutation(linalgOp);
517 if (failed(maybePerm))
518 return failure();
519
520 SmallVector<SmallVector<OpFoldResult>> unpackOpResultDims;
521 if (failed(reifyResultShapes(rewriter, unPackOp, unpackOpResultDims))) {
522 return failure();
523 }
524
525 SmallVector<int64_t> inverseTransposePerm =
526 invertPermutationVector(maybePerm.value());
527 auto outerDimsPerm = unPackOp.getOuterDimsPerm();
528 auto innerDimsPos = unPackOp.getInnerDimsPos();
529 int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
530 auto mixedInnerTilesVec = unPackOp.getMixedTiles();
531 SmallVector<int64_t> newOuterDimsPermVec;
532 SmallVector<int64_t> newInnerDimsPosVec;
533 SmallVector<OpFoldResult> newMixedInnerTilesVec;
534 if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
535 newOuterDimsPermVec, destRank))
536 return rewriter.notifyMatchFailure(
537 unPackOp,
538 "Cannot fold in tensor.unpack if a tile dimension was transposed "
539 "with a non-tile dimension in linalg.transpose.");
540
541 // Process transpose operation for tiled inner dimensions
542 for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
543 int64_t remappedPosition = inverseTransposePerm[i] - destRank;
544 newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
545 newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
546 }
547
548 auto elemType =
549 cast<ShapedType>(unPackOp->getResultTypes()[0]).getElementType();
550 Value output = tensor::EmptyOp::create(rewriter, unPackOp->getLoc(),
551 unpackOpResultDims[0], elemType);
552
553 rewriter.replaceOpWithNewOp<UnPackOp>(
554 unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
555 newMixedInnerTilesVec, newOuterDimsPermVec);
556
557 return success();
558 }
559
560private:
562};
563
564/// tensor.empty does not define any tensor contents, so an unpadded pack
565/// can be folded away.
566struct FoldEmptyTensorWithPackOp : public OpRewritePattern<PackOp> {
567 using OpRewritePattern<PackOp>::OpRewritePattern;
568
569 LogicalResult matchAndRewrite(PackOp packOp,
570 PatternRewriter &rewriter) const override {
571 // Check for tensor.empty source.
572 auto emptyOp = packOp.getSource().getDefiningOp<tensor::EmptyOp>();
573 if (!emptyOp)
574 return failure();
575
576 // Check for padding.
577 // Packing with padding cannot be simply removed.
578 if (packOp.getPaddingValue())
579 return rewriter.notifyMatchFailure(packOp, "expects no padding value");
580
581 // Replace the pack directly with its destination.
582 rewriter.replaceOp(packOp, packOp.getDest());
583
584 return success();
585 }
586};
587
588/// tensor.empty does not define any tensor contents, so an unpack
589/// can be folded away.
590struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern<UnPackOp> {
591 using OpRewritePattern<UnPackOp>::OpRewritePattern;
592
593 LogicalResult matchAndRewrite(UnPackOp unPackOp,
594 PatternRewriter &rewriter) const override {
595 // Check for tensor.empty source.
596 auto emptyOp = unPackOp.getSource().getDefiningOp<tensor::EmptyOp>();
597 if (!emptyOp)
598 return failure();
599
600 // Replace the unpack directly with its destination.
601 rewriter.replaceOp(unPackOp, unPackOp.getDest());
602
603 return success();
604 }
605};
606
607} // namespace
608
611 patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
612 FoldProducerPackWithConsumerLinalgTransposeOp,
613 FoldConsumerPackWithProducerLinalgTransposeOp,
614 FoldConsumerUnPackWithProducerLinalgTransposeOp,
615 FoldProducerUnPackWithConsumerLinalgTransposeOp>(
616 patterns.getContext(), controlFn);
617}
618
620 patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
621 patterns.getContext());
622}
623
626 patterns.add<FoldEmptyTensorWithPackOp, FoldEmptyTensorWithUnPackOp>(
627 patterns.getContext());
628}
629
630} // namespace linalg
631} // namespace mlir
return success()
ArrayAttr()
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that simplify tensor.pack and tensor.unpack operations.
void populateFoldPackUnpackIntoTensorEmptyPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like linalg.pack and linalg....
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns, const ControlFoldIntoPackUnpackFn &controlFn=nullptr)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
std::function< bool(OpOperand *opOperand)> ControlFoldIntoPackUnpackFn
Function type which is used to control folding operations like tensor.pad and tensor....
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
const FrozenRewritePatternSet & patterns
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.