MLIR 23.0.0git
FoldMemRefAliasOps.cpp
Go to the documentation of this file.
1//===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation pass folds loading/storing from/to subview ops into
10// loading/storing from/to the original memref.
11//
12//===----------------------------------------------------------------------===//
13
23#include "mlir/IR/AffineMap.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallBitVector.h"
27#include "llvm/ADT/TypeSwitch.h"
28#include "llvm/Support/Debug.h"
29
30#define DEBUG_TYPE "fold-memref-alias-ops"
31#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
32
33namespace mlir {
34namespace memref {
35#define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS
36#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
37} // namespace memref
38} // namespace mlir
39
40using namespace mlir;
41
42//===----------------------------------------------------------------------===//
43// Utility functions
44//===----------------------------------------------------------------------===//
45
46/// Helpers to access the memref operand for each op.
47template <typename LoadOrStoreOpTy>
48static Value getMemRefOperand(LoadOrStoreOpTy op) {
49 return op.getMemref();
50}
51
52static Value getMemRefOperand(vector::TransferReadOp op) {
53 return op.getBase();
54}
55
56static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
57 return op.getSrcMemref();
58}
59
60static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
61
62static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); }
63
64static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
65
66static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); }
67
68static Value getMemRefOperand(vector::TransferWriteOp op) {
69 return op.getBase();
70}
71
72static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) {
73 return op.getSrcMemref();
74}
75
76static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) {
77 return op.getDstMemref();
78}
79
80//===----------------------------------------------------------------------===//
81// Patterns
82//===----------------------------------------------------------------------===//
83
84namespace {
85/// Merges subview operation with load/transferRead operation.
86template <typename OpTy>
87class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
88public:
89 using OpRewritePattern<OpTy>::OpRewritePattern;
90
91 LogicalResult matchAndRewrite(OpTy loadOp,
92 PatternRewriter &rewriter) const override;
93};
94
95/// Merges expand_shape operation with load/transferRead operation.
96template <typename OpTy>
97class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
98public:
99 using OpRewritePattern<OpTy>::OpRewritePattern;
100
101 LogicalResult matchAndRewrite(OpTy loadOp,
102 PatternRewriter &rewriter) const override;
103};
104
105/// Merges collapse_shape operation with load/transferRead operation.
106template <typename OpTy>
107class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
108public:
109 using OpRewritePattern<OpTy>::OpRewritePattern;
110
111 LogicalResult matchAndRewrite(OpTy loadOp,
112 PatternRewriter &rewriter) const override;
113};
114
115/// Merges subview operation with store/transferWriteOp operation.
116template <typename OpTy>
117class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
118public:
119 using OpRewritePattern<OpTy>::OpRewritePattern;
120
121 LogicalResult matchAndRewrite(OpTy storeOp,
122 PatternRewriter &rewriter) const override;
123};
124
125/// Merges expand_shape operation with store/transferWriteOp operation.
126template <typename OpTy>
127class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
128public:
129 using OpRewritePattern<OpTy>::OpRewritePattern;
130
131 LogicalResult matchAndRewrite(OpTy storeOp,
132 PatternRewriter &rewriter) const override;
133};
134
135/// Merges collapse_shape operation with store/transferWriteOp operation.
136template <typename OpTy>
137class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
138public:
139 using OpRewritePattern<OpTy>::OpRewritePattern;
140
141 LogicalResult matchAndRewrite(OpTy storeOp,
142 PatternRewriter &rewriter) const override;
143};
144
145/// Folds subview(subview(x)) to a single subview(x).
146class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
147public:
148 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
149
150 LogicalResult matchAndRewrite(memref::SubViewOp subView,
151 PatternRewriter &rewriter) const override {
152 auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
153 if (!srcSubView)
154 return failure();
155
156 // TODO: relax unit stride assumption.
157 if (!subView.hasUnitStride()) {
158 return rewriter.notifyMatchFailure(subView, "requires unit strides");
159 }
160 if (!srcSubView.hasUnitStride()) {
161 return rewriter.notifyMatchFailure(srcSubView, "requires unit strides");
162 }
163
164 // Resolve sizes according to dropped dims.
165 SmallVector<OpFoldResult> resolvedSizes;
166 llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
167 affine::resolveSizesIntoOpWithSizes(srcSubView.getMixedSizes(),
168 subView.getMixedSizes(), srcDroppedDims,
169 resolvedSizes);
170
171 // Resolve offsets according to source offsets and strides.
172 SmallVector<Value> resolvedOffsets;
173 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
174 rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
175 srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
176 resolvedOffsets);
177
178 // Replace original op.
179 rewriter.replaceOpWithNewOp<memref::SubViewOp>(
180 subView, subView.getType(), srcSubView.getSource(),
181 getAsOpFoldResult(resolvedOffsets), resolvedSizes,
182 srcSubView.getMixedStrides());
183
184 return success();
185 }
186};
187
188/// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern
189/// is folds subview on src and dst memref of the copy.
190class NVGPUAsyncCopyOpSubViewOpFolder final
191 : public OpRewritePattern<nvgpu::DeviceAsyncCopyOp> {
192public:
193 using OpRewritePattern<nvgpu::DeviceAsyncCopyOp>::OpRewritePattern;
194
195 LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
196 PatternRewriter &rewriter) const override;
197};
198} // namespace
199
200template <typename XferOp>
201static LogicalResult
203 memref::SubViewOp subviewOp) {
204 static_assert(
205 !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
206 "must be a vector transfer op");
207 if (xferOp.hasOutOfBoundsDim())
208 return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
209 if (!subviewOp.hasUnitStride()) {
210 return rewriter.notifyMatchFailure(
211 xferOp, "non-1 stride subview, need to track strides in folded memref");
212 }
213 return success();
214}
215
216static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
217 Operation *op,
218 memref::SubViewOp subviewOp) {
219 return success();
220}
221
222static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
223 vector::TransferReadOp readOp,
224 memref::SubViewOp subviewOp) {
225 return preconditionsFoldSubViewOpImpl(rewriter, readOp, subviewOp);
226}
227
228static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
229 vector::TransferWriteOp writeOp,
230 memref::SubViewOp subviewOp) {
231 return preconditionsFoldSubViewOpImpl(rewriter, writeOp, subviewOp);
232}
233
234template <typename OpTy>
235LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
236 OpTy loadOp, PatternRewriter &rewriter) const {
237 auto subViewOp =
238 getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
239
240 if (!subViewOp)
241 return rewriter.notifyMatchFailure(loadOp, "not a subview producer");
242
243 LogicalResult preconditionResult =
244 preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp);
245 if (failed(preconditionResult))
246 return preconditionResult;
247
248 SmallVector<Value> sourceIndices;
250 rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
251 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
252 loadOp.getIndices(), sourceIndices);
253
255 .Case([&](memref::LoadOp op) {
256 rewriter.replaceOpWithNewOp<memref::LoadOp>(
257 loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
258 })
259 .Case([&](vector::LoadOp op) {
260 rewriter.replaceOpWithNewOp<vector::LoadOp>(
261 op, op.getType(), subViewOp.getSource(), sourceIndices);
262 })
263 .Case([&](vector::MaskedLoadOp op) {
264 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
265 op, op.getType(), subViewOp.getSource(), sourceIndices,
266 op.getMask(), op.getPassThru());
267 })
268 .Case([&](vector::TransferReadOp op) {
269 rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
270 op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
271 AffineMapAttr::get(expandDimsToRank(
272 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
273 subViewOp.getDroppedDims())),
274 op.getPadding(), op.getMask(), op.getInBoundsAttr());
275 })
276 .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
277 rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
278 op, op.getType(), subViewOp.getSource(), sourceIndices,
279 op.getLeadDimension(), op.getTransposeAttr());
280 })
281 .Case([&](nvgpu::LdMatrixOp op) {
282 rewriter.replaceOpWithNewOp<nvgpu::LdMatrixOp>(
283 op, op.getType(), subViewOp.getSource(), sourceIndices,
284 op.getTranspose(), op.getNumTiles());
285 })
286 .DefaultUnreachable("unexpected operation");
287 return success();
288}
289
290template <typename OpTy>
291LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
292 OpTy loadOp, PatternRewriter &rewriter) const {
293 auto expandShapeOp =
294 getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>();
295
296 if (!expandShapeOp)
297 return failure();
298
299 SmallVector<Value> sourceIndices;
300 // memref.load guarantees that indexes start inbounds while the vector
301 // operations don't. This impacts if our linearization is `disjoint`
302 resolveSourceIndicesExpandShape(loadOp.getLoc(), rewriter, expandShapeOp,
303 loadOp.getIndices(), sourceIndices,
304 isa<memref::LoadOp>(loadOp.getOperation()));
305
307 .Case([&](memref::LoadOp op) {
308 rewriter.replaceOpWithNewOp<memref::LoadOp>(
309 loadOp, expandShapeOp.getViewSource(), sourceIndices,
310 op.getNontemporal());
311 return success();
312 })
313 .Case([&](vector::LoadOp op) {
314 rewriter.replaceOpWithNewOp<vector::LoadOp>(
315 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
316 op.getNontemporal());
317 return success();
318 })
319 .Case([&](vector::MaskedLoadOp op) {
320 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
321 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
322 op.getMask(), op.getPassThru());
323 return success();
324 })
325 .Case([&](vector::TransferReadOp op) {
326 // We only support the case where the source of the expand shape has
327 // rank greater than or equal to the vector rank.
328 const int64_t vectorRank = op.getVectorType().getRank();
329 const int64_t sourceRank = sourceIndices.size();
330 if (sourceRank < vectorRank)
331 return failure();
332
333 SmallVector<AffineExpr> newResults;
334 // We can only fold if the permutation map uses only the least
335 // significant dimension from an expanded shape.
336 for (AffineExpr result : op.getPermutationMap().getResults()) {
337 bool foundExpr = false;
338
339 for (auto reassocationIndices :
340 llvm::enumerate(expandShapeOp.getReassociationIndices())) {
341 auto reassociation = reassocationIndices.value();
342
344 reassociation[reassociation.size() - 1], rewriter.getContext());
345 if (dim == result) {
346 newResults.push_back(getAffineDimExpr(reassocationIndices.index(),
347 rewriter.getContext()));
348 foundExpr = true;
349 break;
350 }
351 }
352 if (!foundExpr)
353 return failure();
354 }
355
356 auto newMap =
357 AffineMap::get(sourceRank, 0, newResults, op.getContext());
358
359 rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
360 op, op.getVectorType(), expandShapeOp.getViewSource(),
361 sourceIndices, newMap, op.getPadding(), op.getMask(),
362 op.getInBounds());
363 return success();
364 })
365 .DefaultUnreachable("unexpected operation");
366}
367
368template <typename OpTy>
369LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
370 OpTy loadOp, PatternRewriter &rewriter) const {
371 auto collapseShapeOp = getMemRefOperand(loadOp)
372 .template getDefiningOp<memref::CollapseShapeOp>();
373
374 if (!collapseShapeOp)
375 return failure();
376
377 SmallVector<Value> sourceIndices;
378 resolveSourceIndicesCollapseShape(loadOp.getLoc(), rewriter, collapseShapeOp,
379 loadOp.getIndices(), sourceIndices);
381 .Case([&](memref::LoadOp op) {
382 rewriter.replaceOpWithNewOp<memref::LoadOp>(
383 loadOp, collapseShapeOp.getViewSource(), sourceIndices,
384 op.getNontemporal());
385 })
386 .Case([&](vector::LoadOp op) {
387 rewriter.replaceOpWithNewOp<vector::LoadOp>(
388 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
389 op.getNontemporal());
390 })
391 .Case([&](vector::MaskedLoadOp op) {
392 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
393 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
394 op.getMask(), op.getPassThru());
395 })
396 .DefaultUnreachable("unexpected operation");
397 return success();
398}
399
400template <typename OpTy>
401LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
402 OpTy storeOp, PatternRewriter &rewriter) const {
403 auto subViewOp =
404 getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
405
406 if (!subViewOp)
407 return rewriter.notifyMatchFailure(storeOp, "not a subview producer");
408
409 LogicalResult preconditionResult =
410 preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp);
411 if (failed(preconditionResult))
412 return preconditionResult;
413
414 SmallVector<Value> sourceIndices;
416 rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
417 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
418 storeOp.getIndices(), sourceIndices);
419
421 .Case([&](memref::StoreOp op) {
422 rewriter.replaceOpWithNewOp<memref::StoreOp>(
423 op, op.getValue(), subViewOp.getSource(), sourceIndices,
424 op.getNontemporal());
425 })
426 .Case([&](vector::TransferWriteOp op) {
427 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
428 op, op.getValue(), subViewOp.getSource(), sourceIndices,
429 AffineMapAttr::get(expandDimsToRank(
430 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
431 subViewOp.getDroppedDims())),
432 op.getMask(), op.getInBoundsAttr());
433 })
434 .Case([&](vector::StoreOp op) {
435 rewriter.replaceOpWithNewOp<vector::StoreOp>(
436 op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
437 })
438 .Case([&](vector::MaskedStoreOp op) {
439 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
440 op, subViewOp.getSource(), sourceIndices, op.getMask(),
441 op.getValueToStore());
442 })
443 .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
444 rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
445 op, op.getSrc(), subViewOp.getSource(), sourceIndices,
446 op.getLeadDimension(), op.getTransposeAttr());
447 })
448 .DefaultUnreachable("unexpected operation");
449 return success();
450}
451
452template <typename OpTy>
453LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
454 OpTy storeOp, PatternRewriter &rewriter) const {
455 auto expandShapeOp =
456 getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>();
457
458 if (!expandShapeOp)
459 return failure();
460
461 SmallVector<Value> sourceIndices;
462 // memref.store guarantees that indexes start inbounds while the vector
463 // operations don't. This impacts if our linearization is `disjoint`
464 resolveSourceIndicesExpandShape(storeOp.getLoc(), rewriter, expandShapeOp,
465 storeOp.getIndices(), sourceIndices,
466 isa<memref::StoreOp>(storeOp.getOperation()));
468 .Case([&](memref::StoreOp op) {
469 rewriter.replaceOpWithNewOp<memref::StoreOp>(
470 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
471 sourceIndices, op.getNontemporal());
472 })
473 .Case([&](vector::StoreOp op) {
474 rewriter.replaceOpWithNewOp<vector::StoreOp>(
475 op, op.getValueToStore(), expandShapeOp.getViewSource(),
476 sourceIndices, op.getNontemporal());
477 })
478 .Case([&](vector::MaskedStoreOp op) {
479 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
480 op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
481 op.getValueToStore());
482 })
483 .DefaultUnreachable("unexpected operation");
484 return success();
485}
486
487template <typename OpTy>
488LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
489 OpTy storeOp, PatternRewriter &rewriter) const {
490 auto collapseShapeOp = getMemRefOperand(storeOp)
491 .template getDefiningOp<memref::CollapseShapeOp>();
492
493 if (!collapseShapeOp)
494 return failure();
495
496 SmallVector<Value> sourceIndices;
497 resolveSourceIndicesCollapseShape(storeOp.getLoc(), rewriter, collapseShapeOp,
498 storeOp.getIndices(), sourceIndices);
500 .Case([&](memref::StoreOp op) {
501 rewriter.replaceOpWithNewOp<memref::StoreOp>(
502 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
503 sourceIndices, op.getNontemporal());
504 })
505 .Case([&](vector::StoreOp op) {
506 rewriter.replaceOpWithNewOp<vector::StoreOp>(
507 op, op.getValueToStore(), collapseShapeOp.getViewSource(),
508 sourceIndices, op.getNontemporal());
509 })
510 .Case([&](vector::MaskedStoreOp op) {
511 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
512 op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
513 op.getValueToStore());
514 })
515 .DefaultUnreachable("unexpected operation");
516 return success();
517}
518
519LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
520 nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const {
521
522 LLVM_DEBUG(DBGS() << "copyOp : " << copyOp << "\n");
523
524 auto srcSubViewOp =
525 copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
526 auto dstSubViewOp =
527 copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
528
529 if (!(srcSubViewOp || dstSubViewOp))
530 return rewriter.notifyMatchFailure(copyOp, "does not use subview ops for "
531 "source or destination");
532
533 // If the source is a subview, we need to resolve the indices.
534 SmallVector<Value> foldedSrcIndices(copyOp.getSrcIndices().begin(),
535 copyOp.getSrcIndices().end());
536
537 if (srcSubViewOp) {
538 LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n");
539 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
540 rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
541 srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
542 copyOp.getSrcIndices(), foldedSrcIndices);
543 }
544
545 // If the destination is a subview, we need to resolve the indices.
546 SmallVector<Value> foldedDstIndices(copyOp.getDstIndices().begin(),
547 copyOp.getDstIndices().end());
548
549 if (dstSubViewOp) {
550 LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n");
551 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
552 rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
553 dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
554 copyOp.getDstIndices(), foldedDstIndices);
555 }
556
557 // Replace the copy op with a new copy op that uses the source and destination
558 // of the subview.
559 rewriter.replaceOpWithNewOp<nvgpu::DeviceAsyncCopyOp>(
560 copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()),
561 (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
562 foldedDstIndices,
563 (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
564 foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
565 copyOp.getBypassL1Attr());
566
567 return success();
568}
569
571 patterns.add<LoadOpOfSubViewOpFolder<memref::LoadOp>,
572 LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
573 LoadOpOfSubViewOpFolder<vector::LoadOp>,
574 LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
575 LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
576 LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
577 StoreOpOfSubViewOpFolder<memref::StoreOp>,
578 StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
579 StoreOpOfSubViewOpFolder<vector::StoreOp>,
580 StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
581 StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
582 LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
583 LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
584 LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
585 LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>,
586 StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
587 StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
588 StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
589 LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
590 LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
591 LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
592 StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
593 StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
594 StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
595 SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
596 patterns.getContext());
597}
598
599//===----------------------------------------------------------------------===//
600// Pass registration
601//===----------------------------------------------------------------------===//
602
603namespace {
604
605struct FoldMemRefAliasOpsPass final
606 : public memref::impl::FoldMemRefAliasOpsPassBase<FoldMemRefAliasOpsPass> {
607 void runOnOperation() override;
608};
609
610} // namespace
611
612void FoldMemRefAliasOpsPass::runOnOperation() {
613 RewritePatternSet patterns(&getContext());
615 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
616}
return success()
#define DBGS()
Definition Hoisting.cpp:32
b getContext())
static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, Operation *op, memref::SubViewOp subviewOp)
static LogicalResult preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp, memref::SubViewOp subviewOp)
static Value getMemRefOperand(LoadOrStoreOpTy op)
Helpers to access the memref operand for each op.
Base type for affine expression.
Definition AffineExpr.h:68
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
MLIRContext * getContext() const
Definition Builders.h:56
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
void resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...