MLIR  21.0.0git
FoldMemRefAliasOps.cpp
Go to the documentation of this file.
1 //===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass folds loading/storing from/to subview ops into
10 // loading/storing from/to the original memref.
11 //
12 //===----------------------------------------------------------------------===//
13 
26 #include "mlir/IR/AffineMap.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/SmallBitVector.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/Debug.h"
32 
33 #define DEBUG_TYPE "fold-memref-alias-ops"
34 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
35 
36 namespace mlir {
37 namespace memref {
38 #define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS
39 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
40 } // namespace memref
41 } // namespace mlir
42 
43 using namespace mlir;
44 
45 //===----------------------------------------------------------------------===//
46 // Utility functions
47 //===----------------------------------------------------------------------===//
48 
49 /// Given the 'indices' of a load/store operation where the memref is a result
50 /// of a expand_shape op, returns the indices w.r.t to the source memref of the
51 /// expand_shape op. For example
52 ///
53 /// %0 = ... : memref<12x42xf32>
54 /// %1 = memref.expand_shape %0 [[0, 1], [2]]
55 /// : memref<12x42xf32> into memref<2x6x42xf32>
56 /// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
57 ///
58 /// could be folded into
59 ///
60 /// %2 = load %0[6 * i1 + i2, %i3] :
61 /// memref<12x42xf32>
62 static LogicalResult resolveSourceIndicesExpandShape(
63  Location loc, PatternRewriter &rewriter,
64  memref::ExpandShapeOp expandShapeOp, ValueRange indices,
65  SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
66  SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
67 
68  // Traverse all reassociation groups to determine the appropriate indices
69  // corresponding to each one of them post op folding.
70  for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
71  assert(!group.empty() && "association indices groups cannot be empty");
72  int64_t groupSize = group.size();
73  if (groupSize == 1) {
74  sourceIndices.push_back(indices[group[0]]);
75  continue;
76  }
77  SmallVector<OpFoldResult> groupBasis =
78  llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
79  SmallVector<Value> groupIndices =
80  llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
81  Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
82  loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
83  sourceIndices.push_back(collapsedIndex);
84  }
85  return success();
86 }
87 
88 /// Given the 'indices' of a load/store operation where the memref is a result
89 /// of a collapse_shape op, returns the indices w.r.t to the source memref of
90 /// the collapse_shape op. For example
91 ///
92 /// %0 = ... : memref<2x6x42xf32>
93 /// %1 = memref.collapse_shape %0 [[0, 1], [2]]
94 /// : memref<2x6x42xf32> into memref<12x42xf32>
95 /// %2 = load %1[%i1, %i2] : memref<12x42xf32>
96 ///
97 /// could be folded into
98 ///
99 /// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
100 /// memref<2x6x42xf32>
101 static LogicalResult
103  memref::CollapseShapeOp collapseShapeOp,
104  ValueRange indices,
105  SmallVectorImpl<Value> &sourceIndices) {
106  // Note: collapse_shape requires a strided memref, we can do this.
107  auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
108  loc, collapseShapeOp.getSrc());
109  SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
110  for (auto [index, group] :
111  llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
112  assert(!group.empty() && "association indices groups cannot be empty");
113  int64_t groupSize = group.size();
114 
115  if (groupSize == 1) {
116  sourceIndices.push_back(index);
117  continue;
118  }
119 
121  llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
122  auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
123  loc, index, basis, /*hasOuterBound=*/true);
124  llvm::append_range(sourceIndices, delinearize.getResults());
125  }
126  if (collapseShapeOp.getReassociationIndices().empty()) {
127  auto zeroAffineMap = rewriter.getConstantAffineMap(0);
128  int64_t srcRank =
129  cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
131  rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
132  for (int64_t i = 0; i < srcRank; i++) {
133  sourceIndices.push_back(
134  getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
135  }
136  }
137  return success();
138 }
139 
140 /// Helpers to access the memref operand for each op.
141 template <typename LoadOrStoreOpTy>
142 static Value getMemRefOperand(LoadOrStoreOpTy op) {
143  return op.getMemref();
144 }
145 
146 static Value getMemRefOperand(vector::TransferReadOp op) {
147  return op.getBase();
148 }
149 
150 static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
151  return op.getSrcMemref();
152 }
153 
154 static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
155 
156 static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); }
157 
158 static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
159 
160 static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); }
161 
162 static Value getMemRefOperand(vector::TransferWriteOp op) {
163  return op.getBase();
164 }
165 
166 static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) {
167  return op.getSrcMemref();
168 }
169 
170 static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) {
171  return op.getDstMemref();
172 }
173 
174 //===----------------------------------------------------------------------===//
175 // Patterns
176 //===----------------------------------------------------------------------===//
177 
178 namespace {
179 /// Merges subview operation with load/transferRead operation.
180 template <typename OpTy>
181 class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
182 public:
184 
185  LogicalResult matchAndRewrite(OpTy loadOp,
186  PatternRewriter &rewriter) const override;
187 };
188 
189 /// Merges expand_shape operation with load/transferRead operation.
190 template <typename OpTy>
191 class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
192 public:
194 
195  LogicalResult matchAndRewrite(OpTy loadOp,
196  PatternRewriter &rewriter) const override;
197 };
198 
199 /// Merges collapse_shape operation with load/transferRead operation.
200 template <typename OpTy>
201 class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
202 public:
204 
205  LogicalResult matchAndRewrite(OpTy loadOp,
206  PatternRewriter &rewriter) const override;
207 };
208 
209 /// Merges subview operation with store/transferWriteOp operation.
210 template <typename OpTy>
211 class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
212 public:
214 
215  LogicalResult matchAndRewrite(OpTy storeOp,
216  PatternRewriter &rewriter) const override;
217 };
218 
219 /// Merges expand_shape operation with store/transferWriteOp operation.
220 template <typename OpTy>
221 class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
222 public:
224 
225  LogicalResult matchAndRewrite(OpTy storeOp,
226  PatternRewriter &rewriter) const override;
227 };
228 
229 /// Merges collapse_shape operation with store/transferWriteOp operation.
230 template <typename OpTy>
231 class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
232 public:
234 
235  LogicalResult matchAndRewrite(OpTy storeOp,
236  PatternRewriter &rewriter) const override;
237 };
238 
239 /// Folds subview(subview(x)) to a single subview(x).
240 class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
241 public:
243 
244  LogicalResult matchAndRewrite(memref::SubViewOp subView,
245  PatternRewriter &rewriter) const override {
246  auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
247  if (!srcSubView)
248  return failure();
249 
250  // TODO: relax unit stride assumption.
251  if (!subView.hasUnitStride()) {
252  return rewriter.notifyMatchFailure(subView, "requires unit strides");
253  }
254  if (!srcSubView.hasUnitStride()) {
255  return rewriter.notifyMatchFailure(srcSubView, "requires unit strides");
256  }
257 
258  // Resolve sizes according to dropped dims.
259  SmallVector<OpFoldResult> resolvedSizes;
260  llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
261  affine::resolveSizesIntoOpWithSizes(srcSubView.getMixedSizes(),
262  subView.getMixedSizes(), srcDroppedDims,
263  resolvedSizes);
264 
265  // Resolve offsets according to source offsets and strides.
266  SmallVector<Value> resolvedOffsets;
268  rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
269  srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
270  resolvedOffsets);
271 
272  // Replace original op.
273  rewriter.replaceOpWithNewOp<memref::SubViewOp>(
274  subView, subView.getType(), srcSubView.getSource(),
275  getAsOpFoldResult(resolvedOffsets), resolvedSizes,
276  srcSubView.getMixedStrides());
277 
278  return success();
279  }
280 };
281 
282 /// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern
283 /// is folds subview on src and dst memref of the copy.
284 class NVGPUAsyncCopyOpSubViewOpFolder final
285  : public OpRewritePattern<nvgpu::DeviceAsyncCopyOp> {
286 public:
288 
289  LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
290  PatternRewriter &rewriter) const override;
291 };
292 } // namespace
293 
294 static SmallVector<Value>
296  const SmallVector<Value> &indices, Location loc,
297  PatternRewriter &rewriter) {
298  SmallVector<OpFoldResult> indicesOfr(llvm::to_vector(
299  llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; })));
300  SmallVector<Value> expandedIndices;
301  for (unsigned i = 0, e = affineMap.getNumResults(); i < e; i++) {
303  rewriter, loc, affineMap.getSubMap({i}), indicesOfr);
304  expandedIndices.push_back(
305  getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
306  }
307  return expandedIndices;
308 }
309 
310 template <typename XferOp>
311 static LogicalResult
313  memref::SubViewOp subviewOp) {
314  static_assert(
315  !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
316  "must be a vector transfer op");
317  if (xferOp.hasOutOfBoundsDim())
318  return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
319  if (!subviewOp.hasUnitStride()) {
320  return rewriter.notifyMatchFailure(
321  xferOp, "non-1 stride subview, need to track strides in folded memref");
322  }
323  return success();
324 }
325 
326 static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
327  Operation *op,
328  memref::SubViewOp subviewOp) {
329  return success();
330 }
331 
332 static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
333  vector::TransferReadOp readOp,
334  memref::SubViewOp subviewOp) {
335  return preconditionsFoldSubViewOpImpl(rewriter, readOp, subviewOp);
336 }
337 
338 static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
339  vector::TransferWriteOp writeOp,
340  memref::SubViewOp subviewOp) {
341  return preconditionsFoldSubViewOpImpl(rewriter, writeOp, subviewOp);
342 }
343 
344 template <typename OpTy>
345 LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
346  OpTy loadOp, PatternRewriter &rewriter) const {
347  auto subViewOp =
348  getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
349 
350  if (!subViewOp)
351  return rewriter.notifyMatchFailure(loadOp, "not a subview producer");
352 
353  LogicalResult preconditionResult =
354  preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp);
355  if (failed(preconditionResult))
356  return preconditionResult;
357 
358  SmallVector<Value> indices(loadOp.getIndices().begin(),
359  loadOp.getIndices().end());
360  // For affine ops, we need to apply the map to get the operands to get the
361  // "actual" indices.
362  if (auto affineLoadOp =
363  dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
364  AffineMap affineMap = affineLoadOp.getAffineMap();
365  auto expandedIndices = calculateExpandedAccessIndices(
366  affineMap, indices, loadOp.getLoc(), rewriter);
367  indices.assign(expandedIndices.begin(), expandedIndices.end());
368  }
369  SmallVector<Value> sourceIndices;
371  rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
372  subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
373  sourceIndices);
374 
376  .Case([&](affine::AffineLoadOp op) {
377  rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
378  loadOp, subViewOp.getSource(), sourceIndices);
379  })
380  .Case([&](memref::LoadOp op) {
381  rewriter.replaceOpWithNewOp<memref::LoadOp>(
382  loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
383  })
384  .Case([&](vector::LoadOp op) {
385  rewriter.replaceOpWithNewOp<vector::LoadOp>(
386  op, op.getType(), subViewOp.getSource(), sourceIndices);
387  })
388  .Case([&](vector::MaskedLoadOp op) {
389  rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
390  op, op.getType(), subViewOp.getSource(), sourceIndices,
391  op.getMask(), op.getPassThru());
392  })
393  .Case([&](vector::TransferReadOp op) {
394  rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
395  op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
397  op.getPermutationMap(), subViewOp.getSourceType().getRank(),
398  subViewOp.getDroppedDims())),
399  op.getPadding(), op.getMask(), op.getInBoundsAttr());
400  })
401  .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
402  rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
403  op, op.getType(), subViewOp.getSource(), sourceIndices,
404  op.getLeadDimension(), op.getTransposeAttr());
405  })
406  .Case([&](nvgpu::LdMatrixOp op) {
407  rewriter.replaceOpWithNewOp<nvgpu::LdMatrixOp>(
408  op, op.getType(), subViewOp.getSource(), sourceIndices,
409  op.getTranspose(), op.getNumTiles());
410  })
411  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
412  return success();
413 }
414 
415 template <typename OpTy>
416 LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
417  OpTy loadOp, PatternRewriter &rewriter) const {
418  auto expandShapeOp =
419  getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>();
420 
421  if (!expandShapeOp)
422  return failure();
423 
424  SmallVector<Value> indices(loadOp.getIndices().begin(),
425  loadOp.getIndices().end());
426  // For affine ops, we need to apply the map to get the operands to get the
427  // "actual" indices.
428  if (auto affineLoadOp =
429  dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
430  AffineMap affineMap = affineLoadOp.getAffineMap();
431  auto expandedIndices = calculateExpandedAccessIndices(
432  affineMap, indices, loadOp.getLoc(), rewriter);
433  indices.assign(expandedIndices.begin(), expandedIndices.end());
434  }
435  SmallVector<Value> sourceIndices;
436  // memref.load and affine.load guarantee that indexes start inbounds
437  // while the vector operations don't. This impacts if our linearization
438  // is `disjoint`
440  loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
441  isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
442  return failure();
444  .Case([&](affine::AffineLoadOp op) {
445  rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
446  loadOp, expandShapeOp.getViewSource(), sourceIndices);
447  })
448  .Case([&](memref::LoadOp op) {
449  rewriter.replaceOpWithNewOp<memref::LoadOp>(
450  loadOp, expandShapeOp.getViewSource(), sourceIndices,
451  op.getNontemporal());
452  })
453  .Case([&](vector::LoadOp op) {
454  rewriter.replaceOpWithNewOp<vector::LoadOp>(
455  op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
456  op.getNontemporal());
457  })
458  .Case([&](vector::MaskedLoadOp op) {
459  rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
460  op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
461  op.getMask(), op.getPassThru());
462  })
463  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
464  return success();
465 }
466 
467 template <typename OpTy>
468 LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
469  OpTy loadOp, PatternRewriter &rewriter) const {
470  auto collapseShapeOp = getMemRefOperand(loadOp)
471  .template getDefiningOp<memref::CollapseShapeOp>();
472 
473  if (!collapseShapeOp)
474  return failure();
475 
476  SmallVector<Value> indices(loadOp.getIndices().begin(),
477  loadOp.getIndices().end());
478  // For affine ops, we need to apply the map to get the operands to get the
479  // "actual" indices.
480  if (auto affineLoadOp =
481  dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
482  AffineMap affineMap = affineLoadOp.getAffineMap();
483  auto expandedIndices = calculateExpandedAccessIndices(
484  affineMap, indices, loadOp.getLoc(), rewriter);
485  indices.assign(expandedIndices.begin(), expandedIndices.end());
486  }
487  SmallVector<Value> sourceIndices;
489  loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
490  return failure();
492  .Case([&](affine::AffineLoadOp op) {
493  rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
494  loadOp, collapseShapeOp.getViewSource(), sourceIndices);
495  })
496  .Case([&](memref::LoadOp op) {
497  rewriter.replaceOpWithNewOp<memref::LoadOp>(
498  loadOp, collapseShapeOp.getViewSource(), sourceIndices,
499  op.getNontemporal());
500  })
501  .Case([&](vector::LoadOp op) {
502  rewriter.replaceOpWithNewOp<vector::LoadOp>(
503  op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
504  op.getNontemporal());
505  })
506  .Case([&](vector::MaskedLoadOp op) {
507  rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
508  op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
509  op.getMask(), op.getPassThru());
510  })
511  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
512  return success();
513 }
514 
515 template <typename OpTy>
516 LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
517  OpTy storeOp, PatternRewriter &rewriter) const {
518  auto subViewOp =
519  getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
520 
521  if (!subViewOp)
522  return rewriter.notifyMatchFailure(storeOp, "not a subview producer");
523 
524  LogicalResult preconditionResult =
525  preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp);
526  if (failed(preconditionResult))
527  return preconditionResult;
528 
529  SmallVector<Value> indices(storeOp.getIndices().begin(),
530  storeOp.getIndices().end());
531  // For affine ops, we need to apply the map to get the operands to get the
532  // "actual" indices.
533  if (auto affineStoreOp =
534  dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
535  AffineMap affineMap = affineStoreOp.getAffineMap();
536  auto expandedIndices = calculateExpandedAccessIndices(
537  affineMap, indices, storeOp.getLoc(), rewriter);
538  indices.assign(expandedIndices.begin(), expandedIndices.end());
539  }
540  SmallVector<Value> sourceIndices;
542  rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
543  subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
544  sourceIndices);
545 
547  .Case([&](affine::AffineStoreOp op) {
548  rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
549  op, op.getValue(), subViewOp.getSource(), sourceIndices);
550  })
551  .Case([&](memref::StoreOp op) {
552  rewriter.replaceOpWithNewOp<memref::StoreOp>(
553  op, op.getValue(), subViewOp.getSource(), sourceIndices,
554  op.getNontemporal());
555  })
556  .Case([&](vector::TransferWriteOp op) {
557  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
558  op, op.getValue(), subViewOp.getSource(), sourceIndices,
560  op.getPermutationMap(), subViewOp.getSourceType().getRank(),
561  subViewOp.getDroppedDims())),
562  op.getMask(), op.getInBoundsAttr());
563  })
564  .Case([&](vector::StoreOp op) {
565  rewriter.replaceOpWithNewOp<vector::StoreOp>(
566  op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
567  })
568  .Case([&](vector::MaskedStoreOp op) {
569  rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
570  op, subViewOp.getSource(), sourceIndices, op.getMask(),
571  op.getValueToStore());
572  })
573  .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
574  rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
575  op, op.getSrc(), subViewOp.getSource(), sourceIndices,
576  op.getLeadDimension(), op.getTransposeAttr());
577  })
578  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
579  return success();
580 }
581 
582 template <typename OpTy>
583 LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
584  OpTy storeOp, PatternRewriter &rewriter) const {
585  auto expandShapeOp =
586  getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>();
587 
588  if (!expandShapeOp)
589  return failure();
590 
591  SmallVector<Value> indices(storeOp.getIndices().begin(),
592  storeOp.getIndices().end());
593  // For affine ops, we need to apply the map to get the operands to get the
594  // "actual" indices.
595  if (auto affineStoreOp =
596  dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
597  AffineMap affineMap = affineStoreOp.getAffineMap();
598  auto expandedIndices = calculateExpandedAccessIndices(
599  affineMap, indices, storeOp.getLoc(), rewriter);
600  indices.assign(expandedIndices.begin(), expandedIndices.end());
601  }
602  SmallVector<Value> sourceIndices;
603  // memref.store and affine.store guarantee that indexes start inbounds
604  // while the vector operations don't. This impacts if our linearization
605  // is `disjoint`
607  storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
608  isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation()))))
609  return failure();
611  .Case([&](affine::AffineStoreOp op) {
612  rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
613  storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
614  sourceIndices);
615  })
616  .Case([&](memref::StoreOp op) {
617  rewriter.replaceOpWithNewOp<memref::StoreOp>(
618  storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
619  sourceIndices, op.getNontemporal());
620  })
621  .Case([&](vector::StoreOp op) {
622  rewriter.replaceOpWithNewOp<vector::StoreOp>(
623  op, op.getValueToStore(), expandShapeOp.getViewSource(),
624  sourceIndices, op.getNontemporal());
625  })
626  .Case([&](vector::MaskedStoreOp op) {
627  rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
628  op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
629  op.getValueToStore());
630  })
631  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
632  return success();
633 }
634 
635 template <typename OpTy>
636 LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
637  OpTy storeOp, PatternRewriter &rewriter) const {
638  auto collapseShapeOp = getMemRefOperand(storeOp)
639  .template getDefiningOp<memref::CollapseShapeOp>();
640 
641  if (!collapseShapeOp)
642  return failure();
643 
644  SmallVector<Value> indices(storeOp.getIndices().begin(),
645  storeOp.getIndices().end());
646  // For affine ops, we need to apply the map to get the operands to get the
647  // "actual" indices.
648  if (auto affineStoreOp =
649  dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
650  AffineMap affineMap = affineStoreOp.getAffineMap();
651  auto expandedIndices = calculateExpandedAccessIndices(
652  affineMap, indices, storeOp.getLoc(), rewriter);
653  indices.assign(expandedIndices.begin(), expandedIndices.end());
654  }
655  SmallVector<Value> sourceIndices;
657  storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
658  return failure();
660  .Case([&](affine::AffineStoreOp op) {
661  rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
662  storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
663  sourceIndices);
664  })
665  .Case([&](memref::StoreOp op) {
666  rewriter.replaceOpWithNewOp<memref::StoreOp>(
667  storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
668  sourceIndices, op.getNontemporal());
669  })
670  .Case([&](vector::StoreOp op) {
671  rewriter.replaceOpWithNewOp<vector::StoreOp>(
672  op, op.getValueToStore(), collapseShapeOp.getViewSource(),
673  sourceIndices, op.getNontemporal());
674  })
675  .Case([&](vector::MaskedStoreOp op) {
676  rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
677  op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
678  op.getValueToStore());
679  })
680  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
681  return success();
682 }
683 
684 LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
685  nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const {
686 
687  LLVM_DEBUG(DBGS() << "copyOp : " << copyOp << "\n");
688 
689  auto srcSubViewOp =
690  copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
691  auto dstSubViewOp =
692  copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
693 
694  if (!(srcSubViewOp || dstSubViewOp))
695  return rewriter.notifyMatchFailure(copyOp, "does not use subview ops for "
696  "source or destination");
697 
698  // If the source is a subview, we need to resolve the indices.
699  SmallVector<Value> srcindices(copyOp.getSrcIndices().begin(),
700  copyOp.getSrcIndices().end());
701  SmallVector<Value> foldedSrcIndices(srcindices);
702 
703  if (srcSubViewOp) {
704  LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n");
706  rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
707  srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
708  srcindices, foldedSrcIndices);
709  }
710 
711  // If the destination is a subview, we need to resolve the indices.
712  SmallVector<Value> dstindices(copyOp.getDstIndices().begin(),
713  copyOp.getDstIndices().end());
714  SmallVector<Value> foldedDstIndices(dstindices);
715 
716  if (dstSubViewOp) {
717  LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n");
719  rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
720  dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
721  dstindices, foldedDstIndices);
722  }
723 
724  // Replace the copy op with a new copy op that uses the source and destination
725  // of the subview.
726  rewriter.replaceOpWithNewOp<nvgpu::DeviceAsyncCopyOp>(
727  copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()),
728  (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
729  foldedDstIndices,
730  (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
731  foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
732  copyOp.getBypassL1Attr());
733 
734  return success();
735 }
736 
738  patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>,
739  LoadOpOfSubViewOpFolder<memref::LoadOp>,
740  LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
741  LoadOpOfSubViewOpFolder<vector::LoadOp>,
742  LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
743  LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
744  LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
745  StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
746  StoreOpOfSubViewOpFolder<memref::StoreOp>,
747  StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
748  StoreOpOfSubViewOpFolder<vector::StoreOp>,
749  StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
750  StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
751  LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
752  LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
753  LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
754  LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
755  StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
756  StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
757  StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
758  StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
759  LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>,
760  LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
761  LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
762  LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
763  StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
764  StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
765  StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
766  StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
767  SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
768  patterns.getContext());
769 }
770 
771 //===----------------------------------------------------------------------===//
772 // Pass registration
773 //===----------------------------------------------------------------------===//
774 
775 namespace {
776 
777 struct FoldMemRefAliasOpsPass final
778  : public memref::impl::FoldMemRefAliasOpsPassBase<FoldMemRefAliasOpsPass> {
779  void runOnOperation() override;
780 };
781 
782 } // namespace
783 
784 void FoldMemRefAliasOpsPass::runOnOperation() {
787  (void)applyPatternsGreedily(getOperation(), std::move(patterns));
788 }
static LogicalResult resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, Operation *op, memref::SubViewOp subviewOp)
static LogicalResult preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp, memref::SubViewOp subviewOp)
static SmallVector< Value > calculateExpandedAccessIndices(AffineMap affineMap, const SmallVector< Value > &indices, Location loc, PatternRewriter &rewriter)
static Value getMemRefOperand(LoadOrStoreOpTy op)
Helpers to access the memref operand for each op.
static LogicalResult resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
#define DBGS()
static MLIRContext * getContext(OpFoldResult val)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:654
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition: Builders.cpp:374
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void resolveSizesIntoOpWithSizes(ArrayRef< OpFoldResult > sourceSizes, ArrayRef< OpFoldResult > destSizes, const llvm::SmallBitVector &rankReducedSourceDims, SmallVectorImpl< OpFoldResult > &resolvedSizes)
Given sourceSizes, destSizes and information about which dimensions are dropped by the source: rankRe...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1224
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
Definition: AffineMap.cpp:955
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314