MLIR 23.0.0git
FoldMemRefAliasOps.cpp
Go to the documentation of this file.
1//===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation pass folds loading/storing from/to subview ops into
10// loading/storing from/to the original memref.
11//
12//===----------------------------------------------------------------------===//
13
24#include "mlir/IR/AffineExpr.h"
25#include "mlir/IR/AffineMap.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/SmallBitVector.h"
30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/Support/Debug.h"
32#include <cstdint>
33
34#define DEBUG_TYPE "fold-memref-alias-ops"
35#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
36
37namespace mlir {
38namespace memref {
39#define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS
40#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
41} // namespace memref
42} // namespace mlir
43
44using namespace mlir;
45
46//===----------------------------------------------------------------------===//
47// Utility functions
48//===----------------------------------------------------------------------===//
49
50/// Deterimine if the last N indices of `reassocitaion` are trivial - that is,
51/// check if they all contain exactly one dimension to collape/expand into.
52static bool
54 int64_t n) {
55 if (n <= 0)
56 return true;
57 return llvm::all_of(
58 reassocs.take_back(n),
59 [&](const ReassociationIndices &indices) { return indices.size() == 1; });
60}
61
62static bool hasTrailingUnitStrides(memref::SubViewOp subview, int64_t n) {
63 if (n <= 0)
64 return true;
65 return llvm::all_of(subview.getStaticStrides().take_back(n),
66 [](int64_t s) { return s == 1; });
67}
68
69/// Helpers to access the memref operand for each op.
70template <typename LoadOrStoreOpTy>
71static Value getMemRefOperand(LoadOrStoreOpTy op) {
72 return op.getMemref();
73}
74
75static Value getMemRefOperand(vector::TransferReadOp op) {
76 return op.getBase();
77}
78
79static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
80 return op.getSrcMemref();
81}
82
83static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
84
85static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); }
86
87static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
88
89static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); }
90
91static Value getMemRefOperand(vector::TransferWriteOp op) {
92 return op.getBase();
93}
94
95static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) {
96 return op.getSrcMemref();
97}
98
99static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) {
100 return op.getDstMemref();
101}
102
103//===----------------------------------------------------------------------===//
104// Patterns
105//===----------------------------------------------------------------------===//
106
107namespace {
108/// Merges subview operation with load/transferRead operation.
109template <typename OpTy>
110class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
111public:
112 using OpRewritePattern<OpTy>::OpRewritePattern;
113
114 LogicalResult matchAndRewrite(OpTy loadOp,
115 PatternRewriter &rewriter) const override;
116};
117
118/// Merges expand_shape operation with load/transferRead operation.
119template <typename OpTy>
120class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
121public:
122 using OpRewritePattern<OpTy>::OpRewritePattern;
123
124 LogicalResult matchAndRewrite(OpTy loadOp,
125 PatternRewriter &rewriter) const override;
126};
127
128/// Merges collapse_shape operation with load/transferRead operation.
129template <typename OpTy>
130class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
131public:
132 using OpRewritePattern<OpTy>::OpRewritePattern;
133
134 LogicalResult matchAndRewrite(OpTy loadOp,
135 PatternRewriter &rewriter) const override;
136};
137
138/// Merges subview operation with store/transferWriteOp operation.
139template <typename OpTy>
140class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
141public:
142 using OpRewritePattern<OpTy>::OpRewritePattern;
143
144 LogicalResult matchAndRewrite(OpTy storeOp,
145 PatternRewriter &rewriter) const override;
146};
147
148/// Merges expand_shape operation with store/transferWriteOp operation.
149template <typename OpTy>
150class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
151public:
152 using OpRewritePattern<OpTy>::OpRewritePattern;
153
154 LogicalResult matchAndRewrite(OpTy storeOp,
155 PatternRewriter &rewriter) const override;
156};
157
158/// Merges collapse_shape operation with store/transferWriteOp operation.
159template <typename OpTy>
160class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
161public:
162 using OpRewritePattern<OpTy>::OpRewritePattern;
163
164 LogicalResult matchAndRewrite(OpTy storeOp,
165 PatternRewriter &rewriter) const override;
166};
167
168/// Folds subview(subview(x)) to a single subview(x).
169class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
170public:
171 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
172
173 LogicalResult matchAndRewrite(memref::SubViewOp subView,
174 PatternRewriter &rewriter) const override {
175 auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
176 if (!srcSubView)
177 return failure();
178
179 // TODO: relax unit stride assumption.
180 if (!subView.hasUnitStride()) {
181 return rewriter.notifyMatchFailure(subView, "requires unit strides");
182 }
183 if (!srcSubView.hasUnitStride()) {
184 return rewriter.notifyMatchFailure(srcSubView, "requires unit strides");
185 }
186
187 // Resolve sizes according to dropped dims.
188 SmallVector<OpFoldResult> resolvedSizes;
189 llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
190 affine::resolveSizesIntoOpWithSizes(srcSubView.getMixedSizes(),
191 subView.getMixedSizes(), srcDroppedDims,
192 resolvedSizes);
193
194 // Resolve offsets according to source offsets and strides.
195 SmallVector<Value> resolvedOffsets;
196 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
197 rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
198 srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
199 resolvedOffsets);
200
201 // Replace original op.
202 rewriter.replaceOpWithNewOp<memref::SubViewOp>(
203 subView, subView.getType(), srcSubView.getSource(),
204 getAsOpFoldResult(resolvedOffsets), resolvedSizes,
205 srcSubView.getMixedStrides());
206
207 return success();
208 }
209};
210
211/// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern
212/// is folds subview on src and dst memref of the copy.
213class NVGPUAsyncCopyOpSubViewOpFolder final
214 : public OpRewritePattern<nvgpu::DeviceAsyncCopyOp> {
215public:
216 using OpRewritePattern<nvgpu::DeviceAsyncCopyOp>::OpRewritePattern;
217
218 LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
219 PatternRewriter &rewriter) const override;
220};
221
222/// Merges subview operations with load/store like operations unless such a
223/// merger would cause the strides between dimensions accessed by that operaton
224/// to change.
225struct AccessOpOfSubViewOpFolder final
226 : OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
227 using Base::Base;
228
229 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
230 PatternRewriter &rewriter) const override;
231};
232
233/// Merge a memref.expand_shape operation with an operation that accesses a
234/// memref by index unless that operation accesss more than one dimension of
235/// memory and any dimension other than the outermost dimension accessed this
236/// way would be merged. This prevents issuses from arising with, say, a
237/// vector.load of a 4x2 vector having the two trailing dimensions of the access
238/// get merged.
239struct AccessOpOfExpandShapeOpFolder final
240 : OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
241 using Base::Base;
242
243 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
244 PatternRewriter &rewriter) const override;
245};
246
247/// Merges an operation that accesses a memref by index with a
248/// memref.collapse_shape, unless this would break apart a dimension other than
249/// the outermost one that an operation accesses. This prevents, for example,
250/// transforming a load of a 3x8 vector from a 6x8 memref into a load
251/// from a 3x4x2 memref (as this would require special handling and could lead
252/// to invalid IR if that higher-dimensional memref comes from a subview) but
253/// does permit turning a load of a length-8 vector from a 3x8 memref into a
254/// load from a 3x2x8 one.
255struct AccessOpOfCollapseShapeOpFolder final
256 : OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
257 using Base::Base;
258
259 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
260 PatternRewriter &rewriter) const override;
261};
262
263/// Merges memref.subview operations present on the source or destination
264/// operands of indexed memory copy operations (DMA operations) into those
265/// operations. This is perfromed unconditionally, since folding in a subview
266/// cannot change the starting position of the copy, which is what the
267/// memref/index pair represent in DMA operations.
268struct IndexedMemCopyOpOfSubViewOpFolder final
269 : OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
270 using Base::Base;
271
272 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
273 PatternRewriter &rewriter) const override;
274};
275
276/// Merges memref.expand_shape operations that are present on the source or
277/// destination of an indexed memory copy/DMA into the memref/index arguments of
278/// that DMA. As with subviews, this can be done unconditionally.
279struct IndexedMemCopyOpOfExpandShapeOpFolder final
280 : OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
281 using Base::Base;
282
283 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
284 PatternRewriter &rewriter) const override;
285};
286
287/// Merges memref.collapse_shape operations that are present on the source or
288/// destination of an indexed memory copy/DMA into the memref/index arguments of
289/// that DMA. As with subviews, this can be done unconditionally.
290struct IndexedMemCopyOpOfCollapseShapeOpFolder final
291 : OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
292 using Base::Base;
293
294 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
295 PatternRewriter &rewriter) const override;
296};
297} // namespace
298
299template <typename XferOp>
300static LogicalResult
302 memref::SubViewOp subviewOp) {
303 static_assert(
304 !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
305 "must be a vector transfer op");
306 if (xferOp.hasOutOfBoundsDim())
307 return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
308 if (!subviewOp.hasUnitStride()) {
309 return rewriter.notifyMatchFailure(
310 xferOp, "non-1 stride subview, need to track strides in folded memref");
311 }
312 return success();
313}
314
315static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
316 Operation *op,
317 memref::SubViewOp subviewOp) {
318 return success();
319}
320
321static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
322 vector::TransferReadOp readOp,
323 memref::SubViewOp subviewOp) {
324 return preconditionsFoldSubViewOpImpl(rewriter, readOp, subviewOp);
325}
326
327static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
328 vector::TransferWriteOp writeOp,
329 memref::SubViewOp subviewOp) {
330 return preconditionsFoldSubViewOpImpl(rewriter, writeOp, subviewOp);
331}
332
333template <typename OpTy>
334LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
335 OpTy loadOp, PatternRewriter &rewriter) const {
336 auto subViewOp =
337 getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
338
339 if (!subViewOp)
340 return rewriter.notifyMatchFailure(loadOp, "not a subview producer");
341
342 LogicalResult preconditionResult =
343 preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp);
344 if (failed(preconditionResult))
345 return preconditionResult;
346
347 SmallVector<Value> sourceIndices;
349 rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
350 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
351 loadOp.getIndices(), sourceIndices);
352
354 .Case([&](memref::LoadOp op) {
355 rewriter.replaceOpWithNewOp<memref::LoadOp>(
356 loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
357 })
358 .Case([&](vector::LoadOp op) {
359 rewriter.replaceOpWithNewOp<vector::LoadOp>(
360 op, op.getType(), subViewOp.getSource(), sourceIndices);
361 })
362 .Case([&](vector::MaskedLoadOp op) {
363 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
364 op, op.getType(), subViewOp.getSource(), sourceIndices,
365 op.getMask(), op.getPassThru());
366 })
367 .Case([&](vector::TransferReadOp op) {
368 rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
369 op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
370 AffineMapAttr::get(expandDimsToRank(
371 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
372 subViewOp.getDroppedDims())),
373 op.getPadding(), op.getMask(), op.getInBoundsAttr());
374 })
375 .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
376 rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
377 op, op.getType(), subViewOp.getSource(), sourceIndices,
378 op.getLeadDimension(), op.getTransposeAttr());
379 })
380 .Case([&](nvgpu::LdMatrixOp op) {
381 rewriter.replaceOpWithNewOp<nvgpu::LdMatrixOp>(
382 op, op.getType(), subViewOp.getSource(), sourceIndices,
383 op.getTranspose(), op.getNumTiles());
384 })
385 .DefaultUnreachable("unexpected operation");
386 return success();
387}
388
389template <typename OpTy>
390LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
391 OpTy loadOp, PatternRewriter &rewriter) const {
392 auto expandShapeOp =
393 getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>();
394
395 if (!expandShapeOp)
396 return failure();
397
398 // For vector::TransferReadOp, validate preconditions before creating any IR.
399 // resolveSourceIndicesExpandShape creates new ops, so all checks that can
400 // fail must happen before that call to avoid "pattern returned failure but
401 // IR did change" errors (caught by MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS).
402 SmallVector<AffineExpr> transferReadNewResults;
403 if (auto transferOp =
404 dyn_cast<vector::TransferReadOp>(loadOp.getOperation())) {
405 const int64_t vectorRank = transferOp.getVectorType().getRank();
406 const int64_t sourceRank =
407 cast<MemRefType>(expandShapeOp.getViewSource().getType()).getRank();
408 if (sourceRank < vectorRank)
409 return failure();
410
411 // We can only fold if the permutation map uses only the least significant
412 // dimension from each expanded reassociation group.
413 for (AffineExpr result : transferOp.getPermutationMap().getResults()) {
414 bool foundExpr = false;
415 for (auto reassocationIndices :
416 llvm::enumerate(expandShapeOp.getReassociationIndices())) {
417 auto reassociation = reassocationIndices.value();
419 reassociation[reassociation.size() - 1], rewriter.getContext());
420 if (dim == result) {
421 transferReadNewResults.push_back(getAffineDimExpr(
422 reassocationIndices.index(), rewriter.getContext()));
423 foundExpr = true;
424 break;
425 }
426 }
427 if (!foundExpr)
428 return failure();
429 }
430 }
431
432 SmallVector<Value> sourceIndices;
433 // memref.load guarantees that indexes start inbounds while the vector
434 // operations don't. This impacts if our linearization is `disjoint`
435 resolveSourceIndicesExpandShape(loadOp.getLoc(), rewriter, expandShapeOp,
436 loadOp.getIndices(), sourceIndices,
437 isa<memref::LoadOp>(loadOp.getOperation()));
438
440 .Case([&](memref::LoadOp op) {
441 rewriter.replaceOpWithNewOp<memref::LoadOp>(
442 loadOp, expandShapeOp.getViewSource(), sourceIndices,
443 op.getNontemporal());
444 return success();
445 })
446 .Case([&](vector::LoadOp op) {
447 rewriter.replaceOpWithNewOp<vector::LoadOp>(
448 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
449 op.getNontemporal());
450 return success();
451 })
452 .Case([&](vector::MaskedLoadOp op) {
453 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
454 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
455 op.getMask(), op.getPassThru());
456 return success();
457 })
458 .Case([&](vector::TransferReadOp op) {
459 const int64_t sourceRank = sourceIndices.size();
460 auto newMap = AffineMap::get(sourceRank, 0, transferReadNewResults,
461 op.getContext());
462 rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
463 op, op.getVectorType(), expandShapeOp.getViewSource(),
464 sourceIndices, newMap, op.getPadding(), op.getMask(),
465 op.getInBounds());
466 return success();
467 })
468 .DefaultUnreachable("unexpected operation");
469}
470
471template <typename OpTy>
472LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
473 OpTy loadOp, PatternRewriter &rewriter) const {
474 auto collapseShapeOp = getMemRefOperand(loadOp)
475 .template getDefiningOp<memref::CollapseShapeOp>();
476
477 if (!collapseShapeOp)
478 return failure();
479
480 SmallVector<Value> sourceIndices;
481 resolveSourceIndicesCollapseShape(loadOp.getLoc(), rewriter, collapseShapeOp,
482 loadOp.getIndices(), sourceIndices);
484 .Case([&](memref::LoadOp op) {
485 rewriter.replaceOpWithNewOp<memref::LoadOp>(
486 loadOp, collapseShapeOp.getViewSource(), sourceIndices,
487 op.getNontemporal());
488 })
489 .Case([&](vector::LoadOp op) {
490 rewriter.replaceOpWithNewOp<vector::LoadOp>(
491 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
492 op.getNontemporal());
493 })
494 .Case([&](vector::MaskedLoadOp op) {
495 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
496 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
497 op.getMask(), op.getPassThru());
498 })
499 .DefaultUnreachable("unexpected operation");
500 return success();
501}
502
503template <typename OpTy>
504LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
505 OpTy storeOp, PatternRewriter &rewriter) const {
506 auto subViewOp =
507 getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
508
509 if (!subViewOp)
510 return rewriter.notifyMatchFailure(storeOp, "not a subview producer");
511
512 LogicalResult preconditionResult =
513 preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp);
514 if (failed(preconditionResult))
515 return preconditionResult;
516
517 SmallVector<Value> sourceIndices;
519 rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
520 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
521 storeOp.getIndices(), sourceIndices);
522
524 .Case([&](memref::StoreOp op) {
525 rewriter.replaceOpWithNewOp<memref::StoreOp>(
526 op, op.getValue(), subViewOp.getSource(), sourceIndices,
527 op.getNontemporal());
528 })
529 .Case([&](vector::TransferWriteOp op) {
530 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
531 op, op.getValue(), subViewOp.getSource(), sourceIndices,
532 AffineMapAttr::get(expandDimsToRank(
533 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
534 subViewOp.getDroppedDims())),
535 op.getMask(), op.getInBoundsAttr());
536 })
537 .Case([&](vector::StoreOp op) {
538 rewriter.replaceOpWithNewOp<vector::StoreOp>(
539 op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
540 })
541 .Case([&](vector::MaskedStoreOp op) {
542 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
543 op, subViewOp.getSource(), sourceIndices, op.getMask(),
544 op.getValueToStore());
545 })
546 .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
547 rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
548 op, op.getSrc(), subViewOp.getSource(), sourceIndices,
549 op.getLeadDimension(), op.getTransposeAttr());
550 })
551 .DefaultUnreachable("unexpected operation");
552 return success();
553}
554
555template <typename OpTy>
556LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
557 OpTy storeOp, PatternRewriter &rewriter) const {
558 auto expandShapeOp =
559 getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>();
560
561 if (!expandShapeOp)
562 return failure();
563
564 SmallVector<Value> sourceIndices;
565 // memref.store guarantees that indexes start inbounds while the vector
566 // operations don't. This impacts if our linearization is `disjoint`
567 resolveSourceIndicesExpandShape(storeOp.getLoc(), rewriter, expandShapeOp,
568 storeOp.getIndices(), sourceIndices,
569 isa<memref::StoreOp>(storeOp.getOperation()));
571 .Case([&](memref::StoreOp op) {
572 rewriter.replaceOpWithNewOp<memref::StoreOp>(
573 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
574 sourceIndices, op.getNontemporal());
575 })
576 .Case([&](vector::StoreOp op) {
577 rewriter.replaceOpWithNewOp<vector::StoreOp>(
578 op, op.getValueToStore(), expandShapeOp.getViewSource(),
579 sourceIndices, op.getNontemporal());
580 })
581 .Case([&](vector::MaskedStoreOp op) {
582 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
583 op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
584 op.getValueToStore());
585 })
586 .DefaultUnreachable("unexpected operation");
587 return success();
588}
589
590template <typename OpTy>
591LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
592 OpTy storeOp, PatternRewriter &rewriter) const {
593 auto collapseShapeOp = getMemRefOperand(storeOp)
594 .template getDefiningOp<memref::CollapseShapeOp>();
595
596 if (!collapseShapeOp)
597 return failure();
598
599 SmallVector<Value> sourceIndices;
600 resolveSourceIndicesCollapseShape(storeOp.getLoc(), rewriter, collapseShapeOp,
601 storeOp.getIndices(), sourceIndices);
603 .Case([&](memref::StoreOp op) {
604 rewriter.replaceOpWithNewOp<memref::StoreOp>(
605 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
606 sourceIndices, op.getNontemporal());
607 })
608 .Case([&](vector::StoreOp op) {
609 rewriter.replaceOpWithNewOp<vector::StoreOp>(
610 op, op.getValueToStore(), collapseShapeOp.getViewSource(),
611 sourceIndices, op.getNontemporal());
612 })
613 .Case([&](vector::MaskedStoreOp op) {
614 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
615 op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
616 op.getValueToStore());
617 })
618 .DefaultUnreachable("unexpected operation");
619 return success();
620}
621
622LogicalResult
623AccessOpOfSubViewOpFolder::matchAndRewrite(memref::IndexedAccessOpInterface op,
624 PatternRewriter &rewriter) const {
625 auto subview = op.getAccessedMemref().getDefiningOp<memref::SubViewOp>();
626 if (!subview)
627 return rewriter.notifyMatchFailure(op, "not accessing a subview");
628
629 SmallVector<int64_t> accessedShape = op.getAccessedShape();
630 // Note the subtle difference between accesedShape = {1} and accessedShape =
631 // {} here. The former prevents us from fdolding in a subview that doesn't
632 // have a unit stride on the final dimension, while the latter does not (since
633 // it indices scalar accesss).
634 int64_t accessedDims = accessedShape.size();
635 if (!hasTrailingUnitStrides(subview, accessedDims))
636 return rewriter.notifyMatchFailure(
637 op, "non-unit stride on accessed dimensions");
638
639 llvm::SmallBitVector droppedDims = subview.getDroppedDims();
640 int64_t sourceRank = subview.getSourceType().getRank();
641
642 // Ignore outermost access dimension - we only care about dropped dimensions
643 // between the accessed op's results, as those could break the accessing op's
644 // sematics.
645 int64_t secondAccessedDim = sourceRank - (accessedDims - 1);
646 if (secondAccessedDim < sourceRank) {
647 for (int64_t d : llvm::seq(secondAccessedDim, sourceRank)) {
648 if (droppedDims.test(d))
649 return rewriter.notifyMatchFailure(
650 op, "reintroducing dropped dimension " + Twine(d) +
651 " would break access op semantics");
652 }
653 }
654
655 SmallVector<Value> sourceIndices;
656 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
657 rewriter, op.getLoc(), subview.getMixedOffsets(),
658 subview.getMixedStrides(), droppedDims, op.getIndices(), sourceIndices);
659
660 std::optional<SmallVector<Value>> newValues =
661 op.updateMemrefAndIndices(rewriter, subview.getSource(), sourceIndices);
662 if (newValues)
663 rewriter.replaceOp(op, *newValues);
664 return success();
665}
666
667LogicalResult AccessOpOfExpandShapeOpFolder::matchAndRewrite(
668 memref::IndexedAccessOpInterface op, PatternRewriter &rewriter) const {
669 auto expand = op.getAccessedMemref().getDefiningOp<memref::ExpandShapeOp>();
670 if (!expand)
671 return rewriter.notifyMatchFailure(op, "not accessing an expand_shape");
672
673 SmallVector<int64_t> rawAccessedShape = op.getAccessedShape();
674 ArrayRef<int64_t> accessedShape = rawAccessedShape;
675 // Cut off the leading dimension, since we don't care about monifying its
676 // strides.
677 if (!accessedShape.empty())
678 accessedShape = accessedShape.drop_front();
679
680 SmallVector<ReassociationIndices, 4> reassocs =
681 expand.getReassociationIndices();
682 if (!hasTrivialReassociationSuffix(reassocs, accessedShape.size()))
683 return rewriter.notifyMatchFailure(
684 op,
685 "expand_shape folding would merge semanvtically important dimensions");
686
687 SmallVector<Value> sourceIndices;
688 memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, expand,
689 op.getIndices(), sourceIndices,
690 op.hasInboundsIndices());
691
692 std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices(
693 rewriter, expand.getViewSource(), sourceIndices);
694 if (newValues)
695 rewriter.replaceOp(op, *newValues);
696 return success();
697}
698
699LogicalResult AccessOpOfCollapseShapeOpFolder::matchAndRewrite(
700 memref::IndexedAccessOpInterface op, PatternRewriter &rewriter) const {
701 auto collapse =
702 op.getAccessedMemref().getDefiningOp<memref::CollapseShapeOp>();
703 if (!collapse)
704 return rewriter.notifyMatchFailure(op, "not accessing a collapse_shape");
705
706 SmallVector<int64_t> rawAccessedShape = op.getAccessedShape();
707 ArrayRef<int64_t> accessedShape = rawAccessedShape;
708 // Cut off the leading dimension, since we don't care about its strides being
709 // modified and we know that the dimensions within its reassociation group, if
710 // it's non-trivial, must be contiguous.
711 if (!accessedShape.empty())
712 accessedShape = accessedShape.drop_front();
713
714 SmallVector<ReassociationIndices, 4> reassocs =
715 collapse.getReassociationIndices();
716 if (!hasTrivialReassociationSuffix(reassocs, accessedShape.size()))
717 return rewriter.notifyMatchFailure(op,
718 "collapse_shape folding would merge "
719 "semanvtically important dimensions");
720
721 SmallVector<Value> sourceIndices;
722 memref::resolveSourceIndicesCollapseShape(op.getLoc(), rewriter, collapse,
723 op.getIndices(), sourceIndices);
724
725 std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices(
726 rewriter, collapse.getViewSource(), sourceIndices);
727 if (newValues)
728 rewriter.replaceOp(op, *newValues);
729 return success();
730}
731
732LogicalResult IndexedMemCopyOpOfSubViewOpFolder::matchAndRewrite(
733 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const {
734 auto srcSubview = op.getSrc().getDefiningOp<memref::SubViewOp>();
735 auto dstSubview = op.getDst().getDefiningOp<memref::SubViewOp>();
736 if (!srcSubview && !dstSubview)
737 return rewriter.notifyMatchFailure(
738 op, "no subviews found on indexed copy inputs");
739
740 Value newSrc = op.getSrc();
741 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
742 Value newDst = op.getDst();
743 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
744 if (srcSubview) {
745 newSrc = srcSubview.getSource();
746 newSrcIndices.clear();
747 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
748 rewriter, op.getLoc(), srcSubview.getMixedOffsets(),
749 srcSubview.getMixedStrides(), srcSubview.getDroppedDims(),
750 op.getSrcIndices(), newSrcIndices);
751 }
752 if (dstSubview) {
753 newDst = dstSubview.getSource();
754 newDstIndices.clear();
755 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
756 rewriter, op.getLoc(), dstSubview.getMixedOffsets(),
757 dstSubview.getMixedStrides(), dstSubview.getDroppedDims(),
758 op.getDstIndices(), newDstIndices);
759 }
760 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
761 newDstIndices);
762 return success();
763}
764
765LogicalResult IndexedMemCopyOpOfExpandShapeOpFolder::matchAndRewrite(
766 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const {
767 auto srcExpand = op.getSrc().getDefiningOp<memref::ExpandShapeOp>();
768 auto dstExpand = op.getDst().getDefiningOp<memref::ExpandShapeOp>();
769 if (!srcExpand && !dstExpand)
770 return rewriter.notifyMatchFailure(
771 op, "no expand_shapes found on indexed copy inputs");
772
773 Value newSrc = op.getSrc();
774 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
775 Value newDst = op.getDst();
776 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
777 if (srcExpand) {
778 newSrc = srcExpand.getViewSource();
779 newSrcIndices.clear();
780 memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, srcExpand,
781 op.getSrcIndices(), newSrcIndices,
782 /*startsInbounds=*/true);
783 }
784 if (dstExpand) {
785 newDst = dstExpand.getViewSource();
786 newDstIndices.clear();
787 memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, dstExpand,
788 op.getDstIndices(), newDstIndices,
789 /*startsInbounds=*/true);
790 }
791 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
792 newDstIndices);
793 return success();
794}
795
796LogicalResult IndexedMemCopyOpOfCollapseShapeOpFolder::matchAndRewrite(
797 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const {
798 auto srcCollapse = op.getSrc().getDefiningOp<memref::CollapseShapeOp>();
799 auto dstCollapse = op.getDst().getDefiningOp<memref::CollapseShapeOp>();
800 if (!srcCollapse && !dstCollapse)
801 return rewriter.notifyMatchFailure(
802 op, "no collapse_shapes found on indexed copy inputs");
803
804 Value newSrc = op.getSrc();
805 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
806 Value newDst = op.getDst();
807 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
808 if (srcCollapse) {
809 newSrc = srcCollapse.getViewSource();
810 newSrcIndices.clear();
812 op.getLoc(), rewriter, srcCollapse, op.getSrcIndices(), newSrcIndices);
813 }
814 if (dstCollapse) {
815 newDst = dstCollapse.getViewSource();
816 newDstIndices.clear();
818 op.getLoc(), rewriter, dstCollapse, op.getDstIndices(), newDstIndices);
819 }
820 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
821 newDstIndices);
822 return success();
823}
824
825LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
826 nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const {
827
828 LLVM_DEBUG(DBGS() << "copyOp : " << copyOp << "\n");
829
830 auto srcSubViewOp =
831 copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
832 auto dstSubViewOp =
833 copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
834
835 if (!(srcSubViewOp || dstSubViewOp))
836 return rewriter.notifyMatchFailure(copyOp, "does not use subview ops for "
837 "source or destination");
838
839 // If the source is a subview, we need to resolve the indices.
840 SmallVector<Value> foldedSrcIndices(copyOp.getSrcIndices().begin(),
841 copyOp.getSrcIndices().end());
842
843 if (srcSubViewOp) {
844 LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n");
845 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
846 rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
847 srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
848 copyOp.getSrcIndices(), foldedSrcIndices);
849 }
850
851 // If the destination is a subview, we need to resolve the indices.
852 SmallVector<Value> foldedDstIndices(copyOp.getDstIndices().begin(),
853 copyOp.getDstIndices().end());
854
855 if (dstSubViewOp) {
856 LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n");
857 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
858 rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
859 dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
860 copyOp.getDstIndices(), foldedDstIndices);
861 }
862
863 // Replace the copy op with a new copy op that uses the source and destination
864 // of the subview.
865 rewriter.replaceOpWithNewOp<nvgpu::DeviceAsyncCopyOp>(
866 copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()),
867 (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
868 foldedDstIndices,
869 (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
870 foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
871 copyOp.getBypassL1Attr());
872
873 return success();
874}
875
877 patterns.add<
878 // Interface-based patterns to which we will be migrating.
879 AccessOpOfSubViewOpFolder, AccessOpOfExpandShapeOpFolder,
880 AccessOpOfCollapseShapeOpFolder, IndexedMemCopyOpOfSubViewOpFolder,
881 IndexedMemCopyOpOfExpandShapeOpFolder,
882 IndexedMemCopyOpOfCollapseShapeOpFolder,
883 // The old way of doing things. Don't add more of these.
884 LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
885 LoadOpOfSubViewOpFolder<vector::LoadOp>,
886 LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
887 LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
888 LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
889 StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
890 StoreOpOfSubViewOpFolder<vector::StoreOp>,
891 StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
892 StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
893 LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
894 LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
895 LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>,
896 StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
897 StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
898 LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
899 LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
900 StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
901 StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
902 SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
903 patterns.getContext());
904}
905
906//===----------------------------------------------------------------------===//
907// Pass registration
908//===----------------------------------------------------------------------===//
909
910namespace {
911
912struct FoldMemRefAliasOpsPass final
913 : public memref::impl::FoldMemRefAliasOpsPassBase<FoldMemRefAliasOpsPass> {
914 void runOnOperation() override;
915};
916
917} // namespace
918
919void FoldMemRefAliasOpsPass::runOnOperation() {
920 RewritePatternSet patterns(&getContext());
922 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
923}
return success()
#define DBGS()
Definition Hoisting.cpp:32
b getContext())
static bool hasTrivialReassociationSuffix(ArrayRef< ReassociationIndices > reassocs, int64_t n)
Deterimine if the last N indices of reassocitaion are trivial - that is, check if they all contain ex...
static bool hasTrailingUnitStrides(memref::SubViewOp subview, int64_t n)
static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, Operation *op, memref::SubViewOp subviewOp)
static LogicalResult preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp, memref::SubViewOp subviewOp)
static Value getMemRefOperand(LoadOrStoreOpTy op)
Helpers to access the memref operand for each op.
Base type for affine expression.
Definition AffineExpr.h:68
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
MLIRContext * getContext() const
Definition Builders.h:56
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
void resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...