MLIR  19.0.0git
FoldMemRefAliasOps.cpp
Go to the documentation of this file.
1 //===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass folds loading/storing from/to subview ops into
10 // loading/storing from/to the original memref.
11 //
12 //===----------------------------------------------------------------------===//
13 
25 #include "mlir/IR/AffineMap.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallBitVector.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/Support/Debug.h"
31 
32 #define DEBUG_TYPE "fold-memref-alias-ops"
33 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
34 
35 namespace mlir {
36 namespace memref {
37 #define GEN_PASS_DEF_FOLDMEMREFALIASOPS
38 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
39 } // namespace memref
40 } // namespace mlir
41 
42 using namespace mlir;
43 
44 //===----------------------------------------------------------------------===//
45 // Utility functions
46 //===----------------------------------------------------------------------===//
47 
48 /// Given the 'indices' of a load/store operation where the memref is a result
49 /// of a expand_shape op, returns the indices w.r.t to the source memref of the
50 /// expand_shape op. For example
51 ///
52 /// %0 = ... : memref<12x42xf32>
53 /// %1 = memref.expand_shape %0 [[0, 1], [2]]
54 /// : memref<12x42xf32> into memref<2x6x42xf32>
55 /// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
56 ///
57 /// could be folded into
58 ///
59 /// %2 = load %0[6 * i1 + i2, %i3] :
60 /// memref<12x42xf32>
61 static LogicalResult
63  memref::ExpandShapeOp expandShapeOp,
64  ValueRange indices,
65  SmallVectorImpl<Value> &sourceIndices) {
66  // The below implementation uses computeSuffixProduct method, which only
67  // allows int64_t values (i.e., static shape). Bail out if it has dynamic
68  // shapes.
69  if (!expandShapeOp.getResultType().hasStaticShape())
70  return failure();
71 
72  MLIRContext *ctx = rewriter.getContext();
73  for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
74  assert(!groups.empty() && "association indices groups cannot be empty");
75  int64_t groupSize = groups.size();
76 
77  // Construct the expression for the index value w.r.t to expand shape op
78  // source corresponding the indices wrt to expand shape op result.
79  SmallVector<int64_t> sizes(groupSize);
80  for (int64_t i = 0; i < groupSize; ++i)
81  sizes[i] = expandShapeOp.getResultType().getDimSize(groups[i]);
82  SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
83  SmallVector<AffineExpr> dims(groupSize);
84  bindDimsList(ctx, MutableArrayRef{dims});
85  AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct);
86 
87  /// Apply permutation and create AffineApplyOp.
88  SmallVector<OpFoldResult> dynamicIndices(groupSize);
89  for (int64_t i = 0; i < groupSize; i++)
90  dynamicIndices[i] = indices[groups[i]];
91 
92  // Creating maximally folded and composd affine.apply composes better with
93  // other transformations without interleaving canonicalization passes.
95  rewriter, loc,
96  AffineMap::get(/*numDims=*/groupSize,
97  /*numSymbols=*/0, srcIndexExpr),
98  dynamicIndices);
99  sourceIndices.push_back(
100  getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
101  }
102  return success();
103 }
104 
105 /// Given the 'indices' of a load/store operation where the memref is a result
106 /// of a collapse_shape op, returns the indices w.r.t to the source memref of
107 /// the collapse_shape op. For example
108 ///
109 /// %0 = ... : memref<2x6x42xf32>
110 /// %1 = memref.collapse_shape %0 [[0, 1], [2]]
111 /// : memref<2x6x42xf32> into memref<12x42xf32>
112 /// %2 = load %1[%i1, %i2] : memref<12x42xf32>
113 ///
114 /// could be folded into
115 ///
116 /// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
117 /// memref<2x6x42xf32>
118 static LogicalResult
120  memref::CollapseShapeOp collapseShapeOp,
121  ValueRange indices,
122  SmallVectorImpl<Value> &sourceIndices) {
123  int64_t cnt = 0;
124  SmallVector<Value> tmp(indices.size());
125  SmallVector<OpFoldResult> dynamicIndices;
126  for (ArrayRef<int64_t> groups : collapseShapeOp.getReassociationIndices()) {
127  assert(!groups.empty() && "association indices groups cannot be empty");
128  dynamicIndices.push_back(indices[cnt++]);
129  int64_t groupSize = groups.size();
130 
131  // Calculate suffix product for all collapse op source dimension sizes
132  // except the most major one of each group.
133  // We allow the most major source dimension to be dynamic but enforce all
134  // others to be known statically.
135  SmallVector<int64_t> sizes(groupSize, 1);
136  for (int64_t i = 1; i < groupSize; ++i) {
137  sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
138  if (sizes[i] == ShapedType::kDynamic)
139  return failure();
140  }
141  SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
142 
143  // Derive the index values along all dimensions of the source corresponding
144  // to the index wrt to collapsed shape op output.
145  auto d0 = rewriter.getAffineDimExpr(0);
146  SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, suffixProduct);
147 
148  // Construct the AffineApplyOp for each delinearizingExpr.
149  for (int64_t i = 0; i < groupSize; i++) {
151  rewriter, loc,
152  AffineMap::get(/*numDims=*/1, /*numSymbols=*/0,
153  delinearizingExprs[i]),
154  dynamicIndices);
155  sourceIndices.push_back(
156  getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
157  }
158  dynamicIndices.clear();
159  }
160  if (collapseShapeOp.getReassociationIndices().empty()) {
161  auto zeroAffineMap = rewriter.getConstantAffineMap(0);
162  int64_t srcRank =
163  cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
164  for (int64_t i = 0; i < srcRank; i++) {
166  rewriter, loc, zeroAffineMap, dynamicIndices);
167  sourceIndices.push_back(
168  getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
169  }
170  }
171  return success();
172 }
173 
174 /// Helpers to access the memref operand for each op.
175 template <typename LoadOrStoreOpTy>
176 static Value getMemRefOperand(LoadOrStoreOpTy op) {
177  return op.getMemref();
178 }
179 
180 static Value getMemRefOperand(vector::TransferReadOp op) {
181  return op.getSource();
182 }
183 
184 static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
185  return op.getSrcMemref();
186 }
187 
188 static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
189 
190 static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); }
191 
192 static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
193 
194 static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); }
195 
196 static Value getMemRefOperand(vector::TransferWriteOp op) {
197  return op.getSource();
198 }
199 
200 static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) {
201  return op.getSrcMemref();
202 }
203 
204 static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) {
205  return op.getDstMemref();
206 }
207 
208 //===----------------------------------------------------------------------===//
209 // Patterns
210 //===----------------------------------------------------------------------===//
211 
212 namespace {
213 /// Merges subview operation with load/transferRead operation.
214 template <typename OpTy>
215 class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
216 public:
218 
219  LogicalResult matchAndRewrite(OpTy loadOp,
220  PatternRewriter &rewriter) const override;
221 };
222 
223 /// Merges expand_shape operation with load/transferRead operation.
224 template <typename OpTy>
225 class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
226 public:
228 
229  LogicalResult matchAndRewrite(OpTy loadOp,
230  PatternRewriter &rewriter) const override;
231 };
232 
233 /// Merges collapse_shape operation with load/transferRead operation.
234 template <typename OpTy>
235 class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
236 public:
238 
239  LogicalResult matchAndRewrite(OpTy loadOp,
240  PatternRewriter &rewriter) const override;
241 };
242 
243 /// Merges subview operation with store/transferWriteOp operation.
244 template <typename OpTy>
245 class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
246 public:
248 
249  LogicalResult matchAndRewrite(OpTy storeOp,
250  PatternRewriter &rewriter) const override;
251 };
252 
253 /// Merges expand_shape operation with store/transferWriteOp operation.
254 template <typename OpTy>
255 class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
256 public:
258 
259  LogicalResult matchAndRewrite(OpTy storeOp,
260  PatternRewriter &rewriter) const override;
261 };
262 
263 /// Merges collapse_shape operation with store/transferWriteOp operation.
264 template <typename OpTy>
265 class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
266 public:
268 
269  LogicalResult matchAndRewrite(OpTy storeOp,
270  PatternRewriter &rewriter) const override;
271 };
272 
273 /// Folds subview(subview(x)) to a single subview(x).
274 class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
275 public:
277 
278  LogicalResult matchAndRewrite(memref::SubViewOp subView,
279  PatternRewriter &rewriter) const override {
280  auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
281  if (!srcSubView)
282  return failure();
283 
284  // TODO: relax unit stride assumption.
285  if (!subView.hasUnitStride()) {
286  return rewriter.notifyMatchFailure(subView, "requires unit strides");
287  }
288  if (!srcSubView.hasUnitStride()) {
289  return rewriter.notifyMatchFailure(srcSubView, "requires unit strides");
290  }
291 
292  // Resolve sizes according to dropped dims.
293  SmallVector<OpFoldResult> resolvedSizes;
294  llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
295  affine::resolveSizesIntoOpWithSizes(srcSubView.getMixedSizes(),
296  subView.getMixedSizes(), srcDroppedDims,
297  resolvedSizes);
298 
299  // Resolve offsets according to source offsets and strides.
300  SmallVector<Value> resolvedOffsets;
302  rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
303  srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
304  resolvedOffsets);
305 
306  // Replace original op.
307  rewriter.replaceOpWithNewOp<memref::SubViewOp>(
308  subView, subView.getType(), srcSubView.getSource(),
309  getAsOpFoldResult(resolvedOffsets), resolvedSizes,
310  srcSubView.getMixedStrides());
311 
312  return success();
313  }
314 };
315 
316 /// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern
317 /// is folds subview on src and dst memref of the copy.
318 class NvgpuAsyncCopyOpSubViewOpFolder final
319  : public OpRewritePattern<nvgpu::DeviceAsyncCopyOp> {
320 public:
322 
323  LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
324  PatternRewriter &rewriter) const override;
325 };
326 } // namespace
327 
328 static SmallVector<Value>
330  const SmallVector<Value> &indices, Location loc,
331  PatternRewriter &rewriter) {
332  SmallVector<OpFoldResult> indicesOfr(llvm::to_vector(
333  llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; })));
334  SmallVector<Value> expandedIndices;
335  for (unsigned i = 0, e = affineMap.getNumResults(); i < e; i++) {
337  rewriter, loc, affineMap.getSubMap({i}), indicesOfr);
338  expandedIndices.push_back(
339  getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
340  }
341  return expandedIndices;
342 }
343 
344 template <typename XferOp>
345 static LogicalResult
347  memref::SubViewOp subviewOp) {
348  static_assert(
349  !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
350  "must be a vector transfer op");
351  if (xferOp.hasOutOfBoundsDim())
352  return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
353  if (!subviewOp.hasUnitStride()) {
354  return rewriter.notifyMatchFailure(
355  xferOp, "non-1 stride subview, need to track strides in folded memref");
356  }
357  return success();
358 }
359 
361  Operation *op,
362  memref::SubViewOp subviewOp) {
363  return success();
364 }
365 
367  vector::TransferReadOp readOp,
368  memref::SubViewOp subviewOp) {
369  return preconditionsFoldSubViewOpImpl(rewriter, readOp, subviewOp);
370 }
371 
373  vector::TransferWriteOp writeOp,
374  memref::SubViewOp subviewOp) {
375  return preconditionsFoldSubViewOpImpl(rewriter, writeOp, subviewOp);
376 }
377 
378 template <typename OpTy>
379 LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
380  OpTy loadOp, PatternRewriter &rewriter) const {
381  auto subViewOp =
382  getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
383 
384  if (!subViewOp)
385  return rewriter.notifyMatchFailure(loadOp, "not a subview producer");
386 
387  LogicalResult preconditionResult =
388  preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp);
389  if (failed(preconditionResult))
390  return preconditionResult;
391 
392  SmallVector<Value> indices(loadOp.getIndices().begin(),
393  loadOp.getIndices().end());
394  // For affine ops, we need to apply the map to get the operands to get the
395  // "actual" indices.
396  if (auto affineLoadOp =
397  dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
398  AffineMap affineMap = affineLoadOp.getAffineMap();
399  auto expandedIndices = calculateExpandedAccessIndices(
400  affineMap, indices, loadOp.getLoc(), rewriter);
401  indices.assign(expandedIndices.begin(), expandedIndices.end());
402  }
403  SmallVector<Value> sourceIndices;
405  rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
406  subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
407  sourceIndices);
408 
410  .Case([&](affine::AffineLoadOp op) {
411  rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
412  loadOp, subViewOp.getSource(), sourceIndices);
413  })
414  .Case([&](memref::LoadOp op) {
415  rewriter.replaceOpWithNewOp<memref::LoadOp>(
416  loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
417  })
418  .Case([&](vector::LoadOp op) {
419  rewriter.replaceOpWithNewOp<vector::LoadOp>(
420  op, op.getType(), subViewOp.getSource(), sourceIndices);
421  })
422  .Case([&](vector::MaskedLoadOp op) {
423  rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
424  op, op.getType(), subViewOp.getSource(), sourceIndices,
425  op.getMask(), op.getPassThru());
426  })
427  .Case([&](vector::TransferReadOp op) {
428  rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
429  op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
431  op.getPermutationMap(), subViewOp.getSourceType().getRank(),
432  subViewOp.getDroppedDims())),
433  op.getPadding(), op.getMask(), op.getInBoundsAttr());
434  })
435  .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
436  rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
437  op, op.getType(), subViewOp.getSource(), sourceIndices,
438  op.getLeadDimension(), op.getTransposeAttr());
439  })
440  .Case([&](nvgpu::LdMatrixOp op) {
441  rewriter.replaceOpWithNewOp<nvgpu::LdMatrixOp>(
442  op, op.getType(), subViewOp.getSource(), sourceIndices,
443  op.getTranspose(), op.getNumTiles());
444  })
445  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
446  return success();
447 }
448 
449 template <typename OpTy>
450 LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
451  OpTy loadOp, PatternRewriter &rewriter) const {
452  auto expandShapeOp =
453  getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>();
454 
455  if (!expandShapeOp)
456  return failure();
457 
458  SmallVector<Value> indices(loadOp.getIndices().begin(),
459  loadOp.getIndices().end());
460  // For affine ops, we need to apply the map to get the operands to get the
461  // "actual" indices.
462  if (auto affineLoadOp =
463  dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
464  AffineMap affineMap = affineLoadOp.getAffineMap();
465  auto expandedIndices = calculateExpandedAccessIndices(
466  affineMap, indices, loadOp.getLoc(), rewriter);
467  indices.assign(expandedIndices.begin(), expandedIndices.end());
468  }
469  SmallVector<Value> sourceIndices;
471  loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
472  return failure();
474  .Case<affine::AffineLoadOp, memref::LoadOp>([&](auto op) {
475  rewriter.replaceOpWithNewOp<decltype(op)>(
476  loadOp, expandShapeOp.getViewSource(), sourceIndices);
477  })
478  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
479  return success();
480 }
481 
482 template <typename OpTy>
483 LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
484  OpTy loadOp, PatternRewriter &rewriter) const {
485  auto collapseShapeOp = getMemRefOperand(loadOp)
486  .template getDefiningOp<memref::CollapseShapeOp>();
487 
488  if (!collapseShapeOp)
489  return failure();
490 
491  SmallVector<Value> indices(loadOp.getIndices().begin(),
492  loadOp.getIndices().end());
493  // For affine ops, we need to apply the map to get the operands to get the
494  // "actual" indices.
495  if (auto affineLoadOp =
496  dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
497  AffineMap affineMap = affineLoadOp.getAffineMap();
498  auto expandedIndices = calculateExpandedAccessIndices(
499  affineMap, indices, loadOp.getLoc(), rewriter);
500  indices.assign(expandedIndices.begin(), expandedIndices.end());
501  }
502  SmallVector<Value> sourceIndices;
504  loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
505  return failure();
507  .Case<affine::AffineLoadOp, memref::LoadOp>([&](auto op) {
508  rewriter.replaceOpWithNewOp<decltype(op)>(
509  loadOp, collapseShapeOp.getViewSource(), sourceIndices);
510  })
511  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
512  return success();
513 }
514 
515 template <typename OpTy>
516 LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
517  OpTy storeOp, PatternRewriter &rewriter) const {
518  auto subViewOp =
519  getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
520 
521  if (!subViewOp)
522  return rewriter.notifyMatchFailure(storeOp, "not a subview producer");
523 
524  LogicalResult preconditionResult =
525  preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp);
526  if (failed(preconditionResult))
527  return preconditionResult;
528 
529  SmallVector<Value> indices(storeOp.getIndices().begin(),
530  storeOp.getIndices().end());
531  // For affine ops, we need to apply the map to get the operands to get the
532  // "actual" indices.
533  if (auto affineStoreOp =
534  dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
535  AffineMap affineMap = affineStoreOp.getAffineMap();
536  auto expandedIndices = calculateExpandedAccessIndices(
537  affineMap, indices, storeOp.getLoc(), rewriter);
538  indices.assign(expandedIndices.begin(), expandedIndices.end());
539  }
540  SmallVector<Value> sourceIndices;
542  rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
543  subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
544  sourceIndices);
545 
547  .Case([&](affine::AffineStoreOp op) {
548  rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
549  op, op.getValue(), subViewOp.getSource(), sourceIndices);
550  })
551  .Case([&](memref::StoreOp op) {
552  rewriter.replaceOpWithNewOp<memref::StoreOp>(
553  op, op.getValue(), subViewOp.getSource(), sourceIndices,
554  op.getNontemporal());
555  })
556  .Case([&](vector::TransferWriteOp op) {
557  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
558  op, op.getValue(), subViewOp.getSource(), sourceIndices,
560  op.getPermutationMap(), subViewOp.getSourceType().getRank(),
561  subViewOp.getDroppedDims())),
562  op.getMask(), op.getInBoundsAttr());
563  })
564  .Case([&](vector::StoreOp op) {
565  rewriter.replaceOpWithNewOp<vector::StoreOp>(
566  op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
567  })
568  .Case([&](vector::MaskedStoreOp op) {
569  rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
570  op, subViewOp.getSource(), sourceIndices, op.getMask(),
571  op.getValueToStore());
572  })
573  .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
574  rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
575  op, op.getSrc(), subViewOp.getSource(), sourceIndices,
576  op.getLeadDimension(), op.getTransposeAttr());
577  })
578  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
579  return success();
580 }
581 
582 template <typename OpTy>
583 LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
584  OpTy storeOp, PatternRewriter &rewriter) const {
585  auto expandShapeOp =
586  getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>();
587 
588  if (!expandShapeOp)
589  return failure();
590 
591  SmallVector<Value> indices(storeOp.getIndices().begin(),
592  storeOp.getIndices().end());
593  // For affine ops, we need to apply the map to get the operands to get the
594  // "actual" indices.
595  if (auto affineStoreOp =
596  dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
597  AffineMap affineMap = affineStoreOp.getAffineMap();
598  auto expandedIndices = calculateExpandedAccessIndices(
599  affineMap, indices, storeOp.getLoc(), rewriter);
600  indices.assign(expandedIndices.begin(), expandedIndices.end());
601  }
602  SmallVector<Value> sourceIndices;
604  storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
605  return failure();
607  .Case<affine::AffineStoreOp, memref::StoreOp>([&](auto op) {
608  rewriter.replaceOpWithNewOp<decltype(op)>(storeOp, storeOp.getValue(),
609  expandShapeOp.getViewSource(),
610  sourceIndices);
611  })
612  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
613  return success();
614 }
615 
616 template <typename OpTy>
617 LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
618  OpTy storeOp, PatternRewriter &rewriter) const {
619  auto collapseShapeOp = getMemRefOperand(storeOp)
620  .template getDefiningOp<memref::CollapseShapeOp>();
621 
622  if (!collapseShapeOp)
623  return failure();
624 
625  SmallVector<Value> indices(storeOp.getIndices().begin(),
626  storeOp.getIndices().end());
627  // For affine ops, we need to apply the map to get the operands to get the
628  // "actual" indices.
629  if (auto affineStoreOp =
630  dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
631  AffineMap affineMap = affineStoreOp.getAffineMap();
632  auto expandedIndices = calculateExpandedAccessIndices(
633  affineMap, indices, storeOp.getLoc(), rewriter);
634  indices.assign(expandedIndices.begin(), expandedIndices.end());
635  }
636  SmallVector<Value> sourceIndices;
638  storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
639  return failure();
641  .Case<affine::AffineStoreOp, memref::StoreOp>([&](auto op) {
642  rewriter.replaceOpWithNewOp<decltype(op)>(
643  storeOp, storeOp.getValue(), collapseShapeOp.getViewSource(),
644  sourceIndices);
645  })
646  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
647  return success();
648 }
649 
650 LogicalResult NvgpuAsyncCopyOpSubViewOpFolder::matchAndRewrite(
651  nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const {
652 
653  LLVM_DEBUG(DBGS() << "copyOp : " << copyOp << "\n");
654 
655  auto srcSubViewOp =
656  copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
657  auto dstSubViewOp =
658  copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
659 
660  if (!(srcSubViewOp || dstSubViewOp))
661  return rewriter.notifyMatchFailure(copyOp, "does not use subview ops for "
662  "source or destination");
663 
664  // If the source is a subview, we need to resolve the indices.
665  SmallVector<Value> srcindices(copyOp.getSrcIndices().begin(),
666  copyOp.getSrcIndices().end());
667  SmallVector<Value> foldedSrcIndices(srcindices);
668 
669  if (srcSubViewOp) {
670  LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n");
672  rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
673  srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
674  srcindices, foldedSrcIndices);
675  }
676 
677  // If the destination is a subview, we need to resolve the indices.
678  SmallVector<Value> dstindices(copyOp.getDstIndices().begin(),
679  copyOp.getDstIndices().end());
680  SmallVector<Value> foldedDstIndices(dstindices);
681 
682  if (dstSubViewOp) {
683  LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n");
685  rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
686  dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
687  dstindices, foldedDstIndices);
688  }
689 
690  // Replace the copy op with a new copy op that uses the source and destination
691  // of the subview.
692  rewriter.replaceOpWithNewOp<nvgpu::DeviceAsyncCopyOp>(
693  copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()),
694  (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
695  foldedDstIndices,
696  (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
697  foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
698  copyOp.getBypassL1Attr());
699 
700  return success();
701 }
702 
704  patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>,
705  LoadOpOfSubViewOpFolder<memref::LoadOp>,
706  LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
707  LoadOpOfSubViewOpFolder<vector::LoadOp>,
708  LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
709  LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
710  LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
711  StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
712  StoreOpOfSubViewOpFolder<memref::StoreOp>,
713  StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
714  StoreOpOfSubViewOpFolder<vector::StoreOp>,
715  StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
716  StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
717  LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
718  LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
719  StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
720  StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
721  LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>,
722  LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
723  StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
724  StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
725  SubViewOfSubViewFolder, NvgpuAsyncCopyOpSubViewOpFolder>(
726  patterns.getContext());
727 }
728 
729 //===----------------------------------------------------------------------===//
730 // Pass registration
731 //===----------------------------------------------------------------------===//
732 
733 namespace {
734 
735 struct FoldMemRefAliasOpsPass final
736  : public memref::impl::FoldMemRefAliasOpsBase<FoldMemRefAliasOpsPass> {
737  void runOnOperation() override;
738 };
739 
740 } // namespace
741 
742 void FoldMemRefAliasOpsPass::runOnOperation() {
743  RewritePatternSet patterns(&getContext());
745  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
746 }
747 
748 std::unique_ptr<Pass> memref::createFoldMemRefAliasOpsPass() {
749  return std::make_unique<FoldMemRefAliasOpsPass>();
750 }
static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, Operation *op, memref::SubViewOp subviewOp)
static LogicalResult preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp, memref::SubViewOp subviewOp)
static SmallVector< Value > calculateExpandedAccessIndices(AffineMap affineMap, const SmallVector< Value > &indices, Location loc, PatternRewriter &rewriter)
static LogicalResult resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
static Value getMemRefOperand(LoadOrStoreOpTy op)
Helpers to access the memref operand for each op.
static LogicalResult resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
#define DBGS()
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
Definition: AffineMap.cpp:388
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:617
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:371
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition: Builders.cpp:385
MLIRContext * getContext() const
Definition: Builders.h:55
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:775
MLIRContext * getContext() const
Definition: PatternMatch.h:812
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:836
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:708
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:534
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void resolveSizesIntoOpWithSizes(ArrayRef< OpFoldResult > sourceSizes, ArrayRef< OpFoldResult > destSizes, const llvm::SmallBitVector &rankReducedSourceDims, SmallVectorImpl< OpFoldResult > &resolvedSizes)
Given sourceSizes, destSizes and information about which dimensions are dropped by the source: rankRe...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1188
std::unique_ptr< Pass > createFoldMemRefAliasOpsPass()
Creates an operation pass to fold memref aliasing ops into consumer load/store ops into patterns.
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
Definition: AffineMap.cpp:917
void bindDimsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Definition: AffineExpr.h:354
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:41
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357