MLIR 22.0.0git
FoldMemRefAliasOps.cpp
Go to the documentation of this file.
1//===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation pass folds loading/storing from/to subview ops into
10// loading/storing from/to the original memref.
11//
12//===----------------------------------------------------------------------===//
13
24#include "mlir/IR/AffineMap.h"
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/ADT/SmallBitVector.h"
28#include "llvm/ADT/TypeSwitch.h"
29#include "llvm/Support/Debug.h"
30
31#define DEBUG_TYPE "fold-memref-alias-ops"
32#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
33
34namespace mlir {
35namespace memref {
36#define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS
37#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
38} // namespace memref
39} // namespace mlir
40
41using namespace mlir;
42
43//===----------------------------------------------------------------------===//
44// Utility functions
45//===----------------------------------------------------------------------===//
46
47/// Helpers to access the memref operand for each op.
48template <typename LoadOrStoreOpTy>
49static Value getMemRefOperand(LoadOrStoreOpTy op) {
50 return op.getMemref();
51}
52
53static Value getMemRefOperand(vector::TransferReadOp op) {
54 return op.getBase();
55}
56
57static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
58 return op.getSrcMemref();
59}
60
61static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
62
63static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); }
64
65static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
66
67static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); }
68
69static Value getMemRefOperand(vector::TransferWriteOp op) {
70 return op.getBase();
71}
72
73static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) {
74 return op.getSrcMemref();
75}
76
77static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) {
78 return op.getDstMemref();
79}
80
81//===----------------------------------------------------------------------===//
82// Patterns
83//===----------------------------------------------------------------------===//
84
85namespace {
86/// Merges subview operation with load/transferRead operation.
87template <typename OpTy>
88class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
89public:
90 using OpRewritePattern<OpTy>::OpRewritePattern;
91
92 LogicalResult matchAndRewrite(OpTy loadOp,
93 PatternRewriter &rewriter) const override;
94};
95
96/// Merges expand_shape operation with load/transferRead operation.
97template <typename OpTy>
98class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
99public:
100 using OpRewritePattern<OpTy>::OpRewritePattern;
101
102 LogicalResult matchAndRewrite(OpTy loadOp,
103 PatternRewriter &rewriter) const override;
104};
105
106/// Merges collapse_shape operation with load/transferRead operation.
107template <typename OpTy>
108class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
109public:
110 using OpRewritePattern<OpTy>::OpRewritePattern;
111
112 LogicalResult matchAndRewrite(OpTy loadOp,
113 PatternRewriter &rewriter) const override;
114};
115
116/// Merges subview operation with store/transferWriteOp operation.
117template <typename OpTy>
118class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
119public:
120 using OpRewritePattern<OpTy>::OpRewritePattern;
121
122 LogicalResult matchAndRewrite(OpTy storeOp,
123 PatternRewriter &rewriter) const override;
124};
125
126/// Merges expand_shape operation with store/transferWriteOp operation.
127template <typename OpTy>
128class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
129public:
130 using OpRewritePattern<OpTy>::OpRewritePattern;
131
132 LogicalResult matchAndRewrite(OpTy storeOp,
133 PatternRewriter &rewriter) const override;
134};
135
136/// Merges collapse_shape operation with store/transferWriteOp operation.
137template <typename OpTy>
138class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
139public:
140 using OpRewritePattern<OpTy>::OpRewritePattern;
141
142 LogicalResult matchAndRewrite(OpTy storeOp,
143 PatternRewriter &rewriter) const override;
144};
145
146/// Folds subview(subview(x)) to a single subview(x).
147class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
148public:
149 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
150
151 LogicalResult matchAndRewrite(memref::SubViewOp subView,
152 PatternRewriter &rewriter) const override {
153 auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
154 if (!srcSubView)
155 return failure();
156
157 // TODO: relax unit stride assumption.
158 if (!subView.hasUnitStride()) {
159 return rewriter.notifyMatchFailure(subView, "requires unit strides");
160 }
161 if (!srcSubView.hasUnitStride()) {
162 return rewriter.notifyMatchFailure(srcSubView, "requires unit strides");
163 }
164
165 // Resolve sizes according to dropped dims.
166 SmallVector<OpFoldResult> resolvedSizes;
167 llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
168 affine::resolveSizesIntoOpWithSizes(srcSubView.getMixedSizes(),
169 subView.getMixedSizes(), srcDroppedDims,
170 resolvedSizes);
171
172 // Resolve offsets according to source offsets and strides.
173 SmallVector<Value> resolvedOffsets;
175 rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
176 srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
177 resolvedOffsets);
178
179 // Replace original op.
180 rewriter.replaceOpWithNewOp<memref::SubViewOp>(
181 subView, subView.getType(), srcSubView.getSource(),
182 getAsOpFoldResult(resolvedOffsets), resolvedSizes,
183 srcSubView.getMixedStrides());
184
185 return success();
186 }
187};
188
189/// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern
190/// is folds subview on src and dst memref of the copy.
191class NVGPUAsyncCopyOpSubViewOpFolder final
192 : public OpRewritePattern<nvgpu::DeviceAsyncCopyOp> {
193public:
194 using OpRewritePattern<nvgpu::DeviceAsyncCopyOp>::OpRewritePattern;
195
196 LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
197 PatternRewriter &rewriter) const override;
198};
199} // namespace
200
204 PatternRewriter &rewriter) {
205 SmallVector<OpFoldResult> indicesOfr(llvm::to_vector(
206 llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; })));
207 SmallVector<Value> expandedIndices;
208 for (unsigned i = 0, e = affineMap.getNumResults(); i < e; i++) {
210 rewriter, loc, affineMap.getSubMap({i}), indicesOfr);
211 expandedIndices.push_back(
212 getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
213 }
214 return expandedIndices;
215}
216
217template <typename XferOp>
218static LogicalResult
220 memref::SubViewOp subviewOp) {
221 static_assert(
222 !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
223 "must be a vector transfer op");
224 if (xferOp.hasOutOfBoundsDim())
225 return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
226 if (!subviewOp.hasUnitStride()) {
227 return rewriter.notifyMatchFailure(
228 xferOp, "non-1 stride subview, need to track strides in folded memref");
229 }
230 return success();
231}
232
233static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
234 Operation *op,
235 memref::SubViewOp subviewOp) {
236 return success();
237}
238
239static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
240 vector::TransferReadOp readOp,
241 memref::SubViewOp subviewOp) {
242 return preconditionsFoldSubViewOpImpl(rewriter, readOp, subviewOp);
243}
244
245static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
246 vector::TransferWriteOp writeOp,
247 memref::SubViewOp subviewOp) {
248 return preconditionsFoldSubViewOpImpl(rewriter, writeOp, subviewOp);
249}
250
251template <typename OpTy>
252LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
253 OpTy loadOp, PatternRewriter &rewriter) const {
254 auto subViewOp =
255 getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
256
257 if (!subViewOp)
258 return rewriter.notifyMatchFailure(loadOp, "not a subview producer");
259
260 LogicalResult preconditionResult =
261 preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp);
262 if (failed(preconditionResult))
263 return preconditionResult;
264
265 SmallVector<Value> indices(loadOp.getIndices().begin(),
266 loadOp.getIndices().end());
267 // For affine ops, we need to apply the map to get the operands to get the
268 // "actual" indices.
269 if (auto affineLoadOp =
270 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
271 AffineMap affineMap = affineLoadOp.getAffineMap();
272 auto expandedIndices = calculateExpandedAccessIndices(
273 affineMap, indices, loadOp.getLoc(), rewriter);
274 indices.assign(expandedIndices.begin(), expandedIndices.end());
275 }
276 SmallVector<Value> sourceIndices;
278 rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
279 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
280 sourceIndices);
281
283 .Case([&](affine::AffineLoadOp op) {
284 rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
285 loadOp, subViewOp.getSource(), sourceIndices);
286 })
287 .Case([&](memref::LoadOp op) {
288 rewriter.replaceOpWithNewOp<memref::LoadOp>(
289 loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
290 })
291 .Case([&](vector::LoadOp op) {
292 rewriter.replaceOpWithNewOp<vector::LoadOp>(
293 op, op.getType(), subViewOp.getSource(), sourceIndices);
294 })
295 .Case([&](vector::MaskedLoadOp op) {
296 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
297 op, op.getType(), subViewOp.getSource(), sourceIndices,
298 op.getMask(), op.getPassThru());
299 })
300 .Case([&](vector::TransferReadOp op) {
301 rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
302 op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
303 AffineMapAttr::get(expandDimsToRank(
304 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
305 subViewOp.getDroppedDims())),
306 op.getPadding(), op.getMask(), op.getInBoundsAttr());
307 })
308 .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
309 rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
310 op, op.getType(), subViewOp.getSource(), sourceIndices,
311 op.getLeadDimension(), op.getTransposeAttr());
312 })
313 .Case([&](nvgpu::LdMatrixOp op) {
314 rewriter.replaceOpWithNewOp<nvgpu::LdMatrixOp>(
315 op, op.getType(), subViewOp.getSource(), sourceIndices,
316 op.getTranspose(), op.getNumTiles());
317 })
318 .DefaultUnreachable("unexpected operation");
319 return success();
320}
321
322template <typename OpTy>
323LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
324 OpTy loadOp, PatternRewriter &rewriter) const {
325 auto expandShapeOp =
326 getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>();
327
328 if (!expandShapeOp)
329 return failure();
330
331 SmallVector<Value> indices(loadOp.getIndices().begin(),
332 loadOp.getIndices().end());
333 // For affine ops, we need to apply the map to get the operands to get the
334 // "actual" indices.
335 if (auto affineLoadOp =
336 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
337 AffineMap affineMap = affineLoadOp.getAffineMap();
338 auto expandedIndices = calculateExpandedAccessIndices(
339 affineMap, indices, loadOp.getLoc(), rewriter);
340 indices.assign(expandedIndices.begin(), expandedIndices.end());
341 }
342 SmallVector<Value> sourceIndices;
343 // memref.load and affine.load guarantee that indexes start inbounds
344 // while the vector operations don't. This impacts if our linearization
345 // is `disjoint`
347 loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
348 isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
349 return failure();
350
352 .Case([&](affine::AffineLoadOp op) {
353 rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
354 loadOp, expandShapeOp.getViewSource(), sourceIndices);
355 return success();
356 })
357 .Case([&](memref::LoadOp op) {
358 rewriter.replaceOpWithNewOp<memref::LoadOp>(
359 loadOp, expandShapeOp.getViewSource(), sourceIndices,
360 op.getNontemporal());
361 return success();
362 })
363 .Case([&](vector::LoadOp op) {
364 rewriter.replaceOpWithNewOp<vector::LoadOp>(
365 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
366 op.getNontemporal());
367 return success();
368 })
369 .Case([&](vector::MaskedLoadOp op) {
370 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
371 op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
372 op.getMask(), op.getPassThru());
373 return success();
374 })
375 .Case([&](vector::TransferReadOp op) {
376 // We only support minor identity maps in the permutation attribute.
377 if (!op.getPermutationMap().isMinorIdentity())
378 return failure();
380 // We only support the case where the source of the expand shape has
381 // rank greater than or equal to the vector rank.
382 const int64_t sourceRank = sourceIndices.size();
383 const int64_t vectorRank = op.getVectorType().getRank();
384 if (sourceRank < vectorRank)
385 return failure();
386
387 // We need to construct a new minor identity map since we will have lost
388 // some dimensions in folding away the expand shape.
389 auto minorIdMap = AffineMap::getMinorIdentityMap(sourceRank, vectorRank,
390 op.getContext());
391
392 rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
393 op, op.getVectorType(), expandShapeOp.getViewSource(),
394 sourceIndices, minorIdMap, op.getPadding(), op.getMask(),
395 op.getInBounds());
396 return success();
397 })
398 .DefaultUnreachable("unexpected operation");
399}
400
401template <typename OpTy>
402LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
403 OpTy loadOp, PatternRewriter &rewriter) const {
404 auto collapseShapeOp = getMemRefOperand(loadOp)
405 .template getDefiningOp<memref::CollapseShapeOp>();
406
407 if (!collapseShapeOp)
408 return failure();
410 SmallVector<Value> indices(loadOp.getIndices().begin(),
411 loadOp.getIndices().end());
412 // For affine ops, we need to apply the map to get the operands to get the
413 // "actual" indices.
414 if (auto affineLoadOp =
415 dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
416 AffineMap affineMap = affineLoadOp.getAffineMap();
417 auto expandedIndices = calculateExpandedAccessIndices(
418 affineMap, indices, loadOp.getLoc(), rewriter);
419 indices.assign(expandedIndices.begin(), expandedIndices.end());
420 }
421 SmallVector<Value> sourceIndices;
422 if (failed(resolveSourceIndicesCollapseShape(
423 loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
424 return failure();
426 .Case([&](affine::AffineLoadOp op) {
427 rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
428 loadOp, collapseShapeOp.getViewSource(), sourceIndices);
429 })
430 .Case([&](memref::LoadOp op) {
431 rewriter.replaceOpWithNewOp<memref::LoadOp>(
432 loadOp, collapseShapeOp.getViewSource(), sourceIndices,
433 op.getNontemporal());
434 })
435 .Case([&](vector::LoadOp op) {
436 rewriter.replaceOpWithNewOp<vector::LoadOp>(
437 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
438 op.getNontemporal());
439 })
440 .Case([&](vector::MaskedLoadOp op) {
441 rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
442 op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
443 op.getMask(), op.getPassThru());
444 })
445 .DefaultUnreachable("unexpected operation");
446 return success();
447}
448
449template <typename OpTy>
450LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
451 OpTy storeOp, PatternRewriter &rewriter) const {
452 auto subViewOp =
453 getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
454
455 if (!subViewOp)
456 return rewriter.notifyMatchFailure(storeOp, "not a subview producer");
457
458 LogicalResult preconditionResult =
459 preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp);
460 if (failed(preconditionResult))
461 return preconditionResult;
462
463 SmallVector<Value> indices(storeOp.getIndices().begin(),
464 storeOp.getIndices().end());
465 // For affine ops, we need to apply the map to get the operands to get the
466 // "actual" indices.
467 if (auto affineStoreOp =
468 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
469 AffineMap affineMap = affineStoreOp.getAffineMap();
470 auto expandedIndices = calculateExpandedAccessIndices(
471 affineMap, indices, storeOp.getLoc(), rewriter);
472 indices.assign(expandedIndices.begin(), expandedIndices.end());
473 }
474 SmallVector<Value> sourceIndices;
476 rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
477 subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
478 sourceIndices);
479
481 .Case([&](affine::AffineStoreOp op) {
482 rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
483 op, op.getValue(), subViewOp.getSource(), sourceIndices);
484 })
485 .Case([&](memref::StoreOp op) {
486 rewriter.replaceOpWithNewOp<memref::StoreOp>(
487 op, op.getValue(), subViewOp.getSource(), sourceIndices,
488 op.getNontemporal());
489 })
490 .Case([&](vector::TransferWriteOp op) {
491 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
492 op, op.getValue(), subViewOp.getSource(), sourceIndices,
493 AffineMapAttr::get(expandDimsToRank(
494 op.getPermutationMap(), subViewOp.getSourceType().getRank(),
495 subViewOp.getDroppedDims())),
496 op.getMask(), op.getInBoundsAttr());
497 })
498 .Case([&](vector::StoreOp op) {
499 rewriter.replaceOpWithNewOp<vector::StoreOp>(
500 op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
501 })
502 .Case([&](vector::MaskedStoreOp op) {
503 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
504 op, subViewOp.getSource(), sourceIndices, op.getMask(),
505 op.getValueToStore());
506 })
507 .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
508 rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
509 op, op.getSrc(), subViewOp.getSource(), sourceIndices,
510 op.getLeadDimension(), op.getTransposeAttr());
511 })
512 .DefaultUnreachable("unexpected operation");
513 return success();
514}
515
516template <typename OpTy>
517LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
518 OpTy storeOp, PatternRewriter &rewriter) const {
519 auto expandShapeOp =
520 getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>();
521
522 if (!expandShapeOp)
523 return failure();
524
525 SmallVector<Value> indices(storeOp.getIndices().begin(),
526 storeOp.getIndices().end());
527 // For affine ops, we need to apply the map to get the operands to get the
528 // "actual" indices.
529 if (auto affineStoreOp =
530 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
531 AffineMap affineMap = affineStoreOp.getAffineMap();
532 auto expandedIndices = calculateExpandedAccessIndices(
533 affineMap, indices, storeOp.getLoc(), rewriter);
534 indices.assign(expandedIndices.begin(), expandedIndices.end());
535 }
536 SmallVector<Value> sourceIndices;
537 // memref.store and affine.store guarantee that indexes start inbounds
538 // while the vector operations don't. This impacts if our linearization
539 // is `disjoint`
541 storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
542 isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation()))))
543 return failure();
545 .Case([&](affine::AffineStoreOp op) {
546 rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
547 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
548 sourceIndices);
549 })
550 .Case([&](memref::StoreOp op) {
551 rewriter.replaceOpWithNewOp<memref::StoreOp>(
552 storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
553 sourceIndices, op.getNontemporal());
554 })
555 .Case([&](vector::StoreOp op) {
556 rewriter.replaceOpWithNewOp<vector::StoreOp>(
557 op, op.getValueToStore(), expandShapeOp.getViewSource(),
558 sourceIndices, op.getNontemporal());
559 })
560 .Case([&](vector::MaskedStoreOp op) {
561 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
562 op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
563 op.getValueToStore());
564 })
565 .DefaultUnreachable("unexpected operation");
566 return success();
567}
568
569template <typename OpTy>
570LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
571 OpTy storeOp, PatternRewriter &rewriter) const {
572 auto collapseShapeOp = getMemRefOperand(storeOp)
573 .template getDefiningOp<memref::CollapseShapeOp>();
574
575 if (!collapseShapeOp)
576 return failure();
577
578 SmallVector<Value> indices(storeOp.getIndices().begin(),
579 storeOp.getIndices().end());
580 // For affine ops, we need to apply the map to get the operands to get the
581 // "actual" indices.
582 if (auto affineStoreOp =
583 dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
584 AffineMap affineMap = affineStoreOp.getAffineMap();
585 auto expandedIndices = calculateExpandedAccessIndices(
586 affineMap, indices, storeOp.getLoc(), rewriter);
587 indices.assign(expandedIndices.begin(), expandedIndices.end());
588 }
589 SmallVector<Value> sourceIndices;
591 storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
592 return failure();
594 .Case([&](affine::AffineStoreOp op) {
595 rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
596 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
597 sourceIndices);
598 })
599 .Case([&](memref::StoreOp op) {
600 rewriter.replaceOpWithNewOp<memref::StoreOp>(
601 storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
602 sourceIndices, op.getNontemporal());
603 })
604 .Case([&](vector::StoreOp op) {
605 rewriter.replaceOpWithNewOp<vector::StoreOp>(
606 op, op.getValueToStore(), collapseShapeOp.getViewSource(),
607 sourceIndices, op.getNontemporal());
608 })
609 .Case([&](vector::MaskedStoreOp op) {
610 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
611 op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
612 op.getValueToStore());
613 })
614 .DefaultUnreachable("unexpected operation");
615 return success();
616}
617
618LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
619 nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const {
620
621 LLVM_DEBUG(DBGS() << "copyOp : " << copyOp << "\n");
622
623 auto srcSubViewOp =
624 copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
625 auto dstSubViewOp =
626 copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
627
628 if (!(srcSubViewOp || dstSubViewOp))
629 return rewriter.notifyMatchFailure(copyOp, "does not use subview ops for "
630 "source or destination");
631
632 // If the source is a subview, we need to resolve the indices.
633 SmallVector<Value> srcindices(copyOp.getSrcIndices().begin(),
634 copyOp.getSrcIndices().end());
635 SmallVector<Value> foldedSrcIndices(srcindices);
636
637 if (srcSubViewOp) {
638 LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n");
640 rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
641 srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
642 srcindices, foldedSrcIndices);
643 }
644
645 // If the destination is a subview, we need to resolve the indices.
646 SmallVector<Value> dstindices(copyOp.getDstIndices().begin(),
647 copyOp.getDstIndices().end());
648 SmallVector<Value> foldedDstIndices(dstindices);
649
650 if (dstSubViewOp) {
651 LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n");
653 rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
654 dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
655 dstindices, foldedDstIndices);
656 }
657
658 // Replace the copy op with a new copy op that uses the source and destination
659 // of the subview.
660 rewriter.replaceOpWithNewOp<nvgpu::DeviceAsyncCopyOp>(
661 copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()),
662 (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
663 foldedDstIndices,
664 (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
665 foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
666 copyOp.getBypassL1Attr());
667
668 return success();
669}
670
672 patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>,
673 LoadOpOfSubViewOpFolder<memref::LoadOp>,
674 LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
675 LoadOpOfSubViewOpFolder<vector::LoadOp>,
676 LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
677 LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
678 LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
679 StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
680 StoreOpOfSubViewOpFolder<memref::StoreOp>,
681 StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
682 StoreOpOfSubViewOpFolder<vector::StoreOp>,
683 StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
684 StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
685 LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
686 LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
687 LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
688 LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
689 LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>,
690 StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
691 StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
692 StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
693 StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
694 LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>,
695 LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
696 LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
697 LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
698 StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
699 StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
700 StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
701 StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
702 SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
703 patterns.getContext());
704}
705
706//===----------------------------------------------------------------------===//
707// Pass registration
708//===----------------------------------------------------------------------===//
709
710namespace {
711
712struct FoldMemRefAliasOpsPass final
713 : public memref::impl::FoldMemRefAliasOpsPassBase<FoldMemRefAliasOpsPass> {
714 void runOnOperation() override;
715};
716
717} // namespace
718
719void FoldMemRefAliasOpsPass::runOnOperation() {
720 RewritePatternSet patterns(&getContext());
722 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
723}
return success()
static SmallVector< Value > calculateExpandedAccessIndices(AffineMap affineMap, const SmallVector< Value > &indices, Location loc, PatternRewriter &rewriter)
static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, Operation *op, memref::SubViewOp subviewOp)
static LogicalResult preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp, memref::SubViewOp subviewOp)
static Value getMemRefOperand(LoadOrStoreOpTy op)
Helpers to access the memref operand for each op.
#define DBGS()
Definition Hoisting.cpp:32
b getContext())
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
unsigned getNumResults() const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void resolveSizesIntoOpWithSizes(ArrayRef< OpFoldResult > sourceSizes, ArrayRef< OpFoldResult > destSizes, const llvm::SmallBitVector &rankReducedSourceDims, SmallVectorImpl< OpFoldResult > &resolvedSizes)
Given sourceSizes, destSizes and information about which dimensions are dropped by the source: rankRe...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
LogicalResult resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
LogicalResult resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...