MLIR 22.0.0git
FoldMemRefAliasOps.cpp
Go to the documentation of this file.
1//===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation pass folds loading/storing from/to subview ops into
10// loading/storing from/to the original memref.
11//
12//===----------------------------------------------------------------------===//
13
23#include "mlir/IR/AffineMap.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallBitVector.h"
27#include "llvm/ADT/TypeSwitch.h"
28#include "llvm/Support/Debug.h"
29
30#define DEBUG_TYPE "fold-memref-alias-ops"
31#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
32
33namespace mlir {
34namespace memref {
35#define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS
36#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
37} // namespace memref
38} // namespace mlir
39
40using namespace mlir;
41
42//===----------------------------------------------------------------------===//
43// Utility functions
44//===----------------------------------------------------------------------===//
45
46/// Helpers to access the memref operand for each op.
47template <typename LoadOrStoreOpTy>
48static Value getMemRefOperand(LoadOrStoreOpTy op) {
49 return op.getMemref();
50}
51
52static Value getMemRefOperand(vector::TransferReadOp op) {
53 return op.getBase();
54}
55
56static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
57 return op.getSrcMemref();
58}
59
60static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
61
62static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); }
63
64static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
65
66static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); }
67
68static Value getMemRefOperand(vector::TransferWriteOp op) {
69 return op.getBase();
70}
71
72static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) {
73 return op.getSrcMemref();
74}
75
76static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) {
77 return op.getDstMemref();
78}
79
80//===----------------------------------------------------------------------===//
81// Patterns
82//===----------------------------------------------------------------------===//
83
84namespace {
85/// Merges subview operation with load/transferRead operation.
86template <typename OpTy>
87class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
88public:
89 using OpRewritePattern<OpTy>::OpRewritePattern;
90
91 LogicalResult matchAndRewrite(OpTy loadOp,
92 PatternRewriter &rewriter) const override;
93};
94
95/// Merges expand_shape operation with load/transferRead operation.
96template <typename OpTy>
97class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
98public:
99 using OpRewritePattern<OpTy>::OpRewritePattern;
100
101 LogicalResult matchAndRewrite(OpTy loadOp,
102 PatternRewriter &rewriter) const override;
103};
104
105/// Merges collapse_shape operation with load/transferRead operation.
106template <typename OpTy>
107class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
108public:
109 using OpRewritePattern<OpTy>::OpRewritePattern;
110
111 LogicalResult matchAndRewrite(OpTy loadOp,
112 PatternRewriter &rewriter) const override;
113};
114
115/// Merges subview operation with store/transferWriteOp operation.
116template <typename OpTy>
117class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
118public:
119 using OpRewritePattern<OpTy>::OpRewritePattern;
120
121 LogicalResult matchAndRewrite(OpTy storeOp,
122 PatternRewriter &rewriter) const override;
123};
124
125/// Merges expand_shape operation with store/transferWriteOp operation.
126template <typename OpTy>
127class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
128public:
129 using OpRewritePattern<OpTy>::OpRewritePattern;
130
131 LogicalResult matchAndRewrite(OpTy storeOp,
132 PatternRewriter &rewriter) const override;
133};
134
135/// Merges collapse_shape operation with store/transferWriteOp operation.
136template <typename OpTy>
137class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
138public:
139 using OpRewritePattern<OpTy>::OpRewritePattern;
140
141 LogicalResult matchAndRewrite(OpTy storeOp,
142 PatternRewriter &rewriter) const override;
143};
144
145/// Folds subview(subview(x)) to a single subview(x).
146class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
147public:
148 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
149
150 LogicalResult matchAndRewrite(memref::SubViewOp subView,
151 PatternRewriter &rewriter) const override {
152 auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
153 if (!srcSubView)
154 return failure();
155
156 // TODO: relax unit stride assumption.
157 if (!subView.hasUnitStride()) {
158 return rewriter.notifyMatchFailure(subView, "requires unit strides");
159 }
160 if (!srcSubView.hasUnitStride()) {
161 return rewriter.notifyMatchFailure(srcSubView, "requires unit strides");
162 }
163
164 // Resolve sizes according to dropped dims.
165 SmallVector<OpFoldResult> resolvedSizes;
166 llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
167 affine::resolveSizesIntoOpWithSizes(srcSubView.getMixedSizes(),
168 subView.getMixedSizes(), srcDroppedDims,
169 resolvedSizes);
170
171 // Resolve offsets according to source offsets and strides.
172 SmallVector<Value> resolvedOffsets;
174 rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
175 srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
176 resolvedOffsets);
177
178 // Replace original op.
179 rewriter.replaceOpWithNewOp<memref::SubViewOp>(
180 subView, subView.getType(), srcSubView.getSource(),
181 getAsOpFoldResult(resolvedOffsets), resolvedSizes,
182 srcSubView.getMixedStrides());
183
184 return success();
185 }
186};
187
188/// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern
189/// is folds subview on src and dst memref of the copy.
190class NVGPUAsyncCopyOpSubViewOpFolder final
191 : public OpRewritePattern<nvgpu::DeviceAsyncCopyOp> {
192public:
193 using OpRewritePattern<nvgpu::DeviceAsyncCopyOp>::OpRewritePattern;
194
195 LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
196 PatternRewriter &rewriter) const override;
197};
198} // namespace
199
200template <typename XferOp>
201static LogicalResult
203 memref::SubViewOp subviewOp) {
204 static_assert(
205 !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
206 "must be a vector transfer op");
207 if (xferOp.hasOutOfBoundsDim())
208 return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
209 if (!subviewOp.hasUnitStride()) {
210 return rewriter.notifyMatchFailure(
211 xferOp, "non-1 stride subview, need to track strides in folded memref");
212 }
213 return success();
214}
215
216static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
217 Operation *op,
218 memref::SubViewOp subviewOp) {
219 return success();
220}
221
222static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
223 vector::TransferReadOp readOp,
224 memref::SubViewOp subviewOp) {
225 return preconditionsFoldSubViewOpImpl(rewriter, readOp, subviewOp);
226}
227
228static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
229 vector::TransferWriteOp writeOp,
230 memref::SubViewOp subviewOp) {
231 return preconditionsFoldSubViewOpImpl(rewriter, writeOp, subviewOp);
232}
233
234template <typename OpTy>
235LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
236 OpTy loadOp, PatternRewriter &rewriter) const {
237 auto subViewOp =
238 getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
239
240 if (!subViewOp)
241 return rewriter.notifyMatchFailure(loadOp, "not a subview producer");
242
243 LogicalResult preconditionResult =
244 preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp);
245 if (failed(preconditionResult))
246 return preconditionResult;
247
248 SmallVector<Value> sourceIndices;
250 rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
251 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
252 loadOp.getIndices(), sourceIndices);
253
255 .Case([&](memref::LoadOp op) {
256 rewriter.replaceOpWithNewOp<memref::LoadOp>(
257 loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
258 })
259 .Case([&](vector::LoadOp op) {
260 rewriter.replaceOpWithNewOp<vector::LoadOp>(
261 op, op.getType(), subViewOp.getSource(), sourceIndices);
262 })
263 .Case([&](vector::MaskedLoadOp op) {
264 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
265 op, op.getType(), subViewOp.getSource(), sourceIndices,
266 op.getMask(), op.getPassThru());
267 })
268 .Case([&](vector::TransferReadOp op) {
269 rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
270 op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
271 AffineMapAttr::get(expandDimsToRank(
272 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
273 subViewOp.getDroppedDims())),
274 op.getPadding(), op.getMask(), op.getInBoundsAttr());
275 })
276 .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
277 rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
278 op, op.getType(), subViewOp.getSource(), sourceIndices,
279 op.getLeadDimension(), op.getTransposeAttr());
280 })
281 .Case([&](nvgpu::LdMatrixOp op) {
282 rewriter.replaceOpWithNewOp<nvgpu::LdMatrixOp>(
283 op, op.getType(), subViewOp.getSource(), sourceIndices,
284 op.getTranspose(), op.getNumTiles());
285 })
286 .DefaultUnreachable("unexpected operation");
287 return success();
288}
289
290template <typename OpTy>
291LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
292 OpTy loadOp, PatternRewriter &rewriter) const {
293 auto expandShapeOp =
294 getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>();
295
296 if (!expandShapeOp)
297 return failure();
298
299 SmallVector<Value> sourceIndices;
300 // memref.load guarantees that indexes start inbounds while the vector
301 // operations don't. This impacts if our linearization is `disjoint`
302 resolveSourceIndicesExpandShape(loadOp.getLoc(), rewriter, expandShapeOp,
303 loadOp.getIndices(), sourceIndices,
304 isa<memref::LoadOp>(loadOp.getOperation()));
305
307 .Case([&](memref::LoadOp op) {
308 rewriter.replaceOpWithNewOp<memref::LoadOp>(
309 loadOp, expandShapeOp.getViewSource(), sourceIndices,
310 op.getNontemporal());
311 return success();
312 })
313 .Case([&](vector::LoadOp op) {
314 rewriter.replaceOpWithNewOp<vector::LoadOp>(
315 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
316 op.getNontemporal());
317 return success();
318 })
319 .Case([&](vector::MaskedLoadOp op) {
320 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
321 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
322 op.getMask(), op.getPassThru());
323 return success();
324 })
325 .Case([&](vector::TransferReadOp op) {
326 // We only support minor identity maps in the permutation attribute.
327 if (!op.getPermutationMap().isMinorIdentity())
328 return failure();
329
330 // We only support the case where the source of the expand shape has
331 // rank greater than or equal to the vector rank.
332 const int64_t sourceRank = sourceIndices.size();
333 const int64_t vectorRank = op.getVectorType().getRank();
334 if (sourceRank < vectorRank)
335 return failure();
336
337 // We need to construct a new minor identity map since we will have lost
338 // some dimensions in folding away the expand shape.
339 auto minorIdMap = AffineMap::getMinorIdentityMap(sourceRank, vectorRank,
340 op.getContext());
341
342 rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
343 op, op.getVectorType(), expandShapeOp.getViewSource(),
344 sourceIndices, minorIdMap, op.getPadding(), op.getMask(),
345 op.getInBounds());
346 return success();
347 })
348 .DefaultUnreachable("unexpected operation");
349}
350
351template <typename OpTy>
352LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
353 OpTy loadOp, PatternRewriter &rewriter) const {
354 auto collapseShapeOp = getMemRefOperand(loadOp)
355 .template getDefiningOp<memref::CollapseShapeOp>();
356
357 if (!collapseShapeOp)
358 return failure();
360 SmallVector<Value> sourceIndices;
361 resolveSourceIndicesCollapseShape(loadOp.getLoc(), rewriter, collapseShapeOp,
362 loadOp.getIndices(), sourceIndices);
364 .Case([&](memref::LoadOp op) {
365 rewriter.replaceOpWithNewOp<memref::LoadOp>(
366 loadOp, collapseShapeOp.getViewSource(), sourceIndices,
367 op.getNontemporal());
368 })
369 .Case([&](vector::LoadOp op) {
370 rewriter.replaceOpWithNewOp<vector::LoadOp>(
371 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
372 op.getNontemporal());
373 })
374 .Case([&](vector::MaskedLoadOp op) {
375 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
376 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
377 op.getMask(), op.getPassThru());
378 })
379 .DefaultUnreachable("unexpected operation");
380 return success();
381}
383template <typename OpTy>
384LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
385 OpTy storeOp, PatternRewriter &rewriter) const {
386 auto subViewOp =
387 getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
388
389 if (!subViewOp)
390 return rewriter.notifyMatchFailure(storeOp, "not a subview producer");
391
392 LogicalResult preconditionResult =
393 preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp);
394 if (failed(preconditionResult))
395 return preconditionResult;
396
397 SmallVector<Value> sourceIndices;
399 rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
400 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
401 storeOp.getIndices(), sourceIndices);
402
404 .Case([&](memref::StoreOp op) {
405 rewriter.replaceOpWithNewOp<memref::StoreOp>(
406 op, op.getValue(), subViewOp.getSource(), sourceIndices,
407 op.getNontemporal());
408 })
409 .Case([&](vector::TransferWriteOp op) {
410 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
411 op, op.getValue(), subViewOp.getSource(), sourceIndices,
412 AffineMapAttr::get(expandDimsToRank(
413 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
414 subViewOp.getDroppedDims())),
415 op.getMask(), op.getInBoundsAttr());
416 })
417 .Case([&](vector::StoreOp op) {
418 rewriter.replaceOpWithNewOp<vector::StoreOp>(
419 op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
420 })
421 .Case([&](vector::MaskedStoreOp op) {
422 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
423 op, subViewOp.getSource(), sourceIndices, op.getMask(),
424 op.getValueToStore());
425 })
426 .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
427 rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
428 op, op.getSrc(), subViewOp.getSource(), sourceIndices,
429 op.getLeadDimension(), op.getTransposeAttr());
430 })
431 .DefaultUnreachable("unexpected operation");
432 return success();
433}
434
435template <typename OpTy>
436LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
437 OpTy storeOp, PatternRewriter &rewriter) const {
438 auto expandShapeOp =
439 getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>();
440
441 if (!expandShapeOp)
442 return failure();
443
444 SmallVector<Value> sourceIndices;
445 // memref.store guarantees that indexes start inbounds while the vector
446 // operations don't. This impacts if our linearization is `disjoint`
447 resolveSourceIndicesExpandShape(storeOp.getLoc(), rewriter, expandShapeOp,
448 storeOp.getIndices(), sourceIndices,
449 isa<memref::StoreOp>(storeOp.getOperation()));
451 .Case([&](memref::StoreOp op) {
452 rewriter.replaceOpWithNewOp<memref::StoreOp>(
453 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
454 sourceIndices, op.getNontemporal());
455 })
456 .Case([&](vector::StoreOp op) {
457 rewriter.replaceOpWithNewOp<vector::StoreOp>(
458 op, op.getValueToStore(), expandShapeOp.getViewSource(),
459 sourceIndices, op.getNontemporal());
460 })
461 .Case([&](vector::MaskedStoreOp op) {
462 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
463 op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
464 op.getValueToStore());
465 })
466 .DefaultUnreachable("unexpected operation");
467 return success();
468}
469
470template <typename OpTy>
471LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
472 OpTy storeOp, PatternRewriter &rewriter) const {
473 auto collapseShapeOp = getMemRefOperand(storeOp)
474 .template getDefiningOp<memref::CollapseShapeOp>();
475
476 if (!collapseShapeOp)
477 return failure();
478
479 SmallVector<Value> sourceIndices;
480 resolveSourceIndicesCollapseShape(storeOp.getLoc(), rewriter, collapseShapeOp,
481 storeOp.getIndices(), sourceIndices);
483 .Case([&](memref::StoreOp op) {
484 rewriter.replaceOpWithNewOp<memref::StoreOp>(
485 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
486 sourceIndices, op.getNontemporal());
487 })
488 .Case([&](vector::StoreOp op) {
489 rewriter.replaceOpWithNewOp<vector::StoreOp>(
490 op, op.getValueToStore(), collapseShapeOp.getViewSource(),
491 sourceIndices, op.getNontemporal());
492 })
493 .Case([&](vector::MaskedStoreOp op) {
494 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
495 op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
496 op.getValueToStore());
497 })
498 .DefaultUnreachable("unexpected operation");
499 return success();
500}
501
502LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
503 nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const {
504
505 LLVM_DEBUG(DBGS() << "copyOp : " << copyOp << "\n");
506
507 auto srcSubViewOp =
508 copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
509 auto dstSubViewOp =
510 copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
511
512 if (!(srcSubViewOp || dstSubViewOp))
513 return rewriter.notifyMatchFailure(copyOp, "does not use subview ops for "
514 "source or destination");
515
516 // If the source is a subview, we need to resolve the indices.
517 SmallVector<Value> foldedSrcIndices(copyOp.getSrcIndices().begin(),
518 copyOp.getSrcIndices().end());
519
520 if (srcSubViewOp) {
521 LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n");
523 rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
524 srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
525 copyOp.getSrcIndices(), foldedSrcIndices);
526 }
527
528 // If the destination is a subview, we need to resolve the indices.
529 SmallVector<Value> foldedDstIndices(copyOp.getDstIndices().begin(),
530 copyOp.getDstIndices().end());
531
532 if (dstSubViewOp) {
533 LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n");
535 rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
536 dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
537 copyOp.getDstIndices(), foldedDstIndices);
538 }
539
540 // Replace the copy op with a new copy op that uses the source and destination
541 // of the subview.
542 rewriter.replaceOpWithNewOp<nvgpu::DeviceAsyncCopyOp>(
543 copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()),
544 (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
545 foldedDstIndices,
546 (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
547 foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
548 copyOp.getBypassL1Attr());
549
550 return success();
551}
552
554 patterns.add<LoadOpOfSubViewOpFolder<memref::LoadOp>,
555 LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
556 LoadOpOfSubViewOpFolder<vector::LoadOp>,
557 LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
558 LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
559 LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
560 StoreOpOfSubViewOpFolder<memref::StoreOp>,
561 StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
562 StoreOpOfSubViewOpFolder<vector::StoreOp>,
563 StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
564 StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
565 LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
566 LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
567 LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
568 LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>,
569 StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
570 StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
571 StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
572 LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
573 LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
574 LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
575 StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
576 StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
577 StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
578 SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
579 patterns.getContext());
580}
581
582//===----------------------------------------------------------------------===//
583// Pass registration
584//===----------------------------------------------------------------------===//
585
586namespace {
587
588struct FoldMemRefAliasOpsPass final
589 : public memref::impl::FoldMemRefAliasOpsPassBase<FoldMemRefAliasOpsPass> {
590 void runOnOperation() override;
591};
592
593} // namespace
594
595void FoldMemRefAliasOpsPass::runOnOperation() {
596 RewritePatternSet patterns(&getContext());
598 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
599}
return success()
#define DBGS()
Definition Hoisting.cpp:32
b getContext())
static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, Operation *op, memref::SubViewOp subviewOp)
static LogicalResult preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp, memref::SubViewOp subviewOp)
static Value getMemRefOperand(LoadOrStoreOpTy op)
Helpers to access the memref operand for each op.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void resolveSizesIntoOpWithSizes(ArrayRef< OpFoldResult > sourceSizes, ArrayRef< OpFoldResult > destSizes, const llvm::SmallBitVector &rankReducedSourceDims, SmallVectorImpl< OpFoldResult > &resolvedSizes)
Given sourceSizes, destSizes and information about which dimensions are dropped by the source: rankRe...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
void resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...