MLIR 23.0.0git
FoldMemRefAliasOps.cpp
Go to the documentation of this file.
1//===- FoldMemRefAliasOps.cpp - Fold memref alias ops ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation pass folds loading/storing from/to subview ops into
10// loading/storing from/to the original memref.
11//
12//===----------------------------------------------------------------------===//
13
24#include "mlir/IR/AffineExpr.h"
25#include "mlir/IR/AffineMap.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/SmallBitVector.h"
30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/Support/Debug.h"
32#include <cstdint>
33
34#define DEBUG_TYPE "fold-memref-alias-ops"
35#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
36
37namespace mlir {
38namespace memref {
39#define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS
40#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
41} // namespace memref
42} // namespace mlir
43
44using namespace mlir;
45
46//===----------------------------------------------------------------------===//
47// Utility functions
48//===----------------------------------------------------------------------===//
49
50/// Deterimine if the last N indices of `reassocitaion` are trivial - that is,
51/// check if they all contain exactly one dimension to collape/expand into.
52static bool
54 int64_t n) {
55 if (n <= 0)
56 return true;
57 return llvm::all_of(
58 reassocs.take_back(n),
59 [&](const ReassociationIndices &indices) { return indices.size() == 1; });
60}
61
62static bool hasTrailingUnitStrides(memref::SubViewOp subview, int64_t n) {
63 if (n <= 0)
64 return true;
65 return llvm::all_of(subview.getStaticStrides().take_back(n),
66 [](int64_t s) { return s == 1; });
67}
68
69/// Helpers to access the memref operand for each op.
70template <typename LoadOrStoreOpTy>
71static Value getMemRefOperand(LoadOrStoreOpTy op) {
72 return op.getMemref();
73}
74
75static Value getMemRefOperand(vector::TransferReadOp op) {
76 return op.getBase();
77}
78
79static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
80 return op.getSrcMemref();
81}
82
83static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
84
85static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); }
86
87static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
88
89static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); }
90
91static Value getMemRefOperand(vector::TransferWriteOp op) {
92 return op.getBase();
93}
94
95static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) {
96 return op.getSrcMemref();
97}
98
99static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) {
100 return op.getDstMemref();
101}
102
103//===----------------------------------------------------------------------===//
104// Patterns
105//===----------------------------------------------------------------------===//
106
107namespace {
108/// Merges subview operation with load/transferRead operation.
109template <typename OpTy>
110class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
111public:
112 using OpRewritePattern<OpTy>::OpRewritePattern;
113
114 LogicalResult matchAndRewrite(OpTy loadOp,
115 PatternRewriter &rewriter) const override;
116};
117
118/// Merges expand_shape operation with load/transferRead operation.
119template <typename OpTy>
120class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
121public:
122 using OpRewritePattern<OpTy>::OpRewritePattern;
123
124 LogicalResult matchAndRewrite(OpTy loadOp,
125 PatternRewriter &rewriter) const override;
126};
127
128/// Merges collapse_shape operation with load/transferRead operation.
129template <typename OpTy>
130class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
131public:
132 using OpRewritePattern<OpTy>::OpRewritePattern;
133
134 LogicalResult matchAndRewrite(OpTy loadOp,
135 PatternRewriter &rewriter) const override;
136};
137
138/// Merges subview operation with store/transferWriteOp operation.
139template <typename OpTy>
140class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
141public:
142 using OpRewritePattern<OpTy>::OpRewritePattern;
143
144 LogicalResult matchAndRewrite(OpTy storeOp,
145 PatternRewriter &rewriter) const override;
146};
147
148/// Merges expand_shape operation with store/transferWriteOp operation.
149template <typename OpTy>
150class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
151public:
152 using OpRewritePattern<OpTy>::OpRewritePattern;
153
154 LogicalResult matchAndRewrite(OpTy storeOp,
155 PatternRewriter &rewriter) const override;
156};
157
158/// Merges collapse_shape operation with store/transferWriteOp operation.
159template <typename OpTy>
160class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
161public:
162 using OpRewritePattern<OpTy>::OpRewritePattern;
163
164 LogicalResult matchAndRewrite(OpTy storeOp,
165 PatternRewriter &rewriter) const override;
166};
167
168/// Folds subview(subview(x)) to a single subview(x).
169class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
170public:
171 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
172
173 LogicalResult matchAndRewrite(memref::SubViewOp subView,
174 PatternRewriter &rewriter) const override {
175 auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
176 if (!srcSubView)
177 return failure();
178
179 SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
180 if (failed(affine::mergeOffsetsSizesAndStrides(
181 rewriter, subView.getLoc(), srcSubView, subView,
182 srcSubView.getDroppedDims(), newOffsets, newSizes, newStrides)))
183 return failure();
184
185 // Replace original op.
186 rewriter.replaceOpWithNewOp<memref::SubViewOp>(
187 subView, subView.getType(), srcSubView.getSource(), newOffsets,
188 newSizes, newStrides);
189 return success();
190 }
191};
192
193/// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern
194/// is folds subview on src and dst memref of the copy.
195class NVGPUAsyncCopyOpSubViewOpFolder final
196 : public OpRewritePattern<nvgpu::DeviceAsyncCopyOp> {
197public:
198 using OpRewritePattern<nvgpu::DeviceAsyncCopyOp>::OpRewritePattern;
199
200 LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
201 PatternRewriter &rewriter) const override;
202};
203
204/// Merges subview operations with load/store like operations unless such a
205/// merger would cause the strides between dimensions accessed by that operaton
206/// to change.
207struct AccessOpOfSubViewOpFolder final
208 : OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
209 using Base::Base;
210
211 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
212 PatternRewriter &rewriter) const override;
213};
214
215/// Merge a memref.expand_shape operation with an operation that accesses a
216/// memref by index unless that operation accesss more than one dimension of
217/// memory and any dimension other than the outermost dimension accessed this
218/// way would be merged. This prevents issuses from arising with, say, a
219/// vector.load of a 4x2 vector having the two trailing dimensions of the access
220/// get merged.
221struct AccessOpOfExpandShapeOpFolder final
222 : OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
223 using Base::Base;
224
225 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
226 PatternRewriter &rewriter) const override;
227};
228
229/// Merges an operation that accesses a memref by index with a
230/// memref.collapse_shape, unless this would break apart a dimension other than
231/// the outermost one that an operation accesses. This prevents, for example,
232/// transforming a load of a 3x8 vector from a 6x8 memref into a load
233/// from a 3x4x2 memref (as this would require special handling and could lead
234/// to invalid IR if that higher-dimensional memref comes from a subview) but
235/// does permit turning a load of a length-8 vector from a 3x8 memref into a
236/// load from a 3x2x8 one.
237struct AccessOpOfCollapseShapeOpFolder final
238 : OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
239 using Base::Base;
240
241 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
242 PatternRewriter &rewriter) const override;
243};
244
245/// Merges memref.subview operations present on the source or destination
246/// operands of indexed memory copy operations (DMA operations) into those
247/// operations. This is perfromed unconditionally, since folding in a subview
248/// cannot change the starting position of the copy, which is what the
249/// memref/index pair represent in DMA operations.
250struct IndexedMemCopyOpOfSubViewOpFolder final
251 : OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
252 using Base::Base;
253
254 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
255 PatternRewriter &rewriter) const override;
256};
257
258/// Merges memref.expand_shape operations that are present on the source or
259/// destination of an indexed memory copy/DMA into the memref/index arguments of
260/// that DMA. As with subviews, this can be done unconditionally.
261struct IndexedMemCopyOpOfExpandShapeOpFolder final
262 : OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
263 using Base::Base;
264
265 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
266 PatternRewriter &rewriter) const override;
267};
268
269/// Merges memref.collapse_shape operations that are present on the source or
270/// destination of an indexed memory copy/DMA into the memref/index arguments of
271/// that DMA. As with subviews, this can be done unconditionally.
272struct IndexedMemCopyOpOfCollapseShapeOpFolder final
273 : OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
274 using Base::Base;
275
276 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
277 PatternRewriter &rewriter) const override;
278};
279} // namespace
280
281template <typename XferOp>
282static LogicalResult
284 memref::SubViewOp subviewOp) {
285 static_assert(
286 !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
287 "must be a vector transfer op");
288 if (xferOp.hasOutOfBoundsDim())
289 return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
290 if (!subviewOp.hasUnitStride()) {
291 return rewriter.notifyMatchFailure(
292 xferOp, "non-1 stride subview, need to track strides in folded memref");
293 }
294 return success();
295}
296
297static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
298 Operation *op,
299 memref::SubViewOp subviewOp) {
300 return success();
301}
302
303static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
304 vector::TransferReadOp readOp,
305 memref::SubViewOp subviewOp) {
306 return preconditionsFoldSubViewOpImpl(rewriter, readOp, subviewOp);
307}
308
309static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
310 vector::TransferWriteOp writeOp,
311 memref::SubViewOp subviewOp) {
312 return preconditionsFoldSubViewOpImpl(rewriter, writeOp, subviewOp);
313}
314
315template <typename OpTy>
316LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
317 OpTy loadOp, PatternRewriter &rewriter) const {
318 auto subViewOp =
319 getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
320
321 if (!subViewOp)
322 return rewriter.notifyMatchFailure(loadOp, "not a subview producer");
323
324 LogicalResult preconditionResult =
325 preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp);
326 if (failed(preconditionResult))
327 return preconditionResult;
328
329 SmallVector<Value> sourceIndices;
331 rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
332 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
333 loadOp.getIndices(), sourceIndices);
334
336 .Case([&](memref::LoadOp op) {
337 rewriter.replaceOpWithNewOp<memref::LoadOp>(
338 loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
339 })
340 .Case([&](vector::LoadOp op) {
341 rewriter.replaceOpWithNewOp<vector::LoadOp>(
342 op, op.getType(), subViewOp.getSource(), sourceIndices);
343 })
344 .Case([&](vector::MaskedLoadOp op) {
345 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
346 op, op.getType(), subViewOp.getSource(), sourceIndices,
347 op.getMask(), op.getPassThru());
348 })
349 .Case([&](vector::TransferReadOp op) {
350 rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
351 op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
352 AffineMapAttr::get(expandDimsToRank(
353 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
354 subViewOp.getDroppedDims())),
355 op.getPadding(), op.getMask(), op.getInBoundsAttr());
356 })
357 .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
358 rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
359 op, op.getType(), subViewOp.getSource(), sourceIndices,
360 op.getLeadDimension(), op.getTransposeAttr());
361 })
362 .Case([&](nvgpu::LdMatrixOp op) {
363 rewriter.replaceOpWithNewOp<nvgpu::LdMatrixOp>(
364 op, op.getType(), subViewOp.getSource(), sourceIndices,
365 op.getTranspose(), op.getNumTiles());
366 })
367 .DefaultUnreachable("unexpected operation");
368 return success();
369}
370
371template <typename OpTy>
372LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
373 OpTy loadOp, PatternRewriter &rewriter) const {
374 auto expandShapeOp =
375 getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>();
376
377 if (!expandShapeOp)
378 return failure();
379
380 // For vector::TransferReadOp, validate preconditions before creating any IR.
381 // resolveSourceIndicesExpandShape creates new ops, so all checks that can
382 // fail must happen before that call to avoid "pattern returned failure but
383 // IR did change" errors (caught by MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS).
384 SmallVector<AffineExpr> transferReadNewResults;
385 if (auto transferOp =
386 dyn_cast<vector::TransferReadOp>(loadOp.getOperation())) {
387 const int64_t vectorRank = transferOp.getVectorType().getRank();
388 const int64_t sourceRank =
389 cast<MemRefType>(expandShapeOp.getViewSource().getType()).getRank();
390 if (sourceRank < vectorRank)
391 return failure();
392
393 // We can only fold if the permutation map uses only the least significant
394 // dimension from each expanded reassociation group.
395 for (AffineExpr result : transferOp.getPermutationMap().getResults()) {
396 bool foundExpr = false;
397 for (auto reassocationIndices :
398 llvm::enumerate(expandShapeOp.getReassociationIndices())) {
399 auto reassociation = reassocationIndices.value();
401 reassociation[reassociation.size() - 1], rewriter.getContext());
402 if (dim == result) {
403 transferReadNewResults.push_back(getAffineDimExpr(
404 reassocationIndices.index(), rewriter.getContext()));
405 foundExpr = true;
406 break;
407 }
408 }
409 if (!foundExpr)
410 return failure();
411 }
412 }
413
414 SmallVector<Value> sourceIndices;
415 // memref.load guarantees that indexes start inbounds while the vector
416 // operations don't. This impacts if our linearization is `disjoint`
417 resolveSourceIndicesExpandShape(loadOp.getLoc(), rewriter, expandShapeOp,
418 loadOp.getIndices(), sourceIndices,
419 isa<memref::LoadOp>(loadOp.getOperation()));
420
422 .Case([&](memref::LoadOp op) {
423 rewriter.replaceOpWithNewOp<memref::LoadOp>(
424 loadOp, expandShapeOp.getViewSource(), sourceIndices,
425 op.getNontemporal());
426 return success();
427 })
428 .Case([&](vector::LoadOp op) {
429 rewriter.replaceOpWithNewOp<vector::LoadOp>(
430 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
431 op.getNontemporal());
432 return success();
433 })
434 .Case([&](vector::MaskedLoadOp op) {
435 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
436 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
437 op.getMask(), op.getPassThru());
438 return success();
439 })
440 .Case([&](vector::TransferReadOp op) {
441 const int64_t sourceRank = sourceIndices.size();
442 auto newMap = AffineMap::get(sourceRank, 0, transferReadNewResults,
443 op.getContext());
444 rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
445 op, op.getVectorType(), expandShapeOp.getViewSource(),
446 sourceIndices, newMap, op.getPadding(), op.getMask(),
447 op.getInBounds());
448 return success();
449 })
450 .DefaultUnreachable("unexpected operation");
451}
452
453template <typename OpTy>
454LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
455 OpTy loadOp, PatternRewriter &rewriter) const {
456 auto collapseShapeOp = getMemRefOperand(loadOp)
457 .template getDefiningOp<memref::CollapseShapeOp>();
458
459 if (!collapseShapeOp)
460 return failure();
461
462 SmallVector<Value> sourceIndices;
463 resolveSourceIndicesCollapseShape(loadOp.getLoc(), rewriter, collapseShapeOp,
464 loadOp.getIndices(), sourceIndices);
466 .Case([&](memref::LoadOp op) {
467 rewriter.replaceOpWithNewOp<memref::LoadOp>(
468 loadOp, collapseShapeOp.getViewSource(), sourceIndices,
469 op.getNontemporal());
470 })
471 .Case([&](vector::LoadOp op) {
472 rewriter.replaceOpWithNewOp<vector::LoadOp>(
473 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
474 op.getNontemporal());
475 })
476 .Case([&](vector::MaskedLoadOp op) {
477 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
478 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
479 op.getMask(), op.getPassThru());
480 })
481 .DefaultUnreachable("unexpected operation");
482 return success();
483}
484
485template <typename OpTy>
486LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
487 OpTy storeOp, PatternRewriter &rewriter) const {
488 auto subViewOp =
489 getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
490
491 if (!subViewOp)
492 return rewriter.notifyMatchFailure(storeOp, "not a subview producer");
493
494 LogicalResult preconditionResult =
495 preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp);
496 if (failed(preconditionResult))
497 return preconditionResult;
498
499 SmallVector<Value> sourceIndices;
501 rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
502 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
503 storeOp.getIndices(), sourceIndices);
504
506 .Case([&](memref::StoreOp op) {
507 rewriter.replaceOpWithNewOp<memref::StoreOp>(
508 op, op.getValue(), subViewOp.getSource(), sourceIndices,
509 op.getNontemporal());
510 })
511 .Case([&](vector::TransferWriteOp op) {
512 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
513 op, op.getValue(), subViewOp.getSource(), sourceIndices,
514 AffineMapAttr::get(expandDimsToRank(
515 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
516 subViewOp.getDroppedDims())),
517 op.getMask(), op.getInBoundsAttr());
518 })
519 .Case([&](vector::StoreOp op) {
520 rewriter.replaceOpWithNewOp<vector::StoreOp>(
521 op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
522 })
523 .Case([&](vector::MaskedStoreOp op) {
524 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
525 op, subViewOp.getSource(), sourceIndices, op.getMask(),
526 op.getValueToStore());
527 })
528 .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
529 rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
530 op, op.getSrc(), subViewOp.getSource(), sourceIndices,
531 op.getLeadDimension(), op.getTransposeAttr());
532 })
533 .DefaultUnreachable("unexpected operation");
534 return success();
535}
536
537template <typename OpTy>
538LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
539 OpTy storeOp, PatternRewriter &rewriter) const {
540 auto expandShapeOp =
541 getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>();
542
543 if (!expandShapeOp)
544 return failure();
545
546 SmallVector<Value> sourceIndices;
547 // memref.store guarantees that indexes start inbounds while the vector
548 // operations don't. This impacts if our linearization is `disjoint`
549 resolveSourceIndicesExpandShape(storeOp.getLoc(), rewriter, expandShapeOp,
550 storeOp.getIndices(), sourceIndices,
551 isa<memref::StoreOp>(storeOp.getOperation()));
553 .Case([&](memref::StoreOp op) {
554 rewriter.replaceOpWithNewOp<memref::StoreOp>(
555 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
556 sourceIndices, op.getNontemporal());
557 })
558 .Case([&](vector::StoreOp op) {
559 rewriter.replaceOpWithNewOp<vector::StoreOp>(
560 op, op.getValueToStore(), expandShapeOp.getViewSource(),
561 sourceIndices, op.getNontemporal());
562 })
563 .Case([&](vector::MaskedStoreOp op) {
564 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
565 op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
566 op.getValueToStore());
567 })
568 .DefaultUnreachable("unexpected operation");
569 return success();
570}
571
572template <typename OpTy>
573LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
574 OpTy storeOp, PatternRewriter &rewriter) const {
575 auto collapseShapeOp = getMemRefOperand(storeOp)
576 .template getDefiningOp<memref::CollapseShapeOp>();
577
578 if (!collapseShapeOp)
579 return failure();
580
581 SmallVector<Value> sourceIndices;
582 resolveSourceIndicesCollapseShape(storeOp.getLoc(), rewriter, collapseShapeOp,
583 storeOp.getIndices(), sourceIndices);
585 .Case([&](memref::StoreOp op) {
586 rewriter.replaceOpWithNewOp<memref::StoreOp>(
587 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
588 sourceIndices, op.getNontemporal());
589 })
590 .Case([&](vector::StoreOp op) {
591 rewriter.replaceOpWithNewOp<vector::StoreOp>(
592 op, op.getValueToStore(), collapseShapeOp.getViewSource(),
593 sourceIndices, op.getNontemporal());
594 })
595 .Case([&](vector::MaskedStoreOp op) {
596 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
597 op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
598 op.getValueToStore());
599 })
600 .DefaultUnreachable("unexpected operation");
601 return success();
602}
603
604LogicalResult
605AccessOpOfSubViewOpFolder::matchAndRewrite(memref::IndexedAccessOpInterface op,
606 PatternRewriter &rewriter) const {
607 auto subview = op.getAccessedMemref().getDefiningOp<memref::SubViewOp>();
608 if (!subview)
609 return rewriter.notifyMatchFailure(op, "not accessing a subview");
610
611 SmallVector<int64_t> accessedShape = op.getAccessedShape();
612 // Note the subtle difference between accesedShape = {1} and accessedShape =
613 // {} here. The former prevents us from fdolding in a subview that doesn't
614 // have a unit stride on the final dimension, while the latter does not (since
615 // it indices scalar accesss).
616 int64_t accessedDims = accessedShape.size();
617 if (!hasTrailingUnitStrides(subview, accessedDims))
618 return rewriter.notifyMatchFailure(
619 op, "non-unit stride on accessed dimensions");
620
621 llvm::SmallBitVector droppedDims = subview.getDroppedDims();
622 int64_t sourceRank = subview.getSourceType().getRank();
623
624 // Ignore outermost access dimension - we only care about dropped dimensions
625 // between the accessed op's results, as those could break the accessing op's
626 // sematics.
627 int64_t secondAccessedDim = sourceRank - (accessedDims - 1);
628 if (secondAccessedDim < sourceRank) {
629 for (int64_t d : llvm::seq(secondAccessedDim, sourceRank)) {
630 if (droppedDims.test(d))
631 return rewriter.notifyMatchFailure(
632 op, "reintroducing dropped dimension " + Twine(d) +
633 " would break access op semantics");
634 }
635 }
636
637 SmallVector<Value> sourceIndices;
638 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
639 rewriter, op.getLoc(), subview.getMixedOffsets(),
640 subview.getMixedStrides(), droppedDims, op.getIndices(), sourceIndices);
641
642 std::optional<SmallVector<Value>> newValues =
643 op.updateMemrefAndIndices(rewriter, subview.getSource(), sourceIndices);
644 if (newValues)
645 rewriter.replaceOp(op, *newValues);
646 return success();
647}
648
649LogicalResult AccessOpOfExpandShapeOpFolder::matchAndRewrite(
650 memref::IndexedAccessOpInterface op, PatternRewriter &rewriter) const {
651 auto expand = op.getAccessedMemref().getDefiningOp<memref::ExpandShapeOp>();
652 if (!expand)
653 return rewriter.notifyMatchFailure(op, "not accessing an expand_shape");
654
655 SmallVector<int64_t> rawAccessedShape = op.getAccessedShape();
656 ArrayRef<int64_t> accessedShape = rawAccessedShape;
657 // Cut off the leading dimension, since we don't care about monifying its
658 // strides.
659 if (!accessedShape.empty())
660 accessedShape = accessedShape.drop_front();
661
662 SmallVector<ReassociationIndices, 4> reassocs =
663 expand.getReassociationIndices();
664 if (!hasTrivialReassociationSuffix(reassocs, accessedShape.size()))
665 return rewriter.notifyMatchFailure(
666 op,
667 "expand_shape folding would merge semanvtically important dimensions");
668
669 SmallVector<Value> sourceIndices;
670 memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, expand,
671 op.getIndices(), sourceIndices,
672 op.hasInboundsIndices());
673
674 std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices(
675 rewriter, expand.getViewSource(), sourceIndices);
676 if (newValues)
677 rewriter.replaceOp(op, *newValues);
678 return success();
679}
680
681LogicalResult AccessOpOfCollapseShapeOpFolder::matchAndRewrite(
682 memref::IndexedAccessOpInterface op, PatternRewriter &rewriter) const {
683 auto collapse =
684 op.getAccessedMemref().getDefiningOp<memref::CollapseShapeOp>();
685 if (!collapse)
686 return rewriter.notifyMatchFailure(op, "not accessing a collapse_shape");
687
688 SmallVector<int64_t> rawAccessedShape = op.getAccessedShape();
689 ArrayRef<int64_t> accessedShape = rawAccessedShape;
690 // Cut off the leading dimension, since we don't care about its strides being
691 // modified and we know that the dimensions within its reassociation group, if
692 // it's non-trivial, must be contiguous.
693 if (!accessedShape.empty())
694 accessedShape = accessedShape.drop_front();
695
696 SmallVector<ReassociationIndices, 4> reassocs =
697 collapse.getReassociationIndices();
698 if (!hasTrivialReassociationSuffix(reassocs, accessedShape.size()))
699 return rewriter.notifyMatchFailure(op,
700 "collapse_shape folding would merge "
701 "semanvtically important dimensions");
702
703 SmallVector<Value> sourceIndices;
704 memref::resolveSourceIndicesCollapseShape(op.getLoc(), rewriter, collapse,
705 op.getIndices(), sourceIndices);
706
707 std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices(
708 rewriter, collapse.getViewSource(), sourceIndices);
709 if (newValues)
710 rewriter.replaceOp(op, *newValues);
711 return success();
712}
713
714LogicalResult IndexedMemCopyOpOfSubViewOpFolder::matchAndRewrite(
715 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const {
716 auto srcSubview = op.getSrc().getDefiningOp<memref::SubViewOp>();
717 auto dstSubview = op.getDst().getDefiningOp<memref::SubViewOp>();
718 if (!srcSubview && !dstSubview)
719 return rewriter.notifyMatchFailure(
720 op, "no subviews found on indexed copy inputs");
721
722 Value newSrc = op.getSrc();
723 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
724 Value newDst = op.getDst();
725 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
726 if (srcSubview) {
727 newSrc = srcSubview.getSource();
728 newSrcIndices.clear();
729 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
730 rewriter, op.getLoc(), srcSubview.getMixedOffsets(),
731 srcSubview.getMixedStrides(), srcSubview.getDroppedDims(),
732 op.getSrcIndices(), newSrcIndices);
733 }
734 if (dstSubview) {
735 newDst = dstSubview.getSource();
736 newDstIndices.clear();
737 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
738 rewriter, op.getLoc(), dstSubview.getMixedOffsets(),
739 dstSubview.getMixedStrides(), dstSubview.getDroppedDims(),
740 op.getDstIndices(), newDstIndices);
741 }
742 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
743 newDstIndices);
744 return success();
745}
746
747LogicalResult IndexedMemCopyOpOfExpandShapeOpFolder::matchAndRewrite(
748 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const {
749 auto srcExpand = op.getSrc().getDefiningOp<memref::ExpandShapeOp>();
750 auto dstExpand = op.getDst().getDefiningOp<memref::ExpandShapeOp>();
751 if (!srcExpand && !dstExpand)
752 return rewriter.notifyMatchFailure(
753 op, "no expand_shapes found on indexed copy inputs");
754
755 Value newSrc = op.getSrc();
756 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
757 Value newDst = op.getDst();
758 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
759 if (srcExpand) {
760 newSrc = srcExpand.getViewSource();
761 newSrcIndices.clear();
762 memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, srcExpand,
763 op.getSrcIndices(), newSrcIndices,
764 /*startsInbounds=*/true);
765 }
766 if (dstExpand) {
767 newDst = dstExpand.getViewSource();
768 newDstIndices.clear();
769 memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, dstExpand,
770 op.getDstIndices(), newDstIndices,
771 /*startsInbounds=*/true);
772 }
773 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
774 newDstIndices);
775 return success();
776}
777
778LogicalResult IndexedMemCopyOpOfCollapseShapeOpFolder::matchAndRewrite(
779 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const {
780 auto srcCollapse = op.getSrc().getDefiningOp<memref::CollapseShapeOp>();
781 auto dstCollapse = op.getDst().getDefiningOp<memref::CollapseShapeOp>();
782 if (!srcCollapse && !dstCollapse)
783 return rewriter.notifyMatchFailure(
784 op, "no collapse_shapes found on indexed copy inputs");
785
786 Value newSrc = op.getSrc();
787 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
788 Value newDst = op.getDst();
789 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
790 if (srcCollapse) {
791 newSrc = srcCollapse.getViewSource();
792 newSrcIndices.clear();
794 op.getLoc(), rewriter, srcCollapse, op.getSrcIndices(), newSrcIndices);
795 }
796 if (dstCollapse) {
797 newDst = dstCollapse.getViewSource();
798 newDstIndices.clear();
800 op.getLoc(), rewriter, dstCollapse, op.getDstIndices(), newDstIndices);
801 }
802 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
803 newDstIndices);
804 return success();
805}
806
807LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
808 nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const {
809
810 LLVM_DEBUG(DBGS() << "copyOp : " << copyOp << "\n");
811
812 auto srcSubViewOp =
813 copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
814 auto dstSubViewOp =
815 copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
816
817 if (!(srcSubViewOp || dstSubViewOp))
818 return rewriter.notifyMatchFailure(copyOp, "does not use subview ops for "
819 "source or destination");
820
821 // If the source is a subview, we need to resolve the indices.
822 SmallVector<Value> foldedSrcIndices(copyOp.getSrcIndices().begin(),
823 copyOp.getSrcIndices().end());
824
825 if (srcSubViewOp) {
826 LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n");
827 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
828 rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
829 srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
830 copyOp.getSrcIndices(), foldedSrcIndices);
831 }
832
833 // If the destination is a subview, we need to resolve the indices.
834 SmallVector<Value> foldedDstIndices(copyOp.getDstIndices().begin(),
835 copyOp.getDstIndices().end());
836
837 if (dstSubViewOp) {
838 LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n");
839 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
840 rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
841 dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
842 copyOp.getDstIndices(), foldedDstIndices);
843 }
844
845 // Replace the copy op with a new copy op that uses the source and destination
846 // of the subview.
847 rewriter.replaceOpWithNewOp<nvgpu::DeviceAsyncCopyOp>(
848 copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()),
849 (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
850 foldedDstIndices,
851 (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
852 foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
853 copyOp.getBypassL1Attr());
854
855 return success();
856}
857
859 patterns.add<
860 // Interface-based patterns to which we will be migrating.
861 AccessOpOfSubViewOpFolder, AccessOpOfExpandShapeOpFolder,
862 AccessOpOfCollapseShapeOpFolder, IndexedMemCopyOpOfSubViewOpFolder,
863 IndexedMemCopyOpOfExpandShapeOpFolder,
864 IndexedMemCopyOpOfCollapseShapeOpFolder,
865 // The old way of doing things. Don't add more of these.
866 LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
867 LoadOpOfSubViewOpFolder<vector::LoadOp>,
868 LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
869 LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
870 LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
871 StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
872 StoreOpOfSubViewOpFolder<vector::StoreOp>,
873 StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
874 StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
875 LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
876 LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
877 LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>,
878 StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
879 StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
880 LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
881 LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
882 StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
883 StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
884 SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
885 patterns.getContext());
886}
887
888//===----------------------------------------------------------------------===//
889// Pass registration
890//===----------------------------------------------------------------------===//
891
892namespace {
893
894struct FoldMemRefAliasOpsPass final
895 : public memref::impl::FoldMemRefAliasOpsPassBase<FoldMemRefAliasOpsPass> {
896 void runOnOperation() override;
897};
898
899} // namespace
900
901void FoldMemRefAliasOpsPass::runOnOperation() {
902 RewritePatternSet patterns(&getContext());
904 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
905}
return success()
#define DBGS()
Definition Hoisting.cpp:32
b getContext())
static bool hasTrivialReassociationSuffix(ArrayRef< ReassociationIndices > reassocs, int64_t n)
Deterimine if the last N indices of reassocitaion are trivial - that is, check if they all contain ex...
static bool hasTrailingUnitStrides(memref::SubViewOp subview, int64_t n)
static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, Operation *op, memref::SubViewOp subviewOp)
static LogicalResult preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp, memref::SubViewOp subviewOp)
static Value getMemRefOperand(LoadOrStoreOpTy op)
Helpers to access the memref operand for each op.
Base type for affine expression.
Definition AffineExpr.h:68
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
MLIRContext * getContext() const
Definition Builders.h:56
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
void resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...