MLIR  19.0.0git
FoldMemRefAliasOps.cpp
Go to the documentation of this file.
1 //===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass folds loading/storing from/to subview ops into
10 // loading/storing from/to the original memref.
11 //
12 //===----------------------------------------------------------------------===//
13 
26 #include "mlir/IR/AffineMap.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/SmallBitVector.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/Debug.h"
32 
33 #define DEBUG_TYPE "fold-memref-alias-ops"
34 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
35 
36 namespace mlir {
37 namespace memref {
38 #define GEN_PASS_DEF_FOLDMEMREFALIASOPS
39 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
40 } // namespace memref
41 } // namespace mlir
42 
43 using namespace mlir;
44 
45 //===----------------------------------------------------------------------===//
46 // Utility functions
47 //===----------------------------------------------------------------------===//
48 
49 /// Given the 'indices' of a load/store operation where the memref is a result
50 /// of a expand_shape op, returns the indices w.r.t to the source memref of the
51 /// expand_shape op. For example
52 ///
53 /// %0 = ... : memref<12x42xf32>
54 /// %1 = memref.expand_shape %0 [[0, 1], [2]]
55 /// : memref<12x42xf32> into memref<2x6x42xf32>
56 /// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
57 ///
58 /// could be folded into
59 ///
60 /// %2 = load %0[6 * i1 + i2, %i3] :
61 /// memref<12x42xf32>
62 static LogicalResult
64  memref::ExpandShapeOp expandShapeOp,
65  ValueRange indices,
66  SmallVectorImpl<Value> &sourceIndices) {
67  // Record the rewriter context for constructing ops later.
68  MLIRContext *ctx = rewriter.getContext();
69 
70  // Capture expand_shape's input dimensions as `SmallVector<OpFoldResult>`.
71  // This is done for the purpose of inferring the output shape via
72  // `inferExpandOutputShape` which will in turn be used for suffix product
73  // calculation later.
75  MemRefType srcType = expandShapeOp.getSrcType();
76 
77  for (int64_t i = 0, e = srcType.getRank(); i < e; ++i) {
78  if (srcType.isDynamicDim(i)) {
79  srcShape.push_back(
80  rewriter.create<memref::DimOp>(loc, expandShapeOp.getSrc(), i)
81  .getResult());
82  } else {
83  srcShape.push_back(rewriter.getIndexAttr(srcType.getShape()[i]));
84  }
85  }
86 
87  auto outputShape = inferExpandShapeOutputShape(
88  rewriter, loc, expandShapeOp.getResultType(),
89  expandShapeOp.getReassociationIndices(), srcShape);
90  if (!outputShape.has_value())
91  return failure();
92 
93  // Traverse all reassociation groups to determine the appropriate indices
94  // corresponding to each one of them post op folding.
95  for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
96  assert(!groups.empty() && "association indices groups cannot be empty");
97  // Flag to indicate the presence of dynamic dimensions in current
98  // reassociation group.
99  int64_t groupSize = groups.size();
100 
101  // Group output dimensions utilized in this reassociation group for suffix
102  // product calculation.
103  SmallVector<OpFoldResult> sizesVal(groupSize);
104  for (int64_t i = 0; i < groupSize; ++i) {
105  sizesVal[i] = (*outputShape)[groups[i]];
106  }
107 
108  // Calculate suffix product of relevant output dimension sizes.
109  SmallVector<OpFoldResult> suffixProduct =
110  memref::computeSuffixProductIRBlock(loc, rewriter, sizesVal);
111 
112  // Create affine expression variables for dimensions and symbols in the
113  // newly constructed affine map.
114  SmallVector<AffineExpr> dims(groupSize), symbols(groupSize);
115  bindDimsList<AffineExpr>(ctx, dims);
116  bindSymbolsList<AffineExpr>(ctx, symbols);
117 
118  // Linearize binded dimensions and symbols to construct the resultant
119  // affine expression for this indice.
120  AffineExpr srcIndexExpr = linearize(ctx, dims, symbols);
121 
122  // Record the load index corresponding to each dimension in the
123  // reassociation group. These are later supplied as operands to the affine
124  // map used for calulating relevant index post op folding.
125  SmallVector<OpFoldResult> dynamicIndices(groupSize);
126  for (int64_t i = 0; i < groupSize; i++)
127  dynamicIndices[i] = indices[groups[i]];
128 
129  // Supply suffix product results followed by load op indices as operands
130  // to the map.
131  SmallVector<OpFoldResult> mapOperands;
132  llvm::append_range(mapOperands, suffixProduct);
133  llvm::append_range(mapOperands, dynamicIndices);
134 
135  // Creating maximally folded and composed affine.apply composes better
136  // with other transformations without interleaving canonicalization
137  // passes.
139  rewriter, loc,
140  AffineMap::get(/*numDims=*/groupSize,
141  /*numSymbols=*/groupSize, /*expression=*/srcIndexExpr),
142  mapOperands);
143 
144  // Push index value in the op post folding corresponding to this
145  // reassociation group.
146  sourceIndices.push_back(
147  getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
148  }
149  return success();
150 }
151 
152 /// Given the 'indices' of a load/store operation where the memref is a result
153 /// of a collapse_shape op, returns the indices w.r.t to the source memref of
154 /// the collapse_shape op. For example
155 ///
156 /// %0 = ... : memref<2x6x42xf32>
157 /// %1 = memref.collapse_shape %0 [[0, 1], [2]]
158 /// : memref<2x6x42xf32> into memref<12x42xf32>
159 /// %2 = load %1[%i1, %i2] : memref<12x42xf32>
160 ///
161 /// could be folded into
162 ///
163 /// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
164 /// memref<2x6x42xf32>
165 static LogicalResult
167  memref::CollapseShapeOp collapseShapeOp,
168  ValueRange indices,
169  SmallVectorImpl<Value> &sourceIndices) {
170  int64_t cnt = 0;
171  SmallVector<Value> tmp(indices.size());
172  SmallVector<OpFoldResult> dynamicIndices;
173  for (ArrayRef<int64_t> groups : collapseShapeOp.getReassociationIndices()) {
174  assert(!groups.empty() && "association indices groups cannot be empty");
175  dynamicIndices.push_back(indices[cnt++]);
176  int64_t groupSize = groups.size();
177 
178  // Calculate suffix product for all collapse op source dimension sizes
179  // except the most major one of each group.
180  // We allow the most major source dimension to be dynamic but enforce all
181  // others to be known statically.
182  SmallVector<int64_t> sizes(groupSize, 1);
183  for (int64_t i = 1; i < groupSize; ++i) {
184  sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
185  if (sizes[i] == ShapedType::kDynamic)
186  return failure();
187  }
188  SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
189 
190  // Derive the index values along all dimensions of the source corresponding
191  // to the index wrt to collapsed shape op output.
192  auto d0 = rewriter.getAffineDimExpr(0);
193  SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, suffixProduct);
194 
195  // Construct the AffineApplyOp for each delinearizingExpr.
196  for (int64_t i = 0; i < groupSize; i++) {
198  rewriter, loc,
199  AffineMap::get(/*numDims=*/1, /*numSymbols=*/0,
200  delinearizingExprs[i]),
201  dynamicIndices);
202  sourceIndices.push_back(
203  getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
204  }
205  dynamicIndices.clear();
206  }
207  if (collapseShapeOp.getReassociationIndices().empty()) {
208  auto zeroAffineMap = rewriter.getConstantAffineMap(0);
209  int64_t srcRank =
210  cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
211  for (int64_t i = 0; i < srcRank; i++) {
213  rewriter, loc, zeroAffineMap, dynamicIndices);
214  sourceIndices.push_back(
215  getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
216  }
217  }
218  return success();
219 }
220 
221 /// Helpers to access the memref operand for each op.
222 template <typename LoadOrStoreOpTy>
223 static Value getMemRefOperand(LoadOrStoreOpTy op) {
224  return op.getMemref();
225 }
226 
227 static Value getMemRefOperand(vector::TransferReadOp op) {
228  return op.getSource();
229 }
230 
231 static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
232  return op.getSrcMemref();
233 }
234 
235 static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
236 
237 static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); }
238 
239 static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
240 
241 static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); }
242 
243 static Value getMemRefOperand(vector::TransferWriteOp op) {
244  return op.getSource();
245 }
246 
247 static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) {
248  return op.getSrcMemref();
249 }
250 
251 static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) {
252  return op.getDstMemref();
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // Patterns
257 //===----------------------------------------------------------------------===//
258 
259 namespace {
260 /// Merges subview operation with load/transferRead operation.
261 template <typename OpTy>
262 class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
263 public:
265 
266  LogicalResult matchAndRewrite(OpTy loadOp,
267  PatternRewriter &rewriter) const override;
268 };
269 
270 /// Merges expand_shape operation with load/transferRead operation.
271 template <typename OpTy>
272 class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
273 public:
275 
276  LogicalResult matchAndRewrite(OpTy loadOp,
277  PatternRewriter &rewriter) const override;
278 };
279 
280 /// Merges collapse_shape operation with load/transferRead operation.
281 template <typename OpTy>
282 class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
283 public:
285 
286  LogicalResult matchAndRewrite(OpTy loadOp,
287  PatternRewriter &rewriter) const override;
288 };
289 
290 /// Merges subview operation with store/transferWriteOp operation.
291 template <typename OpTy>
292 class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
293 public:
295 
296  LogicalResult matchAndRewrite(OpTy storeOp,
297  PatternRewriter &rewriter) const override;
298 };
299 
300 /// Merges expand_shape operation with store/transferWriteOp operation.
301 template <typename OpTy>
302 class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
303 public:
305 
306  LogicalResult matchAndRewrite(OpTy storeOp,
307  PatternRewriter &rewriter) const override;
308 };
309 
310 /// Merges collapse_shape operation with store/transferWriteOp operation.
311 template <typename OpTy>
312 class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
313 public:
315 
316  LogicalResult matchAndRewrite(OpTy storeOp,
317  PatternRewriter &rewriter) const override;
318 };
319 
320 /// Folds subview(subview(x)) to a single subview(x).
321 class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
322 public:
324 
325  LogicalResult matchAndRewrite(memref::SubViewOp subView,
326  PatternRewriter &rewriter) const override {
327  auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
328  if (!srcSubView)
329  return failure();
330 
331  // TODO: relax unit stride assumption.
332  if (!subView.hasUnitStride()) {
333  return rewriter.notifyMatchFailure(subView, "requires unit strides");
334  }
335  if (!srcSubView.hasUnitStride()) {
336  return rewriter.notifyMatchFailure(srcSubView, "requires unit strides");
337  }
338 
339  // Resolve sizes according to dropped dims.
340  SmallVector<OpFoldResult> resolvedSizes;
341  llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
342  affine::resolveSizesIntoOpWithSizes(srcSubView.getMixedSizes(),
343  subView.getMixedSizes(), srcDroppedDims,
344  resolvedSizes);
345 
346  // Resolve offsets according to source offsets and strides.
347  SmallVector<Value> resolvedOffsets;
349  rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
350  srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
351  resolvedOffsets);
352 
353  // Replace original op.
354  rewriter.replaceOpWithNewOp<memref::SubViewOp>(
355  subView, subView.getType(), srcSubView.getSource(),
356  getAsOpFoldResult(resolvedOffsets), resolvedSizes,
357  srcSubView.getMixedStrides());
358 
359  return success();
360  }
361 };
362 
363 /// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern
364 /// is folds subview on src and dst memref of the copy.
365 class NVGPUAsyncCopyOpSubViewOpFolder final
366  : public OpRewritePattern<nvgpu::DeviceAsyncCopyOp> {
367 public:
369 
370  LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
371  PatternRewriter &rewriter) const override;
372 };
373 } // namespace
374 
375 static SmallVector<Value>
377  const SmallVector<Value> &indices, Location loc,
378  PatternRewriter &rewriter) {
379  SmallVector<OpFoldResult> indicesOfr(llvm::to_vector(
380  llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; })));
381  SmallVector<Value> expandedIndices;
382  for (unsigned i = 0, e = affineMap.getNumResults(); i < e; i++) {
384  rewriter, loc, affineMap.getSubMap({i}), indicesOfr);
385  expandedIndices.push_back(
386  getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
387  }
388  return expandedIndices;
389 }
390 
391 template <typename XferOp>
392 static LogicalResult
394  memref::SubViewOp subviewOp) {
395  static_assert(
396  !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
397  "must be a vector transfer op");
398  if (xferOp.hasOutOfBoundsDim())
399  return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
400  if (!subviewOp.hasUnitStride()) {
401  return rewriter.notifyMatchFailure(
402  xferOp, "non-1 stride subview, need to track strides in folded memref");
403  }
404  return success();
405 }
406 
408  Operation *op,
409  memref::SubViewOp subviewOp) {
410  return success();
411 }
412 
414  vector::TransferReadOp readOp,
415  memref::SubViewOp subviewOp) {
416  return preconditionsFoldSubViewOpImpl(rewriter, readOp, subviewOp);
417 }
418 
420  vector::TransferWriteOp writeOp,
421  memref::SubViewOp subviewOp) {
422  return preconditionsFoldSubViewOpImpl(rewriter, writeOp, subviewOp);
423 }
424 
425 template <typename OpTy>
426 LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
427  OpTy loadOp, PatternRewriter &rewriter) const {
428  auto subViewOp =
429  getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
430 
431  if (!subViewOp)
432  return rewriter.notifyMatchFailure(loadOp, "not a subview producer");
433 
434  LogicalResult preconditionResult =
435  preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp);
436  if (failed(preconditionResult))
437  return preconditionResult;
438 
439  SmallVector<Value> indices(loadOp.getIndices().begin(),
440  loadOp.getIndices().end());
441  // For affine ops, we need to apply the map to get the operands to get the
442  // "actual" indices.
443  if (auto affineLoadOp =
444  dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
445  AffineMap affineMap = affineLoadOp.getAffineMap();
446  auto expandedIndices = calculateExpandedAccessIndices(
447  affineMap, indices, loadOp.getLoc(), rewriter);
448  indices.assign(expandedIndices.begin(), expandedIndices.end());
449  }
450  SmallVector<Value> sourceIndices;
452  rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
453  subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
454  sourceIndices);
455 
457  .Case([&](affine::AffineLoadOp op) {
458  rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
459  loadOp, subViewOp.getSource(), sourceIndices);
460  })
461  .Case([&](memref::LoadOp op) {
462  rewriter.replaceOpWithNewOp<memref::LoadOp>(
463  loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
464  })
465  .Case([&](vector::LoadOp op) {
466  rewriter.replaceOpWithNewOp<vector::LoadOp>(
467  op, op.getType(), subViewOp.getSource(), sourceIndices);
468  })
469  .Case([&](vector::MaskedLoadOp op) {
470  rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
471  op, op.getType(), subViewOp.getSource(), sourceIndices,
472  op.getMask(), op.getPassThru());
473  })
474  .Case([&](vector::TransferReadOp op) {
475  rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
476  op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
478  op.getPermutationMap(), subViewOp.getSourceType().getRank(),
479  subViewOp.getDroppedDims())),
480  op.getPadding(), op.getMask(), op.getInBoundsAttr());
481  })
482  .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
483  rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
484  op, op.getType(), subViewOp.getSource(), sourceIndices,
485  op.getLeadDimension(), op.getTransposeAttr());
486  })
487  .Case([&](nvgpu::LdMatrixOp op) {
488  rewriter.replaceOpWithNewOp<nvgpu::LdMatrixOp>(
489  op, op.getType(), subViewOp.getSource(), sourceIndices,
490  op.getTranspose(), op.getNumTiles());
491  })
492  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
493  return success();
494 }
495 
496 template <typename OpTy>
497 LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
498  OpTy loadOp, PatternRewriter &rewriter) const {
499  auto expandShapeOp =
500  getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>();
501 
502  if (!expandShapeOp)
503  return failure();
504 
505  SmallVector<Value> indices(loadOp.getIndices().begin(),
506  loadOp.getIndices().end());
507  // For affine ops, we need to apply the map to get the operands to get the
508  // "actual" indices.
509  if (auto affineLoadOp =
510  dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
511  AffineMap affineMap = affineLoadOp.getAffineMap();
512  auto expandedIndices = calculateExpandedAccessIndices(
513  affineMap, indices, loadOp.getLoc(), rewriter);
514  indices.assign(expandedIndices.begin(), expandedIndices.end());
515  }
516  SmallVector<Value> sourceIndices;
518  loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
519  return failure();
521  .Case<affine::AffineLoadOp, memref::LoadOp>([&](auto op) {
522  rewriter.replaceOpWithNewOp<decltype(op)>(
523  loadOp, expandShapeOp.getViewSource(), sourceIndices);
524  })
525  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
526  return success();
527 }
528 
529 template <typename OpTy>
530 LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
531  OpTy loadOp, PatternRewriter &rewriter) const {
532  auto collapseShapeOp = getMemRefOperand(loadOp)
533  .template getDefiningOp<memref::CollapseShapeOp>();
534 
535  if (!collapseShapeOp)
536  return failure();
537 
538  SmallVector<Value> indices(loadOp.getIndices().begin(),
539  loadOp.getIndices().end());
540  // For affine ops, we need to apply the map to get the operands to get the
541  // "actual" indices.
542  if (auto affineLoadOp =
543  dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
544  AffineMap affineMap = affineLoadOp.getAffineMap();
545  auto expandedIndices = calculateExpandedAccessIndices(
546  affineMap, indices, loadOp.getLoc(), rewriter);
547  indices.assign(expandedIndices.begin(), expandedIndices.end());
548  }
549  SmallVector<Value> sourceIndices;
551  loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
552  return failure();
554  .Case<affine::AffineLoadOp, memref::LoadOp>([&](auto op) {
555  rewriter.replaceOpWithNewOp<decltype(op)>(
556  loadOp, collapseShapeOp.getViewSource(), sourceIndices);
557  })
558  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
559  return success();
560 }
561 
562 template <typename OpTy>
563 LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
564  OpTy storeOp, PatternRewriter &rewriter) const {
565  auto subViewOp =
566  getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
567 
568  if (!subViewOp)
569  return rewriter.notifyMatchFailure(storeOp, "not a subview producer");
570 
571  LogicalResult preconditionResult =
572  preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp);
573  if (failed(preconditionResult))
574  return preconditionResult;
575 
576  SmallVector<Value> indices(storeOp.getIndices().begin(),
577  storeOp.getIndices().end());
578  // For affine ops, we need to apply the map to get the operands to get the
579  // "actual" indices.
580  if (auto affineStoreOp =
581  dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
582  AffineMap affineMap = affineStoreOp.getAffineMap();
583  auto expandedIndices = calculateExpandedAccessIndices(
584  affineMap, indices, storeOp.getLoc(), rewriter);
585  indices.assign(expandedIndices.begin(), expandedIndices.end());
586  }
587  SmallVector<Value> sourceIndices;
589  rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
590  subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
591  sourceIndices);
592 
594  .Case([&](affine::AffineStoreOp op) {
595  rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
596  op, op.getValue(), subViewOp.getSource(), sourceIndices);
597  })
598  .Case([&](memref::StoreOp op) {
599  rewriter.replaceOpWithNewOp<memref::StoreOp>(
600  op, op.getValue(), subViewOp.getSource(), sourceIndices,
601  op.getNontemporal());
602  })
603  .Case([&](vector::TransferWriteOp op) {
604  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
605  op, op.getValue(), subViewOp.getSource(), sourceIndices,
607  op.getPermutationMap(), subViewOp.getSourceType().getRank(),
608  subViewOp.getDroppedDims())),
609  op.getMask(), op.getInBoundsAttr());
610  })
611  .Case([&](vector::StoreOp op) {
612  rewriter.replaceOpWithNewOp<vector::StoreOp>(
613  op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
614  })
615  .Case([&](vector::MaskedStoreOp op) {
616  rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
617  op, subViewOp.getSource(), sourceIndices, op.getMask(),
618  op.getValueToStore());
619  })
620  .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
621  rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
622  op, op.getSrc(), subViewOp.getSource(), sourceIndices,
623  op.getLeadDimension(), op.getTransposeAttr());
624  })
625  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
626  return success();
627 }
628 
629 template <typename OpTy>
630 LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
631  OpTy storeOp, PatternRewriter &rewriter) const {
632  auto expandShapeOp =
633  getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>();
634 
635  if (!expandShapeOp)
636  return failure();
637 
638  SmallVector<Value> indices(storeOp.getIndices().begin(),
639  storeOp.getIndices().end());
640  // For affine ops, we need to apply the map to get the operands to get the
641  // "actual" indices.
642  if (auto affineStoreOp =
643  dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
644  AffineMap affineMap = affineStoreOp.getAffineMap();
645  auto expandedIndices = calculateExpandedAccessIndices(
646  affineMap, indices, storeOp.getLoc(), rewriter);
647  indices.assign(expandedIndices.begin(), expandedIndices.end());
648  }
649  SmallVector<Value> sourceIndices;
651  storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
652  return failure();
654  .Case<affine::AffineStoreOp, memref::StoreOp>([&](auto op) {
655  rewriter.replaceOpWithNewOp<decltype(op)>(storeOp, storeOp.getValue(),
656  expandShapeOp.getViewSource(),
657  sourceIndices);
658  })
659  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
660  return success();
661 }
662 
663 template <typename OpTy>
664 LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
665  OpTy storeOp, PatternRewriter &rewriter) const {
666  auto collapseShapeOp = getMemRefOperand(storeOp)
667  .template getDefiningOp<memref::CollapseShapeOp>();
668 
669  if (!collapseShapeOp)
670  return failure();
671 
672  SmallVector<Value> indices(storeOp.getIndices().begin(),
673  storeOp.getIndices().end());
674  // For affine ops, we need to apply the map to get the operands to get the
675  // "actual" indices.
676  if (auto affineStoreOp =
677  dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
678  AffineMap affineMap = affineStoreOp.getAffineMap();
679  auto expandedIndices = calculateExpandedAccessIndices(
680  affineMap, indices, storeOp.getLoc(), rewriter);
681  indices.assign(expandedIndices.begin(), expandedIndices.end());
682  }
683  SmallVector<Value> sourceIndices;
685  storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
686  return failure();
688  .Case<affine::AffineStoreOp, memref::StoreOp>([&](auto op) {
689  rewriter.replaceOpWithNewOp<decltype(op)>(
690  storeOp, storeOp.getValue(), collapseShapeOp.getViewSource(),
691  sourceIndices);
692  })
693  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
694  return success();
695 }
696 
697 LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
698  nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const {
699 
700  LLVM_DEBUG(DBGS() << "copyOp : " << copyOp << "\n");
701 
702  auto srcSubViewOp =
703  copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
704  auto dstSubViewOp =
705  copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
706 
707  if (!(srcSubViewOp || dstSubViewOp))
708  return rewriter.notifyMatchFailure(copyOp, "does not use subview ops for "
709  "source or destination");
710 
711  // If the source is a subview, we need to resolve the indices.
712  SmallVector<Value> srcindices(copyOp.getSrcIndices().begin(),
713  copyOp.getSrcIndices().end());
714  SmallVector<Value> foldedSrcIndices(srcindices);
715 
716  if (srcSubViewOp) {
717  LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n");
719  rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
720  srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
721  srcindices, foldedSrcIndices);
722  }
723 
724  // If the destination is a subview, we need to resolve the indices.
725  SmallVector<Value> dstindices(copyOp.getDstIndices().begin(),
726  copyOp.getDstIndices().end());
727  SmallVector<Value> foldedDstIndices(dstindices);
728 
729  if (dstSubViewOp) {
730  LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n");
732  rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
733  dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
734  dstindices, foldedDstIndices);
735  }
736 
737  // Replace the copy op with a new copy op that uses the source and destination
738  // of the subview.
739  rewriter.replaceOpWithNewOp<nvgpu::DeviceAsyncCopyOp>(
740  copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()),
741  (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
742  foldedDstIndices,
743  (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
744  foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
745  copyOp.getBypassL1Attr());
746 
747  return success();
748 }
749 
751  patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>,
752  LoadOpOfSubViewOpFolder<memref::LoadOp>,
753  LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
754  LoadOpOfSubViewOpFolder<vector::LoadOp>,
755  LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
756  LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
757  LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
758  StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
759  StoreOpOfSubViewOpFolder<memref::StoreOp>,
760  StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
761  StoreOpOfSubViewOpFolder<vector::StoreOp>,
762  StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
763  StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
764  LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
765  LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
766  StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
767  StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
768  LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>,
769  LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
770  StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
771  StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
772  SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
773  patterns.getContext());
774 }
775 
776 //===----------------------------------------------------------------------===//
777 // Pass registration
778 //===----------------------------------------------------------------------===//
779 
780 namespace {
781 
782 struct FoldMemRefAliasOpsPass final
783  : public memref::impl::FoldMemRefAliasOpsBase<FoldMemRefAliasOpsPass> {
784  void runOnOperation() override;
785 };
786 
787 } // namespace
788 
789 void FoldMemRefAliasOpsPass::runOnOperation() {
790  RewritePatternSet patterns(&getContext());
792  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
793 }
794 
795 std::unique_ptr<Pass> memref::createFoldMemRefAliasOpsPass() {
796  return std::make_unique<FoldMemRefAliasOpsPass>();
797 }
static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, Operation *op, memref::SubViewOp subviewOp)
static LogicalResult preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp, memref::SubViewOp subviewOp)
static SmallVector< Value > calculateExpandedAccessIndices(AffineMap affineMap, const SmallVector< Value > &indices, Location loc, PatternRewriter &rewriter)
static LogicalResult resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
static Value getMemRefOperand(LoadOrStoreOpTy op)
Helpers to access the memref operand for each op.
static LogicalResult resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
#define DBGS()
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
Definition: AffineMap.cpp:386
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:615
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:371
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition: Builders.cpp:385
MLIRContext * getContext() const
Definition: Builders.h:55
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void resolveSizesIntoOpWithSizes(ArrayRef< OpFoldResult > sourceSizes, ArrayRef< OpFoldResult > destSizes, const llvm::SmallBitVector &rankReducedSourceDims, SmallVectorImpl< OpFoldResult > &resolvedSizes)
Given sourceSizes, destSizes and information about which dimensions are dropped by the source: rankRe...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1188
std::unique_ptr< Pass > createFoldMemRefAliasOpsPass()
Creates an operation pass to fold memref aliasing ops into consumer load/store ops into patterns.
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
SmallVector< OpFoldResult > computeSuffixProductIRBlock(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > sizes)
Given a set of sizes, return the suffix product.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
Definition: AffineMap.cpp:915
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:103
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:24
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358