MLIR  20.0.0git
FoldMemRefAliasOps.cpp
Go to the documentation of this file.
1 //===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass folds loading/storing from/to subview ops into
10 // loading/storing from/to the original memref.
11 //
12 //===----------------------------------------------------------------------===//
13 
26 #include "mlir/IR/AffineMap.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/SmallBitVector.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/Debug.h"
32 
33 #define DEBUG_TYPE "fold-memref-alias-ops"
34 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
35 
36 namespace mlir {
37 namespace memref {
38 #define GEN_PASS_DEF_FOLDMEMREFALIASOPS
39 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
40 } // namespace memref
41 } // namespace mlir
42 
43 using namespace mlir;
44 
45 //===----------------------------------------------------------------------===//
46 // Utility functions
47 //===----------------------------------------------------------------------===//
48 
49 /// Given the 'indices' of a load/store operation where the memref is a result
50 /// of a expand_shape op, returns the indices w.r.t to the source memref of the
51 /// expand_shape op. For example
52 ///
53 /// %0 = ... : memref<12x42xf32>
54 /// %1 = memref.expand_shape %0 [[0, 1], [2]]
55 /// : memref<12x42xf32> into memref<2x6x42xf32>
56 /// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
57 ///
58 /// could be folded into
59 ///
60 /// %2 = load %0[6 * i1 + i2, %i3] :
61 /// memref<12x42xf32>
62 static LogicalResult
64  memref::ExpandShapeOp expandShapeOp,
65  ValueRange indices,
66  SmallVectorImpl<Value> &sourceIndices) {
67  // Record the rewriter context for constructing ops later.
68  MLIRContext *ctx = rewriter.getContext();
69 
70  // Capture expand_shape's input dimensions as `SmallVector<OpFoldResult>`.
71  // This is done for the purpose of inferring the output shape via
72  // `inferExpandOutputShape` which will in turn be used for suffix product
73  // calculation later.
75  MemRefType srcType = expandShapeOp.getSrcType();
76 
77  for (int64_t i = 0, e = srcType.getRank(); i < e; ++i) {
78  if (srcType.isDynamicDim(i)) {
79  srcShape.push_back(
80  rewriter.create<memref::DimOp>(loc, expandShapeOp.getSrc(), i)
81  .getResult());
82  } else {
83  srcShape.push_back(rewriter.getIndexAttr(srcType.getShape()[i]));
84  }
85  }
86 
87  auto outputShape = inferExpandShapeOutputShape(
88  rewriter, loc, expandShapeOp.getResultType(),
89  expandShapeOp.getReassociationIndices(), srcShape);
90  if (!outputShape.has_value())
91  return failure();
92 
93  // Traverse all reassociation groups to determine the appropriate indices
94  // corresponding to each one of them post op folding.
95  for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
96  assert(!groups.empty() && "association indices groups cannot be empty");
97  // Flag to indicate the presence of dynamic dimensions in current
98  // reassociation group.
99  int64_t groupSize = groups.size();
100 
101  // Group output dimensions utilized in this reassociation group for suffix
102  // product calculation.
103  SmallVector<OpFoldResult> sizesVal(groupSize);
104  for (int64_t i = 0; i < groupSize; ++i) {
105  sizesVal[i] = (*outputShape)[groups[i]];
106  }
107 
108  // Calculate suffix product of relevant output dimension sizes.
109  SmallVector<OpFoldResult> suffixProduct =
110  memref::computeSuffixProductIRBlock(loc, rewriter, sizesVal);
111 
112  // Create affine expression variables for dimensions and symbols in the
113  // newly constructed affine map.
114  SmallVector<AffineExpr> dims(groupSize), symbols(groupSize);
115  bindDimsList<AffineExpr>(ctx, dims);
116  bindSymbolsList<AffineExpr>(ctx, symbols);
117 
118  // Linearize binded dimensions and symbols to construct the resultant
119  // affine expression for this indice.
120  AffineExpr srcIndexExpr = linearize(ctx, dims, symbols);
121 
122  // Record the load index corresponding to each dimension in the
123  // reassociation group. These are later supplied as operands to the affine
124  // map used for calulating relevant index post op folding.
125  SmallVector<OpFoldResult> dynamicIndices(groupSize);
126  for (int64_t i = 0; i < groupSize; i++)
127  dynamicIndices[i] = indices[groups[i]];
128 
129  // Supply suffix product results followed by load op indices as operands
130  // to the map.
131  SmallVector<OpFoldResult> mapOperands;
132  llvm::append_range(mapOperands, suffixProduct);
133  llvm::append_range(mapOperands, dynamicIndices);
134 
135  // Creating maximally folded and composed affine.apply composes better
136  // with other transformations without interleaving canonicalization
137  // passes.
139  rewriter, loc,
140  AffineMap::get(/*numDims=*/groupSize,
141  /*numSymbols=*/groupSize, /*expression=*/srcIndexExpr),
142  mapOperands);
143 
144  // Push index value in the op post folding corresponding to this
145  // reassociation group.
146  sourceIndices.push_back(
147  getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
148  }
149  return success();
150 }
151 
152 /// Given the 'indices' of a load/store operation where the memref is a result
153 /// of a collapse_shape op, returns the indices w.r.t to the source memref of
154 /// the collapse_shape op. For example
155 ///
156 /// %0 = ... : memref<2x6x42xf32>
157 /// %1 = memref.collapse_shape %0 [[0, 1], [2]]
158 /// : memref<2x6x42xf32> into memref<12x42xf32>
159 /// %2 = load %1[%i1, %i2] : memref<12x42xf32>
160 ///
161 /// could be folded into
162 ///
163 /// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
164 /// memref<2x6x42xf32>
165 static LogicalResult
167  memref::CollapseShapeOp collapseShapeOp,
168  ValueRange indices,
169  SmallVectorImpl<Value> &sourceIndices) {
170  int64_t cnt = 0;
171  SmallVector<Value> tmp(indices.size());
172  SmallVector<OpFoldResult> dynamicIndices;
173  for (ArrayRef<int64_t> groups : collapseShapeOp.getReassociationIndices()) {
174  assert(!groups.empty() && "association indices groups cannot be empty");
175  dynamicIndices.push_back(indices[cnt++]);
176  int64_t groupSize = groups.size();
177 
178  // Calculate suffix product for all collapse op source dimension sizes
179  // except the most major one of each group.
180  // We allow the most major source dimension to be dynamic but enforce all
181  // others to be known statically.
182  SmallVector<int64_t> sizes(groupSize, 1);
183  for (int64_t i = 1; i < groupSize; ++i) {
184  sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
185  if (sizes[i] == ShapedType::kDynamic)
186  return failure();
187  }
188  SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
189 
190  // Derive the index values along all dimensions of the source corresponding
191  // to the index wrt to collapsed shape op output.
192  auto d0 = rewriter.getAffineDimExpr(0);
193  SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, suffixProduct);
194 
195  // Construct the AffineApplyOp for each delinearizingExpr.
196  for (int64_t i = 0; i < groupSize; i++) {
198  rewriter, loc,
199  AffineMap::get(/*numDims=*/1, /*numSymbols=*/0,
200  delinearizingExprs[i]),
201  dynamicIndices);
202  sourceIndices.push_back(
203  getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
204  }
205  dynamicIndices.clear();
206  }
207  if (collapseShapeOp.getReassociationIndices().empty()) {
208  auto zeroAffineMap = rewriter.getConstantAffineMap(0);
209  int64_t srcRank =
210  cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
211  for (int64_t i = 0; i < srcRank; i++) {
213  rewriter, loc, zeroAffineMap, dynamicIndices);
214  sourceIndices.push_back(
215  getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
216  }
217  }
218  return success();
219 }
220 
221 /// Helpers to access the memref operand for each op.
222 template <typename LoadOrStoreOpTy>
223 static Value getMemRefOperand(LoadOrStoreOpTy op) {
224  return op.getMemref();
225 }
226 
227 static Value getMemRefOperand(vector::TransferReadOp op) {
228  return op.getSource();
229 }
230 
231 static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
232  return op.getSrcMemref();
233 }
234 
235 static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
236 
237 static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); }
238 
239 static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
240 
241 static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); }
242 
243 static Value getMemRefOperand(vector::TransferWriteOp op) {
244  return op.getSource();
245 }
246 
247 static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) {
248  return op.getSrcMemref();
249 }
250 
251 static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) {
252  return op.getDstMemref();
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // Patterns
257 //===----------------------------------------------------------------------===//
258 
259 namespace {
260 /// Merges subview operation with load/transferRead operation.
261 template <typename OpTy>
262 class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
263 public:
265 
266  LogicalResult matchAndRewrite(OpTy loadOp,
267  PatternRewriter &rewriter) const override;
268 };
269 
270 /// Merges expand_shape operation with load/transferRead operation.
271 template <typename OpTy>
272 class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
273 public:
275 
276  LogicalResult matchAndRewrite(OpTy loadOp,
277  PatternRewriter &rewriter) const override;
278 };
279 
280 /// Merges collapse_shape operation with load/transferRead operation.
281 template <typename OpTy>
282 class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
283 public:
285 
286  LogicalResult matchAndRewrite(OpTy loadOp,
287  PatternRewriter &rewriter) const override;
288 };
289 
290 /// Merges subview operation with store/transferWriteOp operation.
291 template <typename OpTy>
292 class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
293 public:
295 
296  LogicalResult matchAndRewrite(OpTy storeOp,
297  PatternRewriter &rewriter) const override;
298 };
299 
300 /// Merges expand_shape operation with store/transferWriteOp operation.
301 template <typename OpTy>
302 class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
303 public:
305 
306  LogicalResult matchAndRewrite(OpTy storeOp,
307  PatternRewriter &rewriter) const override;
308 };
309 
310 /// Merges collapse_shape operation with store/transferWriteOp operation.
311 template <typename OpTy>
312 class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
313 public:
315 
316  LogicalResult matchAndRewrite(OpTy storeOp,
317  PatternRewriter &rewriter) const override;
318 };
319 
320 /// Folds subview(subview(x)) to a single subview(x).
321 class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
322 public:
324 
325  LogicalResult matchAndRewrite(memref::SubViewOp subView,
326  PatternRewriter &rewriter) const override {
327  auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
328  if (!srcSubView)
329  return failure();
330 
331  // TODO: relax unit stride assumption.
332  if (!subView.hasUnitStride()) {
333  return rewriter.notifyMatchFailure(subView, "requires unit strides");
334  }
335  if (!srcSubView.hasUnitStride()) {
336  return rewriter.notifyMatchFailure(srcSubView, "requires unit strides");
337  }
338 
339  // Resolve sizes according to dropped dims.
340  SmallVector<OpFoldResult> resolvedSizes;
341  llvm::SmallBitVector srcDroppedDims = srcSubView.getDroppedDims();
342  affine::resolveSizesIntoOpWithSizes(srcSubView.getMixedSizes(),
343  subView.getMixedSizes(), srcDroppedDims,
344  resolvedSizes);
345 
346  // Resolve offsets according to source offsets and strides.
347  SmallVector<Value> resolvedOffsets;
349  rewriter, subView.getLoc(), srcSubView.getMixedOffsets(),
350  srcSubView.getMixedStrides(), srcDroppedDims, subView.getMixedOffsets(),
351  resolvedOffsets);
352 
353  // Replace original op.
354  rewriter.replaceOpWithNewOp<memref::SubViewOp>(
355  subView, subView.getType(), srcSubView.getSource(),
356  getAsOpFoldResult(resolvedOffsets), resolvedSizes,
357  srcSubView.getMixedStrides());
358 
359  return success();
360  }
361 };
362 
363 /// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern
364 /// is folds subview on src and dst memref of the copy.
365 class NVGPUAsyncCopyOpSubViewOpFolder final
366  : public OpRewritePattern<nvgpu::DeviceAsyncCopyOp> {
367 public:
369 
370  LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
371  PatternRewriter &rewriter) const override;
372 };
373 } // namespace
374 
375 static SmallVector<Value>
377  const SmallVector<Value> &indices, Location loc,
378  PatternRewriter &rewriter) {
379  SmallVector<OpFoldResult> indicesOfr(llvm::to_vector(
380  llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; })));
381  SmallVector<Value> expandedIndices;
382  for (unsigned i = 0, e = affineMap.getNumResults(); i < e; i++) {
384  rewriter, loc, affineMap.getSubMap({i}), indicesOfr);
385  expandedIndices.push_back(
386  getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
387  }
388  return expandedIndices;
389 }
390 
391 template <typename XferOp>
392 static LogicalResult
394  memref::SubViewOp subviewOp) {
395  static_assert(
396  !llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
397  "must be a vector transfer op");
398  if (xferOp.hasOutOfBoundsDim())
399  return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
400  if (!subviewOp.hasUnitStride()) {
401  return rewriter.notifyMatchFailure(
402  xferOp, "non-1 stride subview, need to track strides in folded memref");
403  }
404  return success();
405 }
406 
407 static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
408  Operation *op,
409  memref::SubViewOp subviewOp) {
410  return success();
411 }
412 
413 static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
414  vector::TransferReadOp readOp,
415  memref::SubViewOp subviewOp) {
416  return preconditionsFoldSubViewOpImpl(rewriter, readOp, subviewOp);
417 }
418 
419 static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
420  vector::TransferWriteOp writeOp,
421  memref::SubViewOp subviewOp) {
422  return preconditionsFoldSubViewOpImpl(rewriter, writeOp, subviewOp);
423 }
424 
425 template <typename OpTy>
426 LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
427  OpTy loadOp, PatternRewriter &rewriter) const {
428  auto subViewOp =
429  getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
430 
431  if (!subViewOp)
432  return rewriter.notifyMatchFailure(loadOp, "not a subview producer");
433 
434  LogicalResult preconditionResult =
435  preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp);
436  if (failed(preconditionResult))
437  return preconditionResult;
438 
439  SmallVector<Value> indices(loadOp.getIndices().begin(),
440  loadOp.getIndices().end());
441  // For affine ops, we need to apply the map to get the operands to get the
442  // "actual" indices.
443  if (auto affineLoadOp =
444  dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
445  AffineMap affineMap = affineLoadOp.getAffineMap();
446  auto expandedIndices = calculateExpandedAccessIndices(
447  affineMap, indices, loadOp.getLoc(), rewriter);
448  indices.assign(expandedIndices.begin(), expandedIndices.end());
449  }
450  SmallVector<Value> sourceIndices;
452  rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
453  subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
454  sourceIndices);
455 
457  .Case([&](affine::AffineLoadOp op) {
458  rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
459  loadOp, subViewOp.getSource(), sourceIndices);
460  })
461  .Case([&](memref::LoadOp op) {
462  rewriter.replaceOpWithNewOp<memref::LoadOp>(
463  loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
464  })
465  .Case([&](vector::LoadOp op) {
466  rewriter.replaceOpWithNewOp<vector::LoadOp>(
467  op, op.getType(), subViewOp.getSource(), sourceIndices);
468  })
469  .Case([&](vector::MaskedLoadOp op) {
470  rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
471  op, op.getType(), subViewOp.getSource(), sourceIndices,
472  op.getMask(), op.getPassThru());
473  })
474  .Case([&](vector::TransferReadOp op) {
475  rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
476  op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
478  op.getPermutationMap(), subViewOp.getSourceType().getRank(),
479  subViewOp.getDroppedDims())),
480  op.getPadding(), op.getMask(), op.getInBoundsAttr());
481  })
482  .Case([&](gpu::SubgroupMmaLoadMatrixOp op) {
483  rewriter.replaceOpWithNewOp<gpu::SubgroupMmaLoadMatrixOp>(
484  op, op.getType(), subViewOp.getSource(), sourceIndices,
485  op.getLeadDimension(), op.getTransposeAttr());
486  })
487  .Case([&](nvgpu::LdMatrixOp op) {
488  rewriter.replaceOpWithNewOp<nvgpu::LdMatrixOp>(
489  op, op.getType(), subViewOp.getSource(), sourceIndices,
490  op.getTranspose(), op.getNumTiles());
491  })
492  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
493  return success();
494 }
495 
496 template <typename OpTy>
497 LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
498  OpTy loadOp, PatternRewriter &rewriter) const {
499  auto expandShapeOp =
500  getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>();
501 
502  if (!expandShapeOp)
503  return failure();
504 
505  SmallVector<Value> indices(loadOp.getIndices().begin(),
506  loadOp.getIndices().end());
507  // For affine ops, we need to apply the map to get the operands to get the
508  // "actual" indices.
509  if (auto affineLoadOp =
510  dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
511  AffineMap affineMap = affineLoadOp.getAffineMap();
512  auto expandedIndices = calculateExpandedAccessIndices(
513  affineMap, indices, loadOp.getLoc(), rewriter);
514  indices.assign(expandedIndices.begin(), expandedIndices.end());
515  }
516  SmallVector<Value> sourceIndices;
518  loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
519  return failure();
521  .Case([&](affine::AffineLoadOp op) {
522  rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
523  loadOp, expandShapeOp.getViewSource(), sourceIndices);
524  })
525  .Case([&](memref::LoadOp op) {
526  rewriter.replaceOpWithNewOp<memref::LoadOp>(
527  loadOp, expandShapeOp.getViewSource(), sourceIndices,
528  op.getNontemporal());
529  })
530  .Case([&](vector::LoadOp op) {
531  rewriter.replaceOpWithNewOp<vector::LoadOp>(
532  op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
533  op.getNontemporal());
534  })
535  .Case([&](vector::MaskedLoadOp op) {
536  rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
537  op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
538  op.getMask(), op.getPassThru());
539  })
540  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
541  return success();
542 }
543 
544 template <typename OpTy>
545 LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
546  OpTy loadOp, PatternRewriter &rewriter) const {
547  auto collapseShapeOp = getMemRefOperand(loadOp)
548  .template getDefiningOp<memref::CollapseShapeOp>();
549 
550  if (!collapseShapeOp)
551  return failure();
552 
553  SmallVector<Value> indices(loadOp.getIndices().begin(),
554  loadOp.getIndices().end());
555  // For affine ops, we need to apply the map to get the operands to get the
556  // "actual" indices.
557  if (auto affineLoadOp =
558  dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
559  AffineMap affineMap = affineLoadOp.getAffineMap();
560  auto expandedIndices = calculateExpandedAccessIndices(
561  affineMap, indices, loadOp.getLoc(), rewriter);
562  indices.assign(expandedIndices.begin(), expandedIndices.end());
563  }
564  SmallVector<Value> sourceIndices;
566  loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
567  return failure();
569  .Case([&](affine::AffineLoadOp op) {
570  rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
571  loadOp, collapseShapeOp.getViewSource(), sourceIndices);
572  })
573  .Case([&](memref::LoadOp op) {
574  rewriter.replaceOpWithNewOp<memref::LoadOp>(
575  loadOp, collapseShapeOp.getViewSource(), sourceIndices,
576  op.getNontemporal());
577  })
578  .Case([&](vector::LoadOp op) {
579  rewriter.replaceOpWithNewOp<vector::LoadOp>(
580  op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
581  op.getNontemporal());
582  })
583  .Case([&](vector::MaskedLoadOp op) {
584  rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
585  op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
586  op.getMask(), op.getPassThru());
587  })
588  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
589  return success();
590 }
591 
592 template <typename OpTy>
593 LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
594  OpTy storeOp, PatternRewriter &rewriter) const {
595  auto subViewOp =
596  getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
597 
598  if (!subViewOp)
599  return rewriter.notifyMatchFailure(storeOp, "not a subview producer");
600 
601  LogicalResult preconditionResult =
602  preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp);
603  if (failed(preconditionResult))
604  return preconditionResult;
605 
606  SmallVector<Value> indices(storeOp.getIndices().begin(),
607  storeOp.getIndices().end());
608  // For affine ops, we need to apply the map to get the operands to get the
609  // "actual" indices.
610  if (auto affineStoreOp =
611  dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
612  AffineMap affineMap = affineStoreOp.getAffineMap();
613  auto expandedIndices = calculateExpandedAccessIndices(
614  affineMap, indices, storeOp.getLoc(), rewriter);
615  indices.assign(expandedIndices.begin(), expandedIndices.end());
616  }
617  SmallVector<Value> sourceIndices;
619  rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
620  subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
621  sourceIndices);
622 
624  .Case([&](affine::AffineStoreOp op) {
625  rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
626  op, op.getValue(), subViewOp.getSource(), sourceIndices);
627  })
628  .Case([&](memref::StoreOp op) {
629  rewriter.replaceOpWithNewOp<memref::StoreOp>(
630  op, op.getValue(), subViewOp.getSource(), sourceIndices,
631  op.getNontemporal());
632  })
633  .Case([&](vector::TransferWriteOp op) {
634  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
635  op, op.getValue(), subViewOp.getSource(), sourceIndices,
637  op.getPermutationMap(), subViewOp.getSourceType().getRank(),
638  subViewOp.getDroppedDims())),
639  op.getMask(), op.getInBoundsAttr());
640  })
641  .Case([&](vector::StoreOp op) {
642  rewriter.replaceOpWithNewOp<vector::StoreOp>(
643  op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
644  })
645  .Case([&](vector::MaskedStoreOp op) {
646  rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
647  op, subViewOp.getSource(), sourceIndices, op.getMask(),
648  op.getValueToStore());
649  })
650  .Case([&](gpu::SubgroupMmaStoreMatrixOp op) {
651  rewriter.replaceOpWithNewOp<gpu::SubgroupMmaStoreMatrixOp>(
652  op, op.getSrc(), subViewOp.getSource(), sourceIndices,
653  op.getLeadDimension(), op.getTransposeAttr());
654  })
655  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
656  return success();
657 }
658 
659 template <typename OpTy>
660 LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
661  OpTy storeOp, PatternRewriter &rewriter) const {
662  auto expandShapeOp =
663  getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>();
664 
665  if (!expandShapeOp)
666  return failure();
667 
668  SmallVector<Value> indices(storeOp.getIndices().begin(),
669  storeOp.getIndices().end());
670  // For affine ops, we need to apply the map to get the operands to get the
671  // "actual" indices.
672  if (auto affineStoreOp =
673  dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
674  AffineMap affineMap = affineStoreOp.getAffineMap();
675  auto expandedIndices = calculateExpandedAccessIndices(
676  affineMap, indices, storeOp.getLoc(), rewriter);
677  indices.assign(expandedIndices.begin(), expandedIndices.end());
678  }
679  SmallVector<Value> sourceIndices;
681  storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
682  return failure();
684  .Case([&](affine::AffineStoreOp op) {
685  rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
686  storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
687  sourceIndices);
688  })
689  .Case([&](memref::StoreOp op) {
690  rewriter.replaceOpWithNewOp<memref::StoreOp>(
691  storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
692  sourceIndices, op.getNontemporal());
693  })
694  .Case([&](vector::StoreOp op) {
695  rewriter.replaceOpWithNewOp<vector::StoreOp>(
696  op, op.getValueToStore(), expandShapeOp.getViewSource(),
697  sourceIndices, op.getNontemporal());
698  })
699  .Case([&](vector::MaskedStoreOp op) {
700  rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
701  op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
702  op.getValueToStore());
703  })
704  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
705  return success();
706 }
707 
708 template <typename OpTy>
709 LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
710  OpTy storeOp, PatternRewriter &rewriter) const {
711  auto collapseShapeOp = getMemRefOperand(storeOp)
712  .template getDefiningOp<memref::CollapseShapeOp>();
713 
714  if (!collapseShapeOp)
715  return failure();
716 
717  SmallVector<Value> indices(storeOp.getIndices().begin(),
718  storeOp.getIndices().end());
719  // For affine ops, we need to apply the map to get the operands to get the
720  // "actual" indices.
721  if (auto affineStoreOp =
722  dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
723  AffineMap affineMap = affineStoreOp.getAffineMap();
724  auto expandedIndices = calculateExpandedAccessIndices(
725  affineMap, indices, storeOp.getLoc(), rewriter);
726  indices.assign(expandedIndices.begin(), expandedIndices.end());
727  }
728  SmallVector<Value> sourceIndices;
730  storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
731  return failure();
733  .Case([&](affine::AffineStoreOp op) {
734  rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
735  storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
736  sourceIndices);
737  })
738  .Case([&](memref::StoreOp op) {
739  rewriter.replaceOpWithNewOp<memref::StoreOp>(
740  storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
741  sourceIndices, op.getNontemporal());
742  })
743  .Case([&](vector::StoreOp op) {
744  rewriter.replaceOpWithNewOp<vector::StoreOp>(
745  op, op.getValueToStore(), collapseShapeOp.getViewSource(),
746  sourceIndices, op.getNontemporal());
747  })
748  .Case([&](vector::MaskedStoreOp op) {
749  rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
750  op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
751  op.getValueToStore());
752  })
753  .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
754  return success();
755 }
756 
757 LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
758  nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const {
759 
760  LLVM_DEBUG(DBGS() << "copyOp : " << copyOp << "\n");
761 
762  auto srcSubViewOp =
763  copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
764  auto dstSubViewOp =
765  copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
766 
767  if (!(srcSubViewOp || dstSubViewOp))
768  return rewriter.notifyMatchFailure(copyOp, "does not use subview ops for "
769  "source or destination");
770 
771  // If the source is a subview, we need to resolve the indices.
772  SmallVector<Value> srcindices(copyOp.getSrcIndices().begin(),
773  copyOp.getSrcIndices().end());
774  SmallVector<Value> foldedSrcIndices(srcindices);
775 
776  if (srcSubViewOp) {
777  LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n");
779  rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
780  srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
781  srcindices, foldedSrcIndices);
782  }
783 
784  // If the destination is a subview, we need to resolve the indices.
785  SmallVector<Value> dstindices(copyOp.getDstIndices().begin(),
786  copyOp.getDstIndices().end());
787  SmallVector<Value> foldedDstIndices(dstindices);
788 
789  if (dstSubViewOp) {
790  LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n");
792  rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
793  dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
794  dstindices, foldedDstIndices);
795  }
796 
797  // Replace the copy op with a new copy op that uses the source and destination
798  // of the subview.
799  rewriter.replaceOpWithNewOp<nvgpu::DeviceAsyncCopyOp>(
800  copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()),
801  (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
802  foldedDstIndices,
803  (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
804  foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
805  copyOp.getBypassL1Attr());
806 
807  return success();
808 }
809 
811  patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>,
812  LoadOpOfSubViewOpFolder<memref::LoadOp>,
813  LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
814  LoadOpOfSubViewOpFolder<vector::LoadOp>,
815  LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
816  LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
817  LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
818  StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
819  StoreOpOfSubViewOpFolder<memref::StoreOp>,
820  StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
821  StoreOpOfSubViewOpFolder<vector::StoreOp>,
822  StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
823  StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
824  LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
825  LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
826  LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
827  LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
828  StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
829  StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
830  StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
831  StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
832  LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>,
833  LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
834  LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
835  LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
836  StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
837  StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
838  StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
839  StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
840  SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
841  patterns.getContext());
842 }
843 
844 //===----------------------------------------------------------------------===//
845 // Pass registration
846 //===----------------------------------------------------------------------===//
847 
848 namespace {
849 
850 struct FoldMemRefAliasOpsPass final
851  : public memref::impl::FoldMemRefAliasOpsBase<FoldMemRefAliasOpsPass> {
852  void runOnOperation() override;
853 };
854 
855 } // namespace
856 
857 void FoldMemRefAliasOpsPass::runOnOperation() {
860  (void)applyPatternsGreedily(getOperation(), std::move(patterns));
861 }
862 
863 std::unique_ptr<Pass> memref::createFoldMemRefAliasOpsPass() {
864  return std::make_unique<FoldMemRefAliasOpsPass>();
865 }
static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, Operation *op, memref::SubViewOp subviewOp)
static LogicalResult preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp, memref::SubViewOp subviewOp)
static SmallVector< Value > calculateExpandedAccessIndices(AffineMap affineMap, const SmallVector< Value > &indices, Location loc, PatternRewriter &rewriter)
static LogicalResult resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
static Value getMemRefOperand(LoadOrStoreOpTy op)
Helpers to access the memref operand for each op.
static LogicalResult resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
#define DBGS()
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:654
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:404
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition: Builders.cpp:418
MLIRContext * getContext() const
Definition: Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void resolveSizesIntoOpWithSizes(ArrayRef< OpFoldResult > sourceSizes, ArrayRef< OpFoldResult > destSizes, const llvm::SmallBitVector &rankReducedSourceDims, SmallVectorImpl< OpFoldResult > &resolvedSizes)
Given sourceSizes, destSizes and information about which dimensions are dropped by the source: rankRe...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1194
std::unique_ptr< Pass > createFoldMemRefAliasOpsPass()
Creates an operation pass to fold memref aliasing ops into consumer load/store ops into patterns.
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
SmallVector< OpFoldResult > computeSuffixProductIRBlock(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > sizes)
Given a set of sizes, return the suffix product.
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
Definition: AffineMap.cpp:955
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
const FrozenRewritePatternSet & patterns
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:24
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358