MLIR 22.0.0git
FoldMemRefAliasOps.cpp
Go to the documentation of this file.
1//===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation pass folds loading/storing from/to subview ops into
10// loading/storing from/to the original memref.
11//
12//===----------------------------------------------------------------------===//
13
24#include "mlir/IR/AffineMap.h"
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/ADT/SmallBitVector.h"
28#include "llvm/ADT/TypeSwitch.h"
29#include "llvm/Support/Debug.h"
30
31#define DEBUG_TYPE "fold-memref-alias-ops"
32#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
33
34namespace mlir {
35namespace memref {
36#define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS
37#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
38} // namespace memref
39} // namespace mlir
40
41using namespace mlir;
42
43//===----------------------------------------------------------------------===//
44// Utility functions
45//===----------------------------------------------------------------------===//
46
47/// Helpers to access the memref operand for each op.
48template <typename LoadOrStoreOpTy>
49static Value getMemRefOperand(LoadOrStoreOpTy op) {
50 return op.getMemref();
51}
52
53static Value getMemRefOperand(vector::TransferReadOp op) {
54 return op.getBase();
55}
56
57static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
58 return op.getSrcMemref();
59}
60
61static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
62
63static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); }
64
65static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
66
67static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); }
68
69static Value getMemRefOperand(vector::TransferWriteOp op) {
70 return op.getBase();
71}
72
73static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) {
74 return op.getSrcMemref();
75}
76
77static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) {
78 return op.getDstMemref();
79}
80
81//===----------------------------------------------------------------------===//
82// Patterns
83//===----------------------------------------------------------------------===//
84
85namespace {
86/// Merges subview operation with load/transferRead operation.
87template <typename OpTy>
88class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
89public:
90 using OpRewritePattern<OpTy>::OpRewritePattern;
91
92 LogicalResult matchAndRewrite(OpTy loadOp,
93 PatternRewriter &rewriter) const override;
94};
95
96/// Merges expand_shape operation with load/transferRead operation.
97template <typename OpTy>
98class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
99public:
100 using OpRewritePattern<OpTy>::OpRewritePattern;
101
102 LogicalResult matchAndRewrite(OpTy loadOp,
103 PatternRewriter &rewriter) const override;
104};
105
106/// Merges collapse_shape operation with load/transferRead operation.
107template <typename OpTy>
108class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
109public:
110 using OpRewritePattern<OpTy>::OpRewritePattern;
111
112 LogicalResult matchAndRewrite(OpTy loadOp,
113 PatternRewriter &rewriter) const override;
114};
115
116/// Merges subview operation with store/transferWriteOp operation.
117template <typename OpTy>
118class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
119public:
120 using OpRewritePattern<OpTy>::OpRewritePattern;
121
122 LogicalResult matchAndRewrite(OpTy storeOp,
123 PatternRewriter &rewriter) const override;
124};
125
126/// Merges expand_shape operation with store/transferWriteOp operation.
127template <typename OpTy>
128class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
129public:
130 using OpRewritePattern<OpTy>::OpRewritePattern;
131
132 LogicalResult matchAndRewrite(OpTy storeOp,
133 PatternRewriter &rewriter) const override;
134};
135
136/// Merges collapse_shape operation with store/transferWriteOp operation.
137template <typename OpTy>
138class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
139public:
140 using OpRewritePattern<OpTy>::OpRewritePattern;
141
142 LogicalResult matchAndRewrite(OpTy storeOp,
143 PatternRewriter &rewriter) const override;
144};
145
146/// Folds subview(subview(x)) to a single subview(x).
147class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
148public:
149 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
150
151 LogicalResult matchAndRewrite(memref::SubViewOp subView,
152 PatternRewriter &rewriter) const override {
153 auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
154 if (!srcSubView)
155 return failure();
156
157 // TODO: relax unit stride assumption.
158 if (!subView.hasUnitStride()) {
159 return rewriter.notifyMatchFailure(subView, "requires unit strides");
160 }
161 if (!srcSubView.hasUnitStride()) {
162 return rewriter.notifyMatchFailure(srcSubView, "requires unit strides");
163 }
164
165 // Resolve sizes according to dropped dims.
166 SmallVector<OpFoldResult> resolvedSizes;
167 llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
168 affine::resolveSizesIntoOpWithSizes(srcSubView.getMixedSizes(),
169 subView.getMixedSizes(), srcDroppedDims,
170 resolvedSizes);
171
172 // Resolve offsets according to source offsets and strides.
173 SmallVector<Value> resolvedOffsets;
175 rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
176 srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
177 resolvedOffsets);
178
179 // Replace original op.
180 rewriter.replaceOpWithNewOp<memref::SubViewOp>(
181 subView, subView.getType(), srcSubView.getSource(),
182 getAsOpFoldResult(resolvedOffsets), resolvedSizes,
183 srcSubView.getMixedStrides());
184
185 return success();
186 }
187};
188
189/// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern
190/// is folds subview on src and dst memref of the copy.
191class NVGPUAsyncCopyOpSubViewOpFolder final
192 : public OpRewritePattern<nvgpu::DeviceAsyncCopyOp> {
193public:
194 using OpRewritePattern<nvgpu::DeviceAsyncCopyOp>::OpRewritePattern;
195
196 LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
197 PatternRewriter &rewriter) const override;
198};
199} // namespace
200
204 PatternRewriter &rewriter) {
205 SmallVector<OpFoldResult> indicesOfr(llvm::to_vector(
206 llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; })));
207 SmallVector<Value> expandedIndices;
208 for (unsigned i = 0, e = affineMap.getNumResults(); i < e; i++) {
210 rewriter, loc, affineMap.getSubMap({i}), indicesOfr);
211 expandedIndices.push_back(
212 getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
213 }
214 return expandedIndices;
215}
216
217template <typename XferOp>
218static LogicalResult
220 memref::SubViewOp subviewOp) {
221 static_assert(
222 !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
223 "must be a vector transfer op");
224 if (xferOp.hasOutOfBoundsDim())
225 return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
226 if (!subviewOp.hasUnitStride()) {
227 return rewriter.notifyMatchFailure(
228 xferOp, "non-1 stride subview, need to track strides in folded memref");
229 }
230 return success();
231}
232
233static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
234 Operation *op,
235 memref::SubViewOp subviewOp) {
236 return success();
237}
238
239static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
240 vector::TransferReadOp readOp,
241 memref::SubViewOp subviewOp) {
242 return preconditionsFoldSubViewOpImpl(rewriter, readOp, subviewOp);
243}
244
245static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
246 vector::TransferWriteOp writeOp,
247 memref::SubViewOp subviewOp) {
248 return preconditionsFoldSubViewOpImpl(rewriter, writeOp, subviewOp);
249}
250
251template <typename OpTy>
252LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
253 OpTy loadOp, PatternRewriter &rewriter) const {
254 auto subViewOp =
255 getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
256
257 if (!subViewOp)
258 return rewriter.notifyMatchFailure(loadOp, "not a subview producer");
259
260 LogicalResult preconditionResult =
261 preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp);
262 if (failed(preconditionResult))
263 return preconditionResult;
264
265 SmallVector<Value> indices(loadOp.getIndices().begin(),
266 loadOp.getIndices().end());
267 // For affine ops, we need to apply the map to get the operands to get the
268 // "actual" indices.
269 if (auto affineLoadOp =
270 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
271 AffineMap affineMap = affineLoadOp.getAffineMap();
272 auto expandedIndices = calculateExpandedAccessIndices(
273 affineMap, indices, loadOp.getLoc(), rewriter);
274 indices.assign(expandedIndices.begin(), expandedIndices.end());
275 }
276 SmallVector<Value> sourceIndices;
278 rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
279 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
280 sourceIndices);
281
283 .Case([&](affine::AffineLoadOp op) {
284 rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
285 loadOp, subViewOp.getSource(), sourceIndices);
286 })
287 .Case([&](memref::LoadOp op) {
288 rewriter.replaceOpWithNewOp<memref::LoadOp>(
289 loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
290 })
291 .Case([&](vector::LoadOp op) {
292 rewriter.replaceOpWithNewOp<vector::LoadOp>(
293 op, op.getType(), subViewOp.getSource(), sourceIndices);
294 })
295 .Case([&](vector::MaskedLoadOp op) {
296 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
297 op, op.getType(), subViewOp.getSource(), sourceIndices,
298 op.getMask(), op.getPassThru());
299 })
300 .Case([&](vector::TransferReadOp op) {
301 rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
302 op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
303 AffineMapAttr::get(expandDimsToRank(
304 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
305 subViewOp.getDroppedDims())),
306 op.getPadding(), op.getMask(), op.getInBoundsAttr());
307 })
308 .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
309 rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
310 op, op.getType(), subViewOp.getSource(), sourceIndices,
311 op.getLeadDimension(), op.getTransposeAttr());
312 })
313 .Case([&](nvgpu::LdMatrixOp op) {
314 rewriter.replaceOpWithNewOp<nvgpu::LdMatrixOp>(
315 op, op.getType(), subViewOp.getSource(), sourceIndices,
316 op.getTranspose(), op.getNumTiles());
317 })
318 .DefaultUnreachable("unexpected operation");
319 return success();
320}
321
322template <typename OpTy>
323LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
324 OpTy loadOp, PatternRewriter &rewriter) const {
325 auto expandShapeOp =
326 getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>();
327
328 if (!expandShapeOp)
329 return failure();
330
331 SmallVector<Value> indices(loadOp.getIndices().begin(),
332 loadOp.getIndices().end());
333 // For affine ops, we need to apply the map to get the operands to get the
334 // "actual" indices.
335 if (auto affineLoadOp =
336 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
337 AffineMap affineMap = affineLoadOp.getAffineMap();
338 auto expandedIndices = calculateExpandedAccessIndices(
339 affineMap, indices, loadOp.getLoc(), rewriter);
340 indices.assign(expandedIndices.begin(), expandedIndices.end());
341 }
342 SmallVector<Value> sourceIndices;
343 // memref.load and affine.load guarantee that indexes start inbounds
344 // while the vector operations don't. This impacts if our linearization
345 // is `disjoint`
347 loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
348 isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
349 return failure();
351 .Case([&](affine::AffineLoadOp op) {
352 rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
353 loadOp, expandShapeOp.getViewSource(), sourceIndices);
354 })
355 .Case([&](memref::LoadOp op) {
356 rewriter.replaceOpWithNewOp<memref::LoadOp>(
357 loadOp, expandShapeOp.getViewSource(), sourceIndices,
358 op.getNontemporal());
359 })
360 .Case([&](vector::LoadOp op) {
361 rewriter.replaceOpWithNewOp<vector::LoadOp>(
362 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
363 op.getNontemporal());
364 })
365 .Case([&](vector::MaskedLoadOp op) {
366 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
367 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
368 op.getMask(), op.getPassThru());
369 })
370 .DefaultUnreachable("unexpected operation");
371 return success();
372}
373
374template <typename OpTy>
375LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
376 OpTy loadOp, PatternRewriter &rewriter) const {
377 auto collapseShapeOp = getMemRefOperand(loadOp)
378 .template getDefiningOp<memref::CollapseShapeOp>();
380 if (!collapseShapeOp)
381 return failure();
383 SmallVector<Value> indices(loadOp.getIndices().begin(),
384 loadOp.getIndices().end());
385 // For affine ops, we need to apply the map to get the operands to get the
386 // "actual" indices.
387 if (auto affineLoadOp =
388 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
389 AffineMap affineMap = affineLoadOp.getAffineMap();
390 auto expandedIndices = calculateExpandedAccessIndices(
391 affineMap, indices, loadOp.getLoc(), rewriter);
392 indices.assign(expandedIndices.begin(), expandedIndices.end());
393 }
394 SmallVector<Value> sourceIndices;
396 loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
397 return failure();
399 .Case([&](affine::AffineLoadOp op) {
400 rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
401 loadOp, collapseShapeOp.getViewSource(), sourceIndices);
402 })
403 .Case([&](memref::LoadOp op) {
404 rewriter.replaceOpWithNewOp<memref::LoadOp>(
405 loadOp, collapseShapeOp.getViewSource(), sourceIndices,
406 op.getNontemporal());
407 })
408 .Case([&](vector::LoadOp op) {
409 rewriter.replaceOpWithNewOp<vector::LoadOp>(
410 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
411 op.getNontemporal());
412 })
413 .Case([&](vector::MaskedLoadOp op) {
414 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
415 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
416 op.getMask(), op.getPassThru());
417 })
418 .DefaultUnreachable("unexpected operation");
419 return success();
420}
421
422template <typename OpTy>
423LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
424 OpTy storeOp, PatternRewriter &rewriter) const {
425 auto subViewOp =
426 getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
427
428 if (!subViewOp)
429 return rewriter.notifyMatchFailure(storeOp, "not a subview producer");
430
431 LogicalResult preconditionResult =
432 preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp);
433 if (failed(preconditionResult))
434 return preconditionResult;
435
436 SmallVector<Value> indices(storeOp.getIndices().begin(),
437 storeOp.getIndices().end());
438 // For affine ops, we need to apply the map to get the operands to get the
439 // "actual" indices.
440 if (auto affineStoreOp =
441 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
442 AffineMap affineMap = affineStoreOp.getAffineMap();
443 auto expandedIndices = calculateExpandedAccessIndices(
444 affineMap, indices, storeOp.getLoc(), rewriter);
445 indices.assign(expandedIndices.begin(), expandedIndices.end());
446 }
447 SmallVector<Value> sourceIndices;
449 rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
450 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
451 sourceIndices);
452
454 .Case([&](affine::AffineStoreOp op) {
455 rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
456 op, op.getValue(), subViewOp.getSource(), sourceIndices);
457 })
458 .Case([&](memref::StoreOp op) {
459 rewriter.replaceOpWithNewOp<memref::StoreOp>(
460 op, op.getValue(), subViewOp.getSource(), sourceIndices,
461 op.getNontemporal());
462 })
463 .Case([&](vector::TransferWriteOp op) {
464 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
465 op, op.getValue(), subViewOp.getSource(), sourceIndices,
466 AffineMapAttr::get(expandDimsToRank(
467 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
468 subViewOp.getDroppedDims())),
469 op.getMask(), op.getInBoundsAttr());
470 })
471 .Case([&](vector::StoreOp op) {
472 rewriter.replaceOpWithNewOp<vector::StoreOp>(
473 op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
474 })
475 .Case([&](vector::MaskedStoreOp op) {
476 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
477 op, subViewOp.getSource(), sourceIndices, op.getMask(),
478 op.getValueToStore());
479 })
480 .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
481 rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
482 op, op.getSrc(), subViewOp.getSource(), sourceIndices,
483 op.getLeadDimension(), op.getTransposeAttr());
484 })
485 .DefaultUnreachable("unexpected operation");
486 return success();
487}
488
489template <typename OpTy>
490LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
491 OpTy storeOp, PatternRewriter &rewriter) const {
492 auto expandShapeOp =
493 getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>();
494
495 if (!expandShapeOp)
496 return failure();
497
498 SmallVector<Value> indices(storeOp.getIndices().begin(),
499 storeOp.getIndices().end());
500 // For affine ops, we need to apply the map to get the operands to get the
501 // "actual" indices.
502 if (auto affineStoreOp =
503 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
504 AffineMap affineMap = affineStoreOp.getAffineMap();
505 auto expandedIndices = calculateExpandedAccessIndices(
506 affineMap, indices, storeOp.getLoc(), rewriter);
507 indices.assign(expandedIndices.begin(), expandedIndices.end());
508 }
509 SmallVector<Value> sourceIndices;
510 // memref.store and affine.store guarantee that indexes start inbounds
511 // while the vector operations don't. This impacts if our linearization
512 // is `disjoint`
514 storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
515 isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation()))))
516 return failure();
518 .Case([&](affine::AffineStoreOp op) {
519 rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
520 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
521 sourceIndices);
522 })
523 .Case([&](memref::StoreOp op) {
524 rewriter.replaceOpWithNewOp<memref::StoreOp>(
525 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
526 sourceIndices, op.getNontemporal());
527 })
528 .Case([&](vector::StoreOp op) {
529 rewriter.replaceOpWithNewOp<vector::StoreOp>(
530 op, op.getValueToStore(), expandShapeOp.getViewSource(),
531 sourceIndices, op.getNontemporal());
532 })
533 .Case([&](vector::MaskedStoreOp op) {
534 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
535 op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
536 op.getValueToStore());
537 })
538 .DefaultUnreachable("unexpected operation");
539 return success();
540}
541
542template <typename OpTy>
543LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
544 OpTy storeOp, PatternRewriter &rewriter) const {
545 auto collapseShapeOp = getMemRefOperand(storeOp)
546 .template getDefiningOp<memref::CollapseShapeOp>();
547
548 if (!collapseShapeOp)
549 return failure();
550
551 SmallVector<Value> indices(storeOp.getIndices().begin(),
552 storeOp.getIndices().end());
553 // For affine ops, we need to apply the map to get the operands to get the
554 // "actual" indices.
555 if (auto affineStoreOp =
556 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
557 AffineMap affineMap = affineStoreOp.getAffineMap();
558 auto expandedIndices = calculateExpandedAccessIndices(
559 affineMap, indices, storeOp.getLoc(), rewriter);
560 indices.assign(expandedIndices.begin(), expandedIndices.end());
561 }
562 SmallVector<Value> sourceIndices;
564 storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
565 return failure();
567 .Case([&](affine::AffineStoreOp op) {
568 rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
569 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
570 sourceIndices);
571 })
572 .Case([&](memref::StoreOp op) {
573 rewriter.replaceOpWithNewOp<memref::StoreOp>(
574 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
575 sourceIndices, op.getNontemporal());
576 })
577 .Case([&](vector::StoreOp op) {
578 rewriter.replaceOpWithNewOp<vector::StoreOp>(
579 op, op.getValueToStore(), collapseShapeOp.getViewSource(),
580 sourceIndices, op.getNontemporal());
581 })
582 .Case([&](vector::MaskedStoreOp op) {
583 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
584 op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
585 op.getValueToStore());
586 })
587 .DefaultUnreachable("unexpected operation");
588 return success();
589}
590
591LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
592 nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const {
593
594 LLVM_DEBUG(DBGS() << "copyOp : " << copyOp << "\n");
595
596 auto srcSubViewOp =
597 copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
598 auto dstSubViewOp =
599 copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
600
601 if (!(srcSubViewOp || dstSubViewOp))
602 return rewriter.notifyMatchFailure(copyOp, "does not use subview ops for "
603 "source or destination");
604
605 // If the source is a subview, we need to resolve the indices.
606 SmallVector<Value> srcindices(copyOp.getSrcIndices().begin(),
607 copyOp.getSrcIndices().end());
608 SmallVector<Value> foldedSrcIndices(srcindices);
609
610 if (srcSubViewOp) {
611 LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n");
613 rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
614 srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
615 srcindices, foldedSrcIndices);
616 }
617
618 // If the destination is a subview, we need to resolve the indices.
619 SmallVector<Value> dstindices(copyOp.getDstIndices().begin(),
620 copyOp.getDstIndices().end());
621 SmallVector<Value> foldedDstIndices(dstindices);
622
623 if (dstSubViewOp) {
624 LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n");
626 rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
627 dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
628 dstindices, foldedDstIndices);
629 }
630
631 // Replace the copy op with a new copy op that uses the source and destination
632 // of the subview.
633 rewriter.replaceOpWithNewOp<nvgpu::DeviceAsyncCopyOp>(
634 copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()),
635 (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
636 foldedDstIndices,
637 (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
638 foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
639 copyOp.getBypassL1Attr());
640
641 return success();
642}
643
645 patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>,
646 LoadOpOfSubViewOpFolder<memref::LoadOp>,
647 LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
648 LoadOpOfSubViewOpFolder<vector::LoadOp>,
649 LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
650 LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
651 LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
652 StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
653 StoreOpOfSubViewOpFolder<memref::StoreOp>,
654 StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
655 StoreOpOfSubViewOpFolder<vector::StoreOp>,
656 StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
657 StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
658 LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
659 LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
660 LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
661 LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
662 StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
663 StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
664 StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
665 StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
666 LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>,
667 LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
668 LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
669 LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
670 StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
671 StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
672 StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
673 StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
674 SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
675 patterns.getContext());
676}
677
678//===----------------------------------------------------------------------===//
679// Pass registration
680//===----------------------------------------------------------------------===//
681
682namespace {
683
684struct FoldMemRefAliasOpsPass final
685 : public memref::impl::FoldMemRefAliasOpsPassBase<FoldMemRefAliasOpsPass> {
686 void runOnOperation() override;
687};
688
689} // namespace
690
691void FoldMemRefAliasOpsPass::runOnOperation() {
692 RewritePatternSet patterns(&getContext());
694 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
695}
return success()
static SmallVector< Value > calculateExpandedAccessIndices(AffineMap affineMap, const SmallVector< Value > &indices, Location loc, PatternRewriter &rewriter)
static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, Operation *op, memref::SubViewOp subviewOp)
static LogicalResult preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp, memref::SubViewOp subviewOp)
static Value getMemRefOperand(LoadOrStoreOpTy op)
Helpers to access the memref operand for each op.
#define DBGS()
Definition Hoisting.cpp:32
b getContext())
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
unsigned getNumResults() const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void resolveSizesIntoOpWithSizes(ArrayRef< OpFoldResult > sourceSizes, ArrayRef< OpFoldResult > destSizes, const llvm::SmallBitVector &rankReducedSourceDims, SmallVectorImpl< OpFoldResult > &resolvedSizes)
Given sourceSizes, destSizes and information about which dimensions are dropped by the source: rankRe...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
LogicalResult resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
LogicalResult resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...