MLIR  22.0.0git
FoldMemRefAliasOps.cpp
Go to the documentation of this file.
1 //===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass folds loading/storing from/to subview ops into
10 // loading/storing from/to the original memref.
11 //
12 //===----------------------------------------------------------------------===//
13 
24 #include "mlir/IR/AffineMap.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallBitVector.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/Debug.h"
30 
31 #define DEBUG_TYPE "fold-memref-alias-ops"
32 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
33 
34 namespace mlir {
35 namespace memref {
36 #define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS
37 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
38 } // namespace memref
39 } // namespace mlir
40 
41 using namespace mlir;
42 
43 //===----------------------------------------------------------------------===//
44 // Utility functions
45 //===----------------------------------------------------------------------===//
46 
47 /// Helpers to access the memref operand for each op.
48 template <typename LoadOrStoreOpTy>
49 static Value getMemRefOperand(LoadOrStoreOpTy op) {
50  return op.getMemref();
51 }
52 
53 static Value getMemRefOperand(vector::TransferReadOp op) {
54  return op.getBase();
55 }
56 
57 static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
58  return op.getSrcMemref();
59 }
60 
61 static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
62 
63 static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); }
64 
65 static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
66 
67 static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); }
68 
69 static Value getMemRefOperand(vector::TransferWriteOp op) {
70  return op.getBase();
71 }
72 
73 static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) {
74  return op.getSrcMemref();
75 }
76 
77 static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) {
78  return op.getDstMemref();
79 }
80 
81 //===----------------------------------------------------------------------===//
82 // Patterns
83 //===----------------------------------------------------------------------===//
84 
85 namespace {
86 /// Merges subview operation with load/transferRead operation.
87 template <typename OpTy>
88 class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
89 public:
91 
92  LogicalResult matchAndRewrite(OpTy loadOp,
93  PatternRewriter &rewriter) const override;
94 };
95 
96 /// Merges expand_shape operation with load/transferRead operation.
97 template <typename OpTy>
98 class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
99 public:
101 
102  LogicalResult matchAndRewrite(OpTy loadOp,
103  PatternRewriter &rewriter) const override;
104 };
105 
106 /// Merges collapse_shape operation with load/transferRead operation.
107 template <typename OpTy>
108 class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
109 public:
111 
112  LogicalResult matchAndRewrite(OpTy loadOp,
113  PatternRewriter &rewriter) const override;
114 };
115 
116 /// Merges subview operation with store/transferWriteOp operation.
117 template <typename OpTy>
118 class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
119 public:
121 
122  LogicalResult matchAndRewrite(OpTy storeOp,
123  PatternRewriter &rewriter) const override;
124 };
125 
126 /// Merges expand_shape operation with store/transferWriteOp operation.
127 template <typename OpTy>
128 class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
129 public:
131 
132  LogicalResult matchAndRewrite(OpTy storeOp,
133  PatternRewriter &rewriter) const override;
134 };
135 
136 /// Merges collapse_shape operation with store/transferWriteOp operation.
137 template <typename OpTy>
138 class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
139 public:
141 
142  LogicalResult matchAndRewrite(OpTy storeOp,
143  PatternRewriter &rewriter) const override;
144 };
145 
146 /// Folds subview(subview(x)) to a single subview(x).
147 class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
148 public:
150 
151  LogicalResult matchAndRewrite(memref::SubViewOp subView,
152  PatternRewriter &rewriter) const override {
153  auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
154  if (!srcSubView)
155  return failure();
156 
157  // TODO: relax unit stride assumption.
158  if (!subView.hasUnitStride()) {
159  return rewriter.notifyMatchFailure(subView, "requires unit strides");
160  }
161  if (!srcSubView.hasUnitStride()) {
162  return rewriter.notifyMatchFailure(srcSubView, "requires unit strides");
163  }
164 
165  // Resolve sizes according to dropped dims.
166  SmallVector<OpFoldResult> resolvedSizes;
167  llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
168  affine::resolveSizesIntoOpWithSizes(srcSubView.getMixedSizes(),
169  subView.getMixedSizes(), srcDroppedDims,
170  resolvedSizes);
171 
172  // Resolve offsets according to source offsets and strides.
173  SmallVector<Value> resolvedOffsets;
175  rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
176  srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
177  resolvedOffsets);
178 
179  // Replace original op.
180  rewriter.replaceOpWithNewOp<memref::SubViewOp>(
181  subView, subView.getType(), srcSubView.getSource(),
182  getAsOpFoldResult(resolvedOffsets), resolvedSizes,
183  srcSubView.getMixedStrides());
184 
185  return success();
186  }
187 };
188 
189 /// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern
190 /// is folds subview on src and dst memref of the copy.
191 class NVGPUAsyncCopyOpSubViewOpFolder final
192  : public OpRewritePattern<nvgpu::DeviceAsyncCopyOp> {
193 public:
195 
196  LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
197  PatternRewriter &rewriter) const override;
198 };
199 } // namespace
200 
201 static SmallVector<Value>
203  const SmallVector<Value> &indices, Location loc,
204  PatternRewriter &rewriter) {
205  SmallVector<OpFoldResult> indicesOfr(llvm::to_vector(
206  llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; })));
207  SmallVector<Value> expandedIndices;
208  for (unsigned i = 0, e = affineMap.getNumResults(); i < e; i++) {
210  rewriter, loc, affineMap.getSubMap({i}), indicesOfr);
211  expandedIndices.push_back(
212  getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
213  }
214  return expandedIndices;
215 }
216 
217 template <typename XferOp>
218 static LogicalResult
220  memref::SubViewOp subviewOp) {
221  static_assert(
222  !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
223  "must be a vector transfer op");
224  if (xferOp.hasOutOfBoundsDim())
225  return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
226  if (!subviewOp.hasUnitStride()) {
227  return rewriter.notifyMatchFailure(
228  xferOp, "non-1 stride subview, need to track strides in folded memref");
229  }
230  return success();
231 }
232 
233 static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
234  Operation *op,
235  memref::SubViewOp subviewOp) {
236  return success();
237 }
238 
239 static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
240  vector::TransferReadOp readOp,
241  memref::SubViewOp subviewOp) {
242  return preconditionsFoldSubViewOpImpl(rewriter, readOp, subviewOp);
243 }
244 
245 static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
246  vector::TransferWriteOp writeOp,
247  memref::SubViewOp subviewOp) {
248  return preconditionsFoldSubViewOpImpl(rewriter, writeOp, subviewOp);
249 }
250 
251 template <typename OpTy>
252 LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
253  OpTy loadOp, PatternRewriter &rewriter) const {
254  auto subViewOp =
255  getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
256 
257  if (!subViewOp)
258  return rewriter.notifyMatchFailure(loadOp, "not a subview producer");
259 
260  LogicalResult preconditionResult =
261  preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp);
262  if (failed(preconditionResult))
263  return preconditionResult;
264 
265  SmallVector<Value> indices(loadOp.getIndices().begin(),
266  loadOp.getIndices().end());
267  // For affine ops, we need to apply the map to get the operands to get the
268  // "actual" indices.
269  if (auto affineLoadOp =
270  dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
271  AffineMap affineMap = affineLoadOp.getAffineMap();
272  auto expandedIndices = calculateExpandedAccessIndices(
273  affineMap, indices, loadOp.getLoc(), rewriter);
274  indices.assign(expandedIndices.begin(), expandedIndices.end());
275  }
276  SmallVector<Value> sourceIndices;
278  rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
279  subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
280  sourceIndices);
281 
283  .Case([&](affine::AffineLoadOp op) {
284  rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
285  loadOp, subViewOp.getSource(), sourceIndices);
286  })
287  .Case([&](memref::LoadOp op) {
288  rewriter.replaceOpWithNewOp<memref::LoadOp>(
289  loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
290  })
291  .Case([&](vector::LoadOp op) {
292  rewriter.replaceOpWithNewOp<vector::LoadOp>(
293  op, op.getType(), subViewOp.getSource(), sourceIndices);
294  })
295  .Case([&](vector::MaskedLoadOp op) {
296  rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
297  op, op.getType(), subViewOp.getSource(), sourceIndices,
298  op.getMask(), op.getPassThru());
299  })
300  .Case([&](vector::TransferReadOp op) {
301  rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
302  op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
304  op.getPermutationMap(), subViewOp.getSourceType().getRank(),
305  subViewOp.getDroppedDims())),
306  op.getPadding(), op.getMask(), op.getInBoundsAttr());
307  })
308  .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
309  rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
310  op, op.getType(), subViewOp.getSource(), sourceIndices,
311  op.getLeadDimension(), op.getTransposeAttr());
312  })
313  .Case([&](nvgpu::LdMatrixOp op) {
314  rewriter.replaceOpWithNewOp<nvgpu::LdMatrixOp>(
315  op, op.getType(), subViewOp.getSource(), sourceIndices,
316  op.getTranspose(), op.getNumTiles());
317  })
318  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
319  return success();
320 }
321 
322 template <typename OpTy>
323 LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
324  OpTy loadOp, PatternRewriter &rewriter) const {
325  auto expandShapeOp =
326  getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>();
327 
328  if (!expandShapeOp)
329  return failure();
330 
331  SmallVector<Value> indices(loadOp.getIndices().begin(),
332  loadOp.getIndices().end());
333  // For affine ops, we need to apply the map to get the operands to get the
334  // "actual" indices.
335  if (auto affineLoadOp =
336  dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
337  AffineMap affineMap = affineLoadOp.getAffineMap();
338  auto expandedIndices = calculateExpandedAccessIndices(
339  affineMap, indices, loadOp.getLoc(), rewriter);
340  indices.assign(expandedIndices.begin(), expandedIndices.end());
341  }
342  SmallVector<Value> sourceIndices;
343  // memref.load and affine.load guarantee that indexes start inbounds
344  // while the vector operations don't. This impacts if our linearization
345  // is `disjoint`
347  loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
348  isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
349  return failure();
351  .Case([&](affine::AffineLoadOp op) {
352  rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
353  loadOp, expandShapeOp.getViewSource(), sourceIndices);
354  })
355  .Case([&](memref::LoadOp op) {
356  rewriter.replaceOpWithNewOp<memref::LoadOp>(
357  loadOp, expandShapeOp.getViewSource(), sourceIndices,
358  op.getNontemporal());
359  })
360  .Case([&](vector::LoadOp op) {
361  rewriter.replaceOpWithNewOp<vector::LoadOp>(
362  op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
363  op.getNontemporal());
364  })
365  .Case([&](vector::MaskedLoadOp op) {
366  rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
367  op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
368  op.getMask(), op.getPassThru());
369  })
370  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
371  return success();
372 }
373 
374 template <typename OpTy>
375 LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
376  OpTy loadOp, PatternRewriter &rewriter) const {
377  auto collapseShapeOp = getMemRefOperand(loadOp)
378  .template getDefiningOp<memref::CollapseShapeOp>();
379 
380  if (!collapseShapeOp)
381  return failure();
382 
383  SmallVector<Value> indices(loadOp.getIndices().begin(),
384  loadOp.getIndices().end());
385  // For affine ops, we need to apply the map to get the operands to get the
386  // "actual" indices.
387  if (auto affineLoadOp =
388  dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
389  AffineMap affineMap = affineLoadOp.getAffineMap();
390  auto expandedIndices = calculateExpandedAccessIndices(
391  affineMap, indices, loadOp.getLoc(), rewriter);
392  indices.assign(expandedIndices.begin(), expandedIndices.end());
393  }
394  SmallVector<Value> sourceIndices;
396  loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
397  return failure();
399  .Case([&](affine::AffineLoadOp op) {
400  rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
401  loadOp, collapseShapeOp.getViewSource(), sourceIndices);
402  })
403  .Case([&](memref::LoadOp op) {
404  rewriter.replaceOpWithNewOp<memref::LoadOp>(
405  loadOp, collapseShapeOp.getViewSource(), sourceIndices,
406  op.getNontemporal());
407  })
408  .Case([&](vector::LoadOp op) {
409  rewriter.replaceOpWithNewOp<vector::LoadOp>(
410  op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
411  op.getNontemporal());
412  })
413  .Case([&](vector::MaskedLoadOp op) {
414  rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
415  op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
416  op.getMask(), op.getPassThru());
417  })
418  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
419  return success();
420 }
421 
422 template <typename OpTy>
423 LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
424  OpTy storeOp, PatternRewriter &rewriter) const {
425  auto subViewOp =
426  getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
427 
428  if (!subViewOp)
429  return rewriter.notifyMatchFailure(storeOp, "not a subview producer");
430 
431  LogicalResult preconditionResult =
432  preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp);
433  if (failed(preconditionResult))
434  return preconditionResult;
435 
436  SmallVector<Value> indices(storeOp.getIndices().begin(),
437  storeOp.getIndices().end());
438  // For affine ops, we need to apply the map to get the operands to get the
439  // "actual" indices.
440  if (auto affineStoreOp =
441  dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
442  AffineMap affineMap = affineStoreOp.getAffineMap();
443  auto expandedIndices = calculateExpandedAccessIndices(
444  affineMap, indices, storeOp.getLoc(), rewriter);
445  indices.assign(expandedIndices.begin(), expandedIndices.end());
446  }
447  SmallVector<Value> sourceIndices;
449  rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
450  subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
451  sourceIndices);
452 
454  .Case([&](affine::AffineStoreOp op) {
455  rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
456  op, op.getValue(), subViewOp.getSource(), sourceIndices);
457  })
458  .Case([&](memref::StoreOp op) {
459  rewriter.replaceOpWithNewOp<memref::StoreOp>(
460  op, op.getValue(), subViewOp.getSource(), sourceIndices,
461  op.getNontemporal());
462  })
463  .Case([&](vector::TransferWriteOp op) {
464  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
465  op, op.getValue(), subViewOp.getSource(), sourceIndices,
467  op.getPermutationMap(), subViewOp.getSourceType().getRank(),
468  subViewOp.getDroppedDims())),
469  op.getMask(), op.getInBoundsAttr());
470  })
471  .Case([&](vector::StoreOp op) {
472  rewriter.replaceOpWithNewOp<vector::StoreOp>(
473  op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
474  })
475  .Case([&](vector::MaskedStoreOp op) {
476  rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
477  op, subViewOp.getSource(), sourceIndices, op.getMask(),
478  op.getValueToStore());
479  })
480  .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
481  rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
482  op, op.getSrc(), subViewOp.getSource(), sourceIndices,
483  op.getLeadDimension(), op.getTransposeAttr());
484  })
485  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
486  return success();
487 }
488 
489 template <typename OpTy>
490 LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
491  OpTy storeOp, PatternRewriter &rewriter) const {
492  auto expandShapeOp =
493  getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>();
494 
495  if (!expandShapeOp)
496  return failure();
497 
498  SmallVector<Value> indices(storeOp.getIndices().begin(),
499  storeOp.getIndices().end());
500  // For affine ops, we need to apply the map to get the operands to get the
501  // "actual" indices.
502  if (auto affineStoreOp =
503  dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
504  AffineMap affineMap = affineStoreOp.getAffineMap();
505  auto expandedIndices = calculateExpandedAccessIndices(
506  affineMap, indices, storeOp.getLoc(), rewriter);
507  indices.assign(expandedIndices.begin(), expandedIndices.end());
508  }
509  SmallVector<Value> sourceIndices;
510  // memref.store and affine.store guarantee that indexes start inbounds
511  // while the vector operations don't. This impacts if our linearization
512  // is `disjoint`
514  storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
515  isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation()))))
516  return failure();
518  .Case([&](affine::AffineStoreOp op) {
519  rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
520  storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
521  sourceIndices);
522  })
523  .Case([&](memref::StoreOp op) {
524  rewriter.replaceOpWithNewOp<memref::StoreOp>(
525  storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
526  sourceIndices, op.getNontemporal());
527  })
528  .Case([&](vector::StoreOp op) {
529  rewriter.replaceOpWithNewOp<vector::StoreOp>(
530  op, op.getValueToStore(), expandShapeOp.getViewSource(),
531  sourceIndices, op.getNontemporal());
532  })
533  .Case([&](vector::MaskedStoreOp op) {
534  rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
535  op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
536  op.getValueToStore());
537  })
538  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
539  return success();
540 }
541 
542 template <typename OpTy>
543 LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
544  OpTy storeOp, PatternRewriter &rewriter) const {
545  auto collapseShapeOp = getMemRefOperand(storeOp)
546  .template getDefiningOp<memref::CollapseShapeOp>();
547 
548  if (!collapseShapeOp)
549  return failure();
550 
551  SmallVector<Value> indices(storeOp.getIndices().begin(),
552  storeOp.getIndices().end());
553  // For affine ops, we need to apply the map to get the operands to get the
554  // "actual" indices.
555  if (auto affineStoreOp =
556  dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
557  AffineMap affineMap = affineStoreOp.getAffineMap();
558  auto expandedIndices = calculateExpandedAccessIndices(
559  affineMap, indices, storeOp.getLoc(), rewriter);
560  indices.assign(expandedIndices.begin(), expandedIndices.end());
561  }
562  SmallVector<Value> sourceIndices;
564  storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
565  return failure();
567  .Case([&](affine::AffineStoreOp op) {
568  rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
569  storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
570  sourceIndices);
571  })
572  .Case([&](memref::StoreOp op) {
573  rewriter.replaceOpWithNewOp<memref::StoreOp>(
574  storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
575  sourceIndices, op.getNontemporal());
576  })
577  .Case([&](vector::StoreOp op) {
578  rewriter.replaceOpWithNewOp<vector::StoreOp>(
579  op, op.getValueToStore(), collapseShapeOp.getViewSource(),
580  sourceIndices, op.getNontemporal());
581  })
582  .Case([&](vector::MaskedStoreOp op) {
583  rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
584  op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
585  op.getValueToStore());
586  })
587  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
588  return success();
589 }
590 
591 LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
592  nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const {
593 
594  LLVM_DEBUG(DBGS() << "copyOp : " << copyOp << "\n");
595 
596  auto srcSubViewOp =
597  copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
598  auto dstSubViewOp =
599  copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
600 
601  if (!(srcSubViewOp || dstSubViewOp))
602  return rewriter.notifyMatchFailure(copyOp, "does not use subview ops for "
603  "source or destination");
604 
605  // If the source is a subview, we need to resolve the indices.
606  SmallVector<Value> srcindices(copyOp.getSrcIndices().begin(),
607  copyOp.getSrcIndices().end());
608  SmallVector<Value> foldedSrcIndices(srcindices);
609 
610  if (srcSubViewOp) {
611  LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n");
613  rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
614  srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
615  srcindices, foldedSrcIndices);
616  }
617 
618  // If the destination is a subview, we need to resolve the indices.
619  SmallVector<Value> dstindices(copyOp.getDstIndices().begin(),
620  copyOp.getDstIndices().end());
621  SmallVector<Value> foldedDstIndices(dstindices);
622 
623  if (dstSubViewOp) {
624  LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n");
626  rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
627  dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
628  dstindices, foldedDstIndices);
629  }
630 
631  // Replace the copy op with a new copy op that uses the source and destination
632  // of the subview.
633  rewriter.replaceOpWithNewOp<nvgpu::DeviceAsyncCopyOp>(
634  copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()),
635  (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
636  foldedDstIndices,
637  (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
638  foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
639  copyOp.getBypassL1Attr());
640 
641  return success();
642 }
643 
645  patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>,
646  LoadOpOfSubViewOpFolder<memref::LoadOp>,
647  LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
648  LoadOpOfSubViewOpFolder<vector::LoadOp>,
649  LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
650  LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
651  LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
652  StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
653  StoreOpOfSubViewOpFolder<memref::StoreOp>,
654  StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
655  StoreOpOfSubViewOpFolder<vector::StoreOp>,
656  StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
657  StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
658  LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
659  LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
660  LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
661  LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
662  StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
663  StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
664  StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
665  StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
666  LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>,
667  LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
668  LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
669  LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
670  StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
671  StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
672  StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
673  StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
674  SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
675  patterns.getContext());
676 }
677 
678 //===----------------------------------------------------------------------===//
679 // Pass registration
680 //===----------------------------------------------------------------------===//
681 
682 namespace {
683 
684 struct FoldMemRefAliasOpsPass final
685  : public memref::impl::FoldMemRefAliasOpsPassBase<FoldMemRefAliasOpsPass> {
686  void runOnOperation() override;
687 };
688 
689 } // namespace
690 
691 void FoldMemRefAliasOpsPass::runOnOperation() {
694  (void)applyPatternsGreedily(getOperation(), std::move(patterns));
695 }
static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, Operation *op, memref::SubViewOp subviewOp)
static LogicalResult preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp, memref::SubViewOp subviewOp)
static SmallVector< Value > calculateExpandedAccessIndices(AffineMap affineMap, const SmallVector< Value > &indices, Location loc, PatternRewriter &rewriter)
static Value getMemRefOperand(LoadOrStoreOpTy op)
Helpers to access the memref operand for each op.
#define DBGS()
static MLIRContext * getContext(OpFoldResult val)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
unsigned getNumResults() const
Definition: AffineMap.cpp:398
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:647
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void resolveSizesIntoOpWithSizes(ArrayRef< OpFoldResult > sourceSizes, ArrayRef< OpFoldResult > destSizes, const llvm::SmallBitVector &rankReducedSourceDims, SmallVectorImpl< OpFoldResult > &resolvedSizes)
Given sourceSizes, destSizes and information about which dimensions are dropped by the source: rankRe...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1329
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
LogicalResult resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
LogicalResult resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
Definition: AffineMap.cpp:948
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314