MLIR 23.0.0git
FoldMemRefAliasOps.cpp
Go to the documentation of this file.
1//===- FoldMemRefAliasOps.cpp - Fold memref alias ops ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation pass folds loading/storing from/to subview ops into
10// loading/storing from/to the original memref.
11//
12//===----------------------------------------------------------------------===//
13
23#include "mlir/IR/AffineExpr.h"
24#include "mlir/IR/AffineMap.h"
27#include "llvm/ADT/STLExtras.h"
28#include "llvm/ADT/SmallBitVector.h"
29#include "llvm/ADT/TypeSwitch.h"
30#include "llvm/Support/Debug.h"
31#include <cstdint>
32
33#define DEBUG_TYPE "fold-memref-alias-ops"
34#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
35
36namespace mlir {
37namespace memref {
38#define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS
39#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
40} // namespace memref
41} // namespace mlir
42
43using namespace mlir;
44
45//===----------------------------------------------------------------------===//
46// Utility functions
47//===----------------------------------------------------------------------===//
48
49/// Deterimine if the last N indices of `reassocitaion` are trivial - that is,
50/// check if they all contain exactly one dimension to collape/expand into.
51static bool
53 int64_t n) {
54 if (n <= 0)
55 return true;
56 return llvm::all_of(
57 reassocs.take_back(n),
58 [&](const ReassociationIndices &indices) { return indices.size() == 1; });
59}
60
61static bool hasTrailingUnitStrides(memref::SubViewOp subview, int64_t n) {
62 if (n <= 0)
63 return true;
64 return llvm::all_of(subview.getStaticStrides().take_back(n),
65 [](int64_t s) { return s == 1; });
66}
67
68/// Helpers to access the memref operand for each op.
69template <typename LoadOrStoreOpTy>
70static Value getMemRefOperand(LoadOrStoreOpTy op) {
71 return op.getMemref();
72}
73
74static Value getMemRefOperand(vector::TransferReadOp op) {
75 return op.getBase();
76}
77
78static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
79 return op.getSrcMemref();
80}
81
82static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
83
84static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); }
85
86static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
87
88static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); }
89
90static Value getMemRefOperand(vector::TransferWriteOp op) {
91 return op.getBase();
92}
93
94//===----------------------------------------------------------------------===//
95// Patterns
96//===----------------------------------------------------------------------===//
97
98namespace {
99/// Merges subview operation with load/transferRead operation.
100template <typename OpTy>
101class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
102public:
103 using OpRewritePattern<OpTy>::OpRewritePattern;
104
105 LogicalResult matchAndRewrite(OpTy loadOp,
106 PatternRewriter &rewriter) const override;
107};
108
109/// Merges expand_shape operation with load/transferRead operation.
110template <typename OpTy>
111class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
112public:
113 using OpRewritePattern<OpTy>::OpRewritePattern;
114
115 LogicalResult matchAndRewrite(OpTy loadOp,
116 PatternRewriter &rewriter) const override;
117};
118
119/// Merges collapse_shape operation with load/transferRead operation.
120template <typename OpTy>
121class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
122public:
123 using OpRewritePattern<OpTy>::OpRewritePattern;
124
125 LogicalResult matchAndRewrite(OpTy loadOp,
126 PatternRewriter &rewriter) const override;
127};
128
129/// Merges subview operation with store/transferWriteOp operation.
130template <typename OpTy>
131class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
132public:
133 using OpRewritePattern<OpTy>::OpRewritePattern;
134
135 LogicalResult matchAndRewrite(OpTy storeOp,
136 PatternRewriter &rewriter) const override;
137};
138
139/// Merges expand_shape operation with store/transferWriteOp operation.
140template <typename OpTy>
141class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
142public:
143 using OpRewritePattern<OpTy>::OpRewritePattern;
144
145 LogicalResult matchAndRewrite(OpTy storeOp,
146 PatternRewriter &rewriter) const override;
147};
148
149/// Merges collapse_shape operation with store/transferWriteOp operation.
150template <typename OpTy>
151class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
152public:
153 using OpRewritePattern<OpTy>::OpRewritePattern;
154
155 LogicalResult matchAndRewrite(OpTy storeOp,
156 PatternRewriter &rewriter) const override;
157};
158
159/// Folds subview(subview(x)) to a single subview(x).
160class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
161public:
162 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
163
164 LogicalResult matchAndRewrite(memref::SubViewOp subView,
165 PatternRewriter &rewriter) const override {
166 auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
167 if (!srcSubView)
168 return failure();
169
170 SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
172 rewriter, subView.getLoc(), srcSubView, subView,
173 srcSubView.getDroppedDims(), newOffsets, newSizes, newStrides)))
174 return failure();
175
176 // Replace original op.
177 rewriter.replaceOpWithNewOp<memref::SubViewOp>(
178 subView, subView.getType(), srcSubView.getSource(), newOffsets,
179 newSizes, newStrides);
180 return success();
181 }
182};
183
184/// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern
185/// is folds subview on src and dst memref of the copy.
186class NVGPUAsyncCopyOpSubViewOpFolder final
187 : public OpRewritePattern<nvgpu::DeviceAsyncCopyOp> {
188public:
189 using OpRewritePattern<nvgpu::DeviceAsyncCopyOp>::OpRewritePattern;
190
191 LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
192 PatternRewriter &rewriter) const override;
193};
194
195/// Merges subview operations with load/store like operations unless such a
196/// merger would cause the strides between dimensions accessed by that operaton
197/// to change.
198struct AccessOpOfSubViewOpFolder final
199 : OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
200 using Base::Base;
201
202 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
203 PatternRewriter &rewriter) const override;
204};
205
206/// Merge a memref.expand_shape operation with an operation that accesses a
207/// memref by index unless that operation accesss more than one dimension of
208/// memory and any dimension other than the outermost dimension accessed this
209/// way would be merged. This prevents issuses from arising with, say, a
210/// vector.load of a 4x2 vector having the two trailing dimensions of the access
211/// get merged.
212struct AccessOpOfExpandShapeOpFolder final
213 : OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
214 using Base::Base;
215
216 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
217 PatternRewriter &rewriter) const override;
218};
219
220/// Merges an operation that accesses a memref by index with a
221/// memref.collapse_shape, unless this would break apart a dimension other than
222/// the outermost one that an operation accesses. This prevents, for example,
223/// transforming a load of a 3x8 vector from a 6x8 memref into a load
224/// from a 3x4x2 memref (as this would require special handling and could lead
225/// to invalid IR if that higher-dimensional memref comes from a subview) but
226/// does permit turning a load of a length-8 vector from a 3x8 memref into a
227/// load from a 3x2x8 one.
228struct AccessOpOfCollapseShapeOpFolder final
229 : OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
230 using Base::Base;
231
232 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
233 PatternRewriter &rewriter) const override;
234};
235
236/// Merges memref.subview operations present on the source or destination
237/// operands of indexed memory copy operations (DMA operations) into those
238/// operations. This is perfromed unconditionally, since folding in a subview
239/// cannot change the starting position of the copy, which is what the
240/// memref/index pair represent in DMA operations.
241struct IndexedMemCopyOpOfSubViewOpFolder final
242 : OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
243 using Base::Base;
244
245 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
246 PatternRewriter &rewriter) const override;
247};
248
249/// Merges memref.expand_shape operations that are present on the source or
250/// destination of an indexed memory copy/DMA into the memref/index arguments of
251/// that DMA. As with subviews, this can be done unconditionally.
252struct IndexedMemCopyOpOfExpandShapeOpFolder final
253 : OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
254 using Base::Base;
255
256 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
257 PatternRewriter &rewriter) const override;
258};
259
260/// Merges memref.collapse_shape operations that are present on the source or
261/// destination of an indexed memory copy/DMA into the memref/index arguments of
262/// that DMA. As with subviews, this can be done unconditionally.
263struct IndexedMemCopyOpOfCollapseShapeOpFolder final
264 : OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
265 using Base::Base;
266
267 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
268 PatternRewriter &rewriter) const override;
269};
270} // namespace
271
272template <typename XferOp>
273static LogicalResult
275 memref::SubViewOp subviewOp) {
276 static_assert(
277 !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
278 "must be a vector transfer op");
279 if (xferOp.hasOutOfBoundsDim())
280 return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
281 if (!subviewOp.hasUnitStride()) {
282 return rewriter.notifyMatchFailure(
283 xferOp, "non-1 stride subview, need to track strides in folded memref");
284 }
285 return success();
286}
287
288static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
289 Operation *op,
290 memref::SubViewOp subviewOp) {
291 return success();
292}
293
294static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
295 vector::TransferReadOp readOp,
296 memref::SubViewOp subviewOp) {
297 return preconditionsFoldSubViewOpImpl(rewriter, readOp, subviewOp);
298}
299
300static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
301 vector::TransferWriteOp writeOp,
302 memref::SubViewOp subviewOp) {
303 return preconditionsFoldSubViewOpImpl(rewriter, writeOp, subviewOp);
304}
305
306template <typename OpTy>
307LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
308 OpTy loadOp, PatternRewriter &rewriter) const {
309 auto subViewOp =
310 getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
311
312 if (!subViewOp)
313 return rewriter.notifyMatchFailure(loadOp, "not a subview producer");
314
315 LogicalResult preconditionResult =
316 preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp);
317 if (failed(preconditionResult))
318 return preconditionResult;
319
320 SmallVector<Value> sourceIndices;
322 rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
323 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
324 loadOp.getIndices(), sourceIndices);
325
327 .Case([&](memref::LoadOp op) {
328 rewriter.replaceOpWithNewOp<memref::LoadOp>(
329 loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
330 })
331 .Case([&](vector::LoadOp op) {
332 rewriter.replaceOpWithNewOp<vector::LoadOp>(
333 op, op.getType(), subViewOp.getSource(), sourceIndices);
334 })
335 .Case([&](vector::MaskedLoadOp op) {
336 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
337 op, op.getType(), subViewOp.getSource(), sourceIndices,
338 op.getMask(), op.getPassThru());
339 })
340 .Case([&](vector::TransferReadOp op) {
341 rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
342 op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
343 AffineMapAttr::get(expandDimsToRank(
344 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
345 subViewOp.getDroppedDims())),
346 op.getPadding(), op.getMask(), op.getInBoundsAttr());
347 })
348 .Case([&](nvgpu::LdMatrixOp op) {
349 rewriter.replaceOpWithNewOp<nvgpu::LdMatrixOp>(
350 op, op.getType(), subViewOp.getSource(), sourceIndices,
351 op.getTranspose(), op.getNumTiles());
352 })
353 .DefaultUnreachable("unexpected operation");
354 return success();
355}
356
357template <typename OpTy>
358LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
359 OpTy loadOp, PatternRewriter &rewriter) const {
360 auto expandShapeOp =
361 getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>();
362
363 if (!expandShapeOp)
364 return failure();
365
366 // For vector::TransferReadOp, validate preconditions before creating any IR.
367 // resolveSourceIndicesExpandShape creates new ops, so all checks that can
368 // fail must happen before that call to avoid "pattern returned failure but
369 // IR did change" errors (caught by MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS).
370 SmallVector<AffineExpr> transferReadNewResults;
371 if (auto transferOp =
372 dyn_cast<vector::TransferReadOp>(loadOp.getOperation())) {
373 const int64_t vectorRank = transferOp.getVectorType().getRank();
374 const int64_t sourceRank =
375 cast<MemRefType>(expandShapeOp.getViewSource().getType()).getRank();
376 if (sourceRank < vectorRank)
377 return failure();
378
379 // We can only fold if the permutation map uses only the least significant
380 // dimension from each expanded reassociation group.
381 for (AffineExpr result : transferOp.getPermutationMap().getResults()) {
382 bool foundExpr = false;
383 for (auto reassocationIndices :
384 llvm::enumerate(expandShapeOp.getReassociationIndices())) {
385 auto reassociation = reassocationIndices.value();
387 reassociation[reassociation.size() - 1], rewriter.getContext());
388 if (dim == result) {
389 transferReadNewResults.push_back(getAffineDimExpr(
390 reassocationIndices.index(), rewriter.getContext()));
391 foundExpr = true;
392 break;
393 }
394 }
395 if (!foundExpr)
396 return failure();
397 }
398 }
399
400 SmallVector<Value> sourceIndices;
401 // memref.load guarantees that indexes start inbounds while the vector
402 // operations don't. This impacts if our linearization is `disjoint`
403 resolveSourceIndicesExpandShape(loadOp.getLoc(), rewriter, expandShapeOp,
404 loadOp.getIndices(), sourceIndices,
405 isa<memref::LoadOp>(loadOp.getOperation()));
406
408 .Case([&](memref::LoadOp op) {
409 rewriter.replaceOpWithNewOp<memref::LoadOp>(
410 loadOp, expandShapeOp.getViewSource(), sourceIndices,
411 op.getNontemporal());
412 return success();
413 })
414 .Case([&](vector::LoadOp op) {
415 rewriter.replaceOpWithNewOp<vector::LoadOp>(
416 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
417 op.getNontemporal());
418 return success();
419 })
420 .Case([&](vector::MaskedLoadOp op) {
421 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
422 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
423 op.getMask(), op.getPassThru());
424 return success();
425 })
426 .Case([&](vector::TransferReadOp op) {
427 const int64_t sourceRank = sourceIndices.size();
428 auto newMap = AffineMap::get(sourceRank, 0, transferReadNewResults,
429 op.getContext());
430 rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
431 op, op.getVectorType(), expandShapeOp.getViewSource(),
432 sourceIndices, newMap, op.getPadding(), op.getMask(),
433 op.getInBounds());
434 return success();
435 })
436 .DefaultUnreachable("unexpected operation");
438
439template <typename OpTy>
440LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
441 OpTy loadOp, PatternRewriter &rewriter) const {
442 auto collapseShapeOp = getMemRefOperand(loadOp)
443 .template getDefiningOp<memref::CollapseShapeOp>();
445 if (!collapseShapeOp)
446 return failure();
448 SmallVector<Value> sourceIndices;
449 resolveSourceIndicesCollapseShape(loadOp.getLoc(), rewriter, collapseShapeOp,
450 loadOp.getIndices(), sourceIndices);
452 .Case([&](memref::LoadOp op) {
453 rewriter.replaceOpWithNewOp<memref::LoadOp>(
454 loadOp, collapseShapeOp.getViewSource(), sourceIndices,
455 op.getNontemporal());
456 })
457 .Case([&](vector::LoadOp op) {
458 rewriter.replaceOpWithNewOp<vector::LoadOp>(
459 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
460 op.getNontemporal());
461 })
462 .Case([&](vector::MaskedLoadOp op) {
463 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
464 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
465 op.getMask(), op.getPassThru());
466 })
467 .DefaultUnreachable("unexpected operation");
468 return success();
469}
470
471template <typename OpTy>
472LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
473 OpTy storeOp, PatternRewriter &rewriter) const {
474 auto subViewOp =
475 getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
476
477 if (!subViewOp)
478 return rewriter.notifyMatchFailure(storeOp, "not a subview producer");
479
480 LogicalResult preconditionResult =
481 preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp);
482 if (failed(preconditionResult))
483 return preconditionResult;
487 rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
488 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
489 storeOp.getIndices(), sourceIndices);
492 .Case([&](memref::StoreOp op) {
493 rewriter.replaceOpWithNewOp<memref::StoreOp>(
494 op, op.getValue(), subViewOp.getSource(), sourceIndices,
495 op.getNontemporal());
496 })
497 .Case([&](vector::TransferWriteOp op) {
498 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
499 op, op.getValue(), subViewOp.getSource(), sourceIndices,
500 AffineMapAttr::get(expandDimsToRank(
501 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
502 subViewOp.getDroppedDims())),
503 op.getMask(), op.getInBoundsAttr());
504 })
505 .Case([&](vector::StoreOp op) {
506 rewriter.replaceOpWithNewOp<vector::StoreOp>(
507 op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
508 })
509 .Case([&](vector::MaskedStoreOp op) {
510 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
511 op, subViewOp.getSource(), sourceIndices, op.getMask(),
512 op.getValueToStore());
513 })
514 .DefaultUnreachable("unexpected operation");
515 return success();
516}
517
518template <typename OpTy>
519LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
520 OpTy storeOp, PatternRewriter &rewriter) const {
521 auto expandShapeOp =
522 getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>();
523
524 if (!expandShapeOp)
525 return failure();
526
527 SmallVector<Value> sourceIndices;
528 // memref.store guarantees that indexes start inbounds while the vector
529 // operations don't. This impacts if our linearization is `disjoint`
530 resolveSourceIndicesExpandShape(storeOp.getLoc(), rewriter, expandShapeOp,
531 storeOp.getIndices(), sourceIndices,
532 isa<memref::StoreOp>(storeOp.getOperation()));
534 .Case([&](memref::StoreOp op) {
535 rewriter.replaceOpWithNewOp<memref::StoreOp>(
536 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
537 sourceIndices, op.getNontemporal());
538 })
539 .Case([&](vector::StoreOp op) {
540 rewriter.replaceOpWithNewOp<vector::StoreOp>(
541 op, op.getValueToStore(), expandShapeOp.getViewSource(),
542 sourceIndices, op.getNontemporal());
543 })
544 .Case([&](vector::MaskedStoreOp op) {
545 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
546 op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
547 op.getValueToStore());
548 })
549 .DefaultUnreachable("unexpected operation");
550 return success();
551}
552
553template <typename OpTy>
554LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
555 OpTy storeOp, PatternRewriter &rewriter) const {
556 auto collapseShapeOp = getMemRefOperand(storeOp)
557 .template getDefiningOp<memref::CollapseShapeOp>();
558
559 if (!collapseShapeOp)
560 return failure();
561
562 SmallVector<Value> sourceIndices;
563 resolveSourceIndicesCollapseShape(storeOp.getLoc(), rewriter, collapseShapeOp,
564 storeOp.getIndices(), sourceIndices);
566 .Case([&](memref::StoreOp op) {
567 rewriter.replaceOpWithNewOp<memref::StoreOp>(
568 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
569 sourceIndices, op.getNontemporal());
570 })
571 .Case([&](vector::StoreOp op) {
572 rewriter.replaceOpWithNewOp<vector::StoreOp>(
573 op, op.getValueToStore(), collapseShapeOp.getViewSource(),
574 sourceIndices, op.getNontemporal());
575 })
576 .Case([&](vector::MaskedStoreOp op) {
577 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
578 op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
579 op.getValueToStore());
580 })
581 .DefaultUnreachable("unexpected operation");
582 return success();
583}
584
585LogicalResult
586AccessOpOfSubViewOpFolder::matchAndRewrite(memref::IndexedAccessOpInterface op,
587 PatternRewriter &rewriter) const {
588 auto subview = op.getAccessedMemref().getDefiningOp<memref::SubViewOp>();
589 if (!subview)
590 return rewriter.notifyMatchFailure(op, "not accessing a subview");
591
592 SmallVector<int64_t> accessedShape = op.getAccessedShape();
593 // Note the subtle difference between accesedShape = {1} and accessedShape =
594 // {} here. The former prevents us from fdolding in a subview that doesn't
595 // have a unit stride on the final dimension, while the latter does not (since
596 // it indices scalar accesss).
597 int64_t accessedDims = accessedShape.size();
598 if (!hasTrailingUnitStrides(subview, accessedDims))
599 return rewriter.notifyMatchFailure(
600 op, "non-unit stride on accessed dimensions");
601
602 llvm::SmallBitVector droppedDims = subview.getDroppedDims();
603 int64_t sourceRank = subview.getSourceType().getRank();
604
605 // Ignore outermost access dimension - we only care about dropped dimensions
606 // between the accessed op's results, as those could break the accessing op's
607 // sematics.
608 int64_t secondAccessedDim = sourceRank - (accessedDims - 1);
609 if (secondAccessedDim < sourceRank) {
610 for (int64_t d : llvm::seq(secondAccessedDim, sourceRank)) {
611 if (droppedDims.test(d))
612 return rewriter.notifyMatchFailure(
613 op, "reintroducing dropped dimension " + Twine(d) +
614 " would break access op semantics");
615 }
616 }
617
618 SmallVector<Value> sourceIndices;
620 rewriter, op.getLoc(), subview.getMixedOffsets(),
621 subview.getMixedStrides(), droppedDims, op.getIndices(), sourceIndices);
622
623 std::optional<SmallVector<Value>> newValues =
624 op.updateMemrefAndIndices(rewriter, subview.getSource(), sourceIndices);
625 if (newValues)
626 rewriter.replaceOp(op, *newValues);
627 return success();
628}
629
630LogicalResult AccessOpOfExpandShapeOpFolder::matchAndRewrite(
631 memref::IndexedAccessOpInterface op, PatternRewriter &rewriter) const {
632 auto expand = op.getAccessedMemref().getDefiningOp<memref::ExpandShapeOp>();
633 if (!expand)
634 return rewriter.notifyMatchFailure(op, "not accessing an expand_shape");
635
636 SmallVector<int64_t> rawAccessedShape = op.getAccessedShape();
637 ArrayRef<int64_t> accessedShape = rawAccessedShape;
638 // Cut off the leading dimension, since we don't care about monifying its
639 // strides.
640 if (!accessedShape.empty())
641 accessedShape = accessedShape.drop_front();
642
643 SmallVector<ReassociationIndices, 4> reassocs =
644 expand.getReassociationIndices();
645 if (!hasTrivialReassociationSuffix(reassocs, accessedShape.size()))
646 return rewriter.notifyMatchFailure(
647 op,
648 "expand_shape folding would merge semanvtically important dimensions");
649
650 SmallVector<Value> sourceIndices;
651 memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, expand,
652 op.getIndices(), sourceIndices,
653 op.hasInboundsIndices());
654
655 std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices(
656 rewriter, expand.getViewSource(), sourceIndices);
657 if (newValues)
658 rewriter.replaceOp(op, *newValues);
659 return success();
660}
661
662LogicalResult AccessOpOfCollapseShapeOpFolder::matchAndRewrite(
663 memref::IndexedAccessOpInterface op, PatternRewriter &rewriter) const {
664 auto collapse =
665 op.getAccessedMemref().getDefiningOp<memref::CollapseShapeOp>();
666 if (!collapse)
667 return rewriter.notifyMatchFailure(op, "not accessing a collapse_shape");
668
669 SmallVector<int64_t> rawAccessedShape = op.getAccessedShape();
670 ArrayRef<int64_t> accessedShape = rawAccessedShape;
671 // Cut off the leading dimension, since we don't care about its strides being
672 // modified and we know that the dimensions within its reassociation group, if
673 // it's non-trivial, must be contiguous.
674 if (!accessedShape.empty())
675 accessedShape = accessedShape.drop_front();
676
677 SmallVector<ReassociationIndices, 4> reassocs =
678 collapse.getReassociationIndices();
679 if (!hasTrivialReassociationSuffix(reassocs, accessedShape.size()))
680 return rewriter.notifyMatchFailure(op,
681 "collapse_shape folding would merge "
682 "semanvtically important dimensions");
683
684 SmallVector<Value> sourceIndices;
685 memref::resolveSourceIndicesCollapseShape(op.getLoc(), rewriter, collapse,
686 op.getIndices(), sourceIndices);
687
688 std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices(
689 rewriter, collapse.getViewSource(), sourceIndices);
690 if (newValues)
691 rewriter.replaceOp(op, *newValues);
692 return success();
693}
694
695LogicalResult IndexedMemCopyOpOfSubViewOpFolder::matchAndRewrite(
696 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const {
697 auto srcSubview = op.getSrc().getDefiningOp<memref::SubViewOp>();
698 auto dstSubview = op.getDst().getDefiningOp<memref::SubViewOp>();
699 if (!srcSubview && !dstSubview)
700 return rewriter.notifyMatchFailure(
701 op, "no subviews found on indexed copy inputs");
702
703 Value newSrc = op.getSrc();
704 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
705 Value newDst = op.getDst();
706 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
707 if (srcSubview) {
708 newSrc = srcSubview.getSource();
709 newSrcIndices.clear();
711 rewriter, op.getLoc(), srcSubview.getMixedOffsets(),
712 srcSubview.getMixedStrides(), srcSubview.getDroppedDims(),
713 op.getSrcIndices(), newSrcIndices);
714 }
715 if (dstSubview) {
716 newDst = dstSubview.getSource();
717 newDstIndices.clear();
719 rewriter, op.getLoc(), dstSubview.getMixedOffsets(),
720 dstSubview.getMixedStrides(), dstSubview.getDroppedDims(),
721 op.getDstIndices(), newDstIndices);
722 }
723 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
724 newDstIndices);
725 return success();
726}
727
728LogicalResult IndexedMemCopyOpOfExpandShapeOpFolder::matchAndRewrite(
729 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const {
730 auto srcExpand = op.getSrc().getDefiningOp<memref::ExpandShapeOp>();
731 auto dstExpand = op.getDst().getDefiningOp<memref::ExpandShapeOp>();
732 if (!srcExpand && !dstExpand)
733 return rewriter.notifyMatchFailure(
734 op, "no expand_shapes found on indexed copy inputs");
735
736 Value newSrc = op.getSrc();
737 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
738 Value newDst = op.getDst();
739 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
740 if (srcExpand) {
741 newSrc = srcExpand.getViewSource();
742 newSrcIndices.clear();
743 memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, srcExpand,
744 op.getSrcIndices(), newSrcIndices,
745 /*startsInbounds=*/true);
746 }
747 if (dstExpand) {
748 newDst = dstExpand.getViewSource();
749 newDstIndices.clear();
750 memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, dstExpand,
751 op.getDstIndices(), newDstIndices,
752 /*startsInbounds=*/true);
753 }
754 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
755 newDstIndices);
756 return success();
757}
758
759LogicalResult IndexedMemCopyOpOfCollapseShapeOpFolder::matchAndRewrite(
760 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const {
761 auto srcCollapse = op.getSrc().getDefiningOp<memref::CollapseShapeOp>();
762 auto dstCollapse = op.getDst().getDefiningOp<memref::CollapseShapeOp>();
763 if (!srcCollapse && !dstCollapse)
764 return rewriter.notifyMatchFailure(
765 op, "no collapse_shapes found on indexed copy inputs");
766
767 Value newSrc = op.getSrc();
768 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
769 Value newDst = op.getDst();
770 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
771 if (srcCollapse) {
772 newSrc = srcCollapse.getViewSource();
773 newSrcIndices.clear();
775 op.getLoc(), rewriter, srcCollapse, op.getSrcIndices(), newSrcIndices);
776 }
777 if (dstCollapse) {
778 newDst = dstCollapse.getViewSource();
779 newDstIndices.clear();
781 op.getLoc(), rewriter, dstCollapse, op.getDstIndices(), newDstIndices);
782 }
783 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
784 newDstIndices);
785 return success();
786}
787
788LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
789 nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const {
790
791 LLVM_DEBUG(DBGS() << "copyOp : " << copyOp << "\n");
792
793 auto srcSubViewOp =
794 copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
795 auto dstSubViewOp =
796 copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
797
798 if (!(srcSubViewOp || dstSubViewOp))
799 return rewriter.notifyMatchFailure(copyOp, "does not use subview ops for "
800 "source or destination");
801
802 // If the source is a subview, we need to resolve the indices.
803 SmallVector<Value> foldedSrcIndices(copyOp.getSrcIndices().begin(),
804 copyOp.getSrcIndices().end());
805
806 if (srcSubViewOp) {
807 LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n");
809 rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
810 srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
811 copyOp.getSrcIndices(), foldedSrcIndices);
812 }
813
814 // If the destination is a subview, we need to resolve the indices.
815 SmallVector<Value> foldedDstIndices(copyOp.getDstIndices().begin(),
816 copyOp.getDstIndices().end());
817
818 if (dstSubViewOp) {
819 LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n");
821 rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
822 dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
823 copyOp.getDstIndices(), foldedDstIndices);
824 }
825
826 // Replace the copy op with a new copy op that uses the source and destination
827 // of the subview.
828 rewriter.replaceOpWithNewOp<nvgpu::DeviceAsyncCopyOp>(
829 copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()),
830 (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
831 foldedDstIndices,
832 (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
833 foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
834 copyOp.getBypassL1Attr());
835
836 return success();
837}
838
840 patterns.add<
841 // Interface-based patterns to which we will be migrating.
842 AccessOpOfSubViewOpFolder, AccessOpOfExpandShapeOpFolder,
843 AccessOpOfCollapseShapeOpFolder, IndexedMemCopyOpOfSubViewOpFolder,
844 IndexedMemCopyOpOfExpandShapeOpFolder,
845 IndexedMemCopyOpOfCollapseShapeOpFolder,
846 // The old way of doing things. Don't add more of these.
847 LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
848 LoadOpOfSubViewOpFolder<vector::LoadOp>,
849 LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
850 LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
851 StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
852 StoreOpOfSubViewOpFolder<vector::StoreOp>,
853 StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
854 LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
855 LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
856 LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>,
857 StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
858 StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
859 LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
860 LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
861 StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
862 StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
863 SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
864 patterns.getContext());
865}
866
867//===----------------------------------------------------------------------===//
868// Pass registration
869//===----------------------------------------------------------------------===//
870
871namespace {
872
873struct FoldMemRefAliasOpsPass final
874 : public memref::impl::FoldMemRefAliasOpsPassBase<FoldMemRefAliasOpsPass> {
875 void runOnOperation() override;
876};
877
878} // namespace
879
880void FoldMemRefAliasOpsPass::runOnOperation() {
881 RewritePatternSet patterns(&getContext());
883 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
884}
return success()
#define DBGS()
Definition Hoisting.cpp:32
b getContext())
static bool hasTrivialReassociationSuffix(ArrayRef< ReassociationIndices > reassocs, int64_t n)
Deterimine if the last N indices of reassocitaion are trivial - that is, check if they all contain ex...
static bool hasTrailingUnitStrides(memref::SubViewOp subview, int64_t n)
static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, Operation *op, memref::SubViewOp subviewOp)
static LogicalResult preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp, memref::SubViewOp subviewOp)
static Value getMemRefOperand(LoadOrStoreOpTy op)
Helpers to access the memref operand for each op.
Base type for affine expression.
Definition AffineExpr.h:68
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
MLIRContext * getContext() const
Definition Builders.h:56
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
LogicalResult mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > producerOffsets, ArrayRef< OpFoldResult > producerSizes, ArrayRef< OpFoldResult > producerStrides, const llvm::SmallBitVector &droppedProducerDims, ArrayRef< OpFoldResult > consumerOffsets, ArrayRef< OpFoldResult > consumerSizes, ArrayRef< OpFoldResult > consumerStrides, SmallVector< OpFoldResult > &combinedOffsets, SmallVector< OpFoldResult > &combinedSizes, SmallVector< OpFoldResult > &combinedStrides)
Fills the combinedOffsets, combinedSizes and combinedStrides to use when combining a producer slice i...
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
void resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...