MLIR 23.0.0git
FoldMemRefAliasOps.cpp
Go to the documentation of this file.
1//===- FoldMemRefAliasOps.cpp - Fold memref alias ops ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation pass folds loading/storing from/to subview ops into
10// loading/storing from/to the original memref.
11//
12//===----------------------------------------------------------------------===//
13
22#include "mlir/IR/AffineExpr.h"
23#include "mlir/IR/AffineMap.h"
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/ADT/SmallBitVector.h"
28#include "llvm/Support/Debug.h"
29#include <cstdint>
30
31#define DEBUG_TYPE "fold-memref-alias-ops"
32#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
33
34namespace mlir {
35namespace memref {
36#define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS
37#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
38} // namespace memref
39} // namespace mlir
40
41using namespace mlir;
42
43//===----------------------------------------------------------------------===//
44// Utility functions
45//===----------------------------------------------------------------------===//
46
47/// Deterimine if the last N indices of `reassocitaion` are trivial - that is,
48/// check if they all contain exactly one dimension to collape/expand into.
49static bool
51 int64_t n) {
52 if (n <= 0)
53 return true;
54 if (n > static_cast<int64_t>(reassocs.size()))
55 return false;
56 return llvm::all_of(
57 reassocs.take_back(n),
58 [&](const ReassociationIndices &indices) { return indices.size() == 1; });
59}
60
61static bool hasTrailingUnitStrides(memref::SubViewOp subview, int64_t n) {
62 if (n <= 0)
63 return true;
64 ArrayRef<int64_t> strides = subview.getStaticStrides();
65 if (n > static_cast<int64_t>(strides.size()))
66 return false;
67 return llvm::all_of(strides.take_back(n), [](int64_t s) { return s == 1; });
68}
69
70//===----------------------------------------------------------------------===//
71// Patterns
72//===----------------------------------------------------------------------===//
73
74namespace {
75/// Folds subview(subview(x)) to a single subview(x).
76class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
77public:
78 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
79
80 LogicalResult matchAndRewrite(memref::SubViewOp subView,
81 PatternRewriter &rewriter) const override {
82 auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
83 if (!srcSubView)
84 return failure();
85
86 SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
88 rewriter, subView.getLoc(), srcSubView, subView,
89 srcSubView.getDroppedDims(), newOffsets, newSizes, newStrides)))
90 return failure();
91
92 // Replace original op.
93 rewriter.replaceOpWithNewOp<memref::SubViewOp>(
94 subView, subView.getType(), srcSubView.getSource(), newOffsets,
95 newSizes, newStrides);
96 return success();
97 }
98};
99
100/// Merges subview operations with load/store like operations unless such a
101/// merger would cause the strides between dimensions accessed by that operaton
102/// to change.
103struct AccessOpOfSubViewOpFolder final
104 : OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
105 using Base::Base;
106
107 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
108 PatternRewriter &rewriter) const override;
109};
110
111/// Merge a memref.expand_shape operation with an operation that accesses a
112/// memref by index unless that operation accesss more than one dimension of
113/// memory and any dimension other than the outermost dimension accessed this
114/// way would be merged. This prevents issuses from arising with, say, a
115/// vector.load of a 4x2 vector having the two trailing dimensions of the access
116/// get merged.
117struct AccessOpOfExpandShapeOpFolder final
118 : OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
119 using Base::Base;
120
121 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
122 PatternRewriter &rewriter) const override;
123};
124
125/// Merges an operation that accesses a memref by index with a
126/// memref.collapse_shape, unless this would break apart a dimension other than
127/// the outermost one that an operation accesses. This prevents, for example,
128/// transforming a load of a 3x8 vector from a 6x8 memref into a load
129/// from a 3x4x2 memref (as this would require special handling and could lead
130/// to invalid IR if that higher-dimensional memref comes from a subview) but
131/// does permit turning a load of a length-8 vector from a 3x8 memref into a
132/// load from a 3x2x8 one.
133struct AccessOpOfCollapseShapeOpFolder final
134 : OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
135 using Base::Base;
136
137 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
138 PatternRewriter &rewriter) const override;
139};
140
141/// Merges memref.subview operations present on the source or destination
142/// operands of indexed memory copy operations (DMA operations) into those
143/// operations. This is perfromed unconditionally, since folding in a subview
144/// cannot change the starting position of the copy, which is what the
145/// memref/index pair represent in DMA operations.
146struct IndexedMemCopyOpOfSubViewOpFolder final
147 : OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
148 using Base::Base;
149
150 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
151 PatternRewriter &rewriter) const override;
152};
153
154/// Merges memref.expand_shape operations that are present on the source or
155/// destination of an indexed memory copy/DMA into the memref/index arguments of
156/// that DMA. As with subviews, this can be done unconditionally.
157struct IndexedMemCopyOpOfExpandShapeOpFolder final
158 : OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
159 using Base::Base;
160
161 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
162 PatternRewriter &rewriter) const override;
163};
164
165/// Merges memref.collapse_shape operations that are present on the source or
166/// destination of an indexed memory copy/DMA into the memref/index arguments of
167/// that DMA. As with subviews, this can be done unconditionally.
168struct IndexedMemCopyOpOfCollapseShapeOpFolder final
169 : OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
170 using Base::Base;
171
172 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
173 PatternRewriter &rewriter) const override;
174};
175
176/// Merges memref.subview ops on the base argument to vector transfer operations
177/// into the base and indices of that transfer if:
178/// - The subview has unit strides on transfer dimensions
179/// - All the transfer dimensions are in-bounds
180/// This will correctly update said permutation map to account for dropped
181/// dimensions in rank-reducing subviews.
182struct TransferOpOfSubViewOpFolder final
183 : OpInterfaceRewritePattern<VectorTransferOpInterface> {
184 using Base::Base;
185
186 LogicalResult matchAndRewrite(VectorTransferOpInterface op,
187 PatternRewriter &rewriter) const override;
188};
189
190/// Merges memref.expand_shape ops that create the base of a vector transfer
191/// operation into the base and indices of that transfer. Does not act when the
192/// a dimension is potentially out of bounds, if one of the transfer dimensions
193/// would need to be strided because of the collapse, or if it would merge two
194/// dimensions that are both transfer dimensions.
195/// TODO: become more sophisticated about length-1 dimensions that are the
196/// result of an expansion becoming broadcasts.
197struct TransferOpOfExpandShapeOpFolder final
198 : OpInterfaceRewritePattern<VectorTransferOpInterface> {
199 using Base::Base;
200
201 LogicalResult matchAndRewrite(VectorTransferOpInterface op,
202 PatternRewriter &rewriter) const override;
203};
204
205/// Merges memref.collapse_shape ops that create the base of a vector transfer
206/// operation into the base and indices of that transfer. Does not act when the
207/// permutation map is not trivial, a dimension could be performing out of
208/// bounds reads, or if it would break apart a transfer dimension.
209struct TransferOpOfCollapseShapeOpFolder final
210 : OpInterfaceRewritePattern<VectorTransferOpInterface> {
211 using Base::Base;
212
213 LogicalResult matchAndRewrite(VectorTransferOpInterface op,
214 PatternRewriter &rewriter) const override;
215};
216} // namespace
217
218LogicalResult
219AccessOpOfSubViewOpFolder::matchAndRewrite(memref::IndexedAccessOpInterface op,
220 PatternRewriter &rewriter) const {
221 auto subview = op.getAccessedMemref().getDefiningOp<memref::SubViewOp>();
222 if (!subview)
223 return rewriter.notifyMatchFailure(op, "not accessing a subview");
224
225 SmallVector<int64_t> accessedShape = op.getAccessedShape();
226 // Note the subtle difference between accessedShape = {1} and accessedShape =
227 // {} here. The former prevents us from folding in a subview that doesn't
228 // have a unit stride on the final dimension, while the latter does not (since
229 // it indexes scalar accesses).
230 int64_t accessedDims = accessedShape.size();
231 if (!hasTrailingUnitStrides(subview, accessedDims))
232 return rewriter.notifyMatchFailure(
233 op, "non-unit stride on accessed dimensions");
234
235 llvm::SmallBitVector droppedDims = subview.getDroppedDims();
236 int64_t sourceRank = subview.getSourceType().getRank();
237
238 // Ignore outermost access dimension - we only care about dropped dimensions
239 // between the accessed op's results, as those could break the accessing op's
240 // semantics.
241 int64_t secondAccessedDim = sourceRank - (accessedDims - 1);
242 if (secondAccessedDim < sourceRank) {
243 for (int64_t d : llvm::seq(secondAccessedDim, sourceRank)) {
244 if (droppedDims.test(d))
245 return rewriter.notifyMatchFailure(
246 op, "reintroducing dropped dimension " + Twine(d) +
247 " would break access op semantics");
248 }
249 }
250
251 SmallVector<Value> sourceIndices;
253 rewriter, op.getLoc(), subview.getMixedOffsets(),
254 subview.getMixedStrides(), droppedDims, op.getIndices(), sourceIndices);
255
256 std::optional<SmallVector<Value>> newValues =
257 op.updateMemrefAndIndices(rewriter, subview.getSource(), sourceIndices);
258 if (newValues)
259 rewriter.replaceOp(op, *newValues);
260 return success();
261}
262
263LogicalResult AccessOpOfExpandShapeOpFolder::matchAndRewrite(
264 memref::IndexedAccessOpInterface op, PatternRewriter &rewriter) const {
265 auto expand = op.getAccessedMemref().getDefiningOp<memref::ExpandShapeOp>();
266 if (!expand)
267 return rewriter.notifyMatchFailure(op, "not accessing an expand_shape");
268
269 SmallVector<int64_t> rawAccessedShape = op.getAccessedShape();
270 ArrayRef<int64_t> accessedShape = rawAccessedShape;
271 if (expand.getSrcType().getRank() <
272 static_cast<int64_t>(accessedShape.size()))
273 return rewriter.notifyMatchFailure(
274 op, "expand_shape source rank is too small for the accessed shape");
275
276 // Cut off the leading dimension, since we don't care about modifying its
277 // strides.
278 if (!accessedShape.empty())
279 accessedShape = accessedShape.drop_front();
280
281 SmallVector<ReassociationIndices, 4> reassocs =
282 expand.getReassociationIndices();
283 if (!hasTrivialReassociationSuffix(reassocs, accessedShape.size()))
284 return rewriter.notifyMatchFailure(
285 op,
286 "expand_shape folding would merge semantically important dimensions");
287
288 SmallVector<Value> sourceIndices;
289 memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, expand,
290 op.getIndices(), sourceIndices,
291 op.hasInboundsIndices());
292
293 std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices(
294 rewriter, expand.getViewSource(), sourceIndices);
295 if (newValues)
296 rewriter.replaceOp(op, *newValues);
297 return success();
298}
299
300LogicalResult AccessOpOfCollapseShapeOpFolder::matchAndRewrite(
301 memref::IndexedAccessOpInterface op, PatternRewriter &rewriter) const {
302 auto collapse =
303 op.getAccessedMemref().getDefiningOp<memref::CollapseShapeOp>();
304 if (!collapse)
305 return rewriter.notifyMatchFailure(op, "not accessing a collapse_shape");
306
307 SmallVector<int64_t> rawAccessedShape = op.getAccessedShape();
308 ArrayRef<int64_t> accessedShape = rawAccessedShape;
309 if (collapse.getSrcType().getRank() <
310 static_cast<int64_t>(accessedShape.size()))
311 return rewriter.notifyMatchFailure(
312 op, "collapse_shape source rank is too small for the accessed shape");
313
314 // Cut off the leading dimension, since we don't care about its strides being
315 // modified and we know that the dimensions within its reassociation group, if
316 // it's non-trivial, must be contiguous.
317 if (!accessedShape.empty())
318 accessedShape = accessedShape.drop_front();
319
320 SmallVector<ReassociationIndices, 4> reassocs =
321 collapse.getReassociationIndices();
322 if (!hasTrivialReassociationSuffix(reassocs, accessedShape.size()))
323 return rewriter.notifyMatchFailure(op, "collapse_shape folding would merge "
324 "semantically important dimensions");
325
326 SmallVector<Value> sourceIndices;
327 memref::resolveSourceIndicesCollapseShape(op.getLoc(), rewriter, collapse,
328 op.getIndices(), sourceIndices,
329 op.hasInboundsIndices());
330
331 std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices(
332 rewriter, collapse.getViewSource(), sourceIndices);
333 if (newValues)
334 rewriter.replaceOp(op, *newValues);
335 return success();
336}
337
338LogicalResult IndexedMemCopyOpOfSubViewOpFolder::matchAndRewrite(
339 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const {
340 auto srcSubview = op.getSrc().getDefiningOp<memref::SubViewOp>();
341 auto dstSubview = op.getDst().getDefiningOp<memref::SubViewOp>();
342 if (!srcSubview && !dstSubview)
343 return rewriter.notifyMatchFailure(
344 op, "no subviews found on indexed copy inputs");
345
346 Value newSrc = op.getSrc();
347 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
348 Value newDst = op.getDst();
349 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
350 if (srcSubview) {
351 newSrc = srcSubview.getSource();
352 newSrcIndices.clear();
354 rewriter, op.getLoc(), srcSubview.getMixedOffsets(),
355 srcSubview.getMixedStrides(), srcSubview.getDroppedDims(),
356 op.getSrcIndices(), newSrcIndices);
357 }
358 if (dstSubview) {
359 newDst = dstSubview.getSource();
360 newDstIndices.clear();
362 rewriter, op.getLoc(), dstSubview.getMixedOffsets(),
363 dstSubview.getMixedStrides(), dstSubview.getDroppedDims(),
364 op.getDstIndices(), newDstIndices);
365 }
366 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
367 newDstIndices);
368 return success();
369}
370
371LogicalResult IndexedMemCopyOpOfExpandShapeOpFolder::matchAndRewrite(
372 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const {
373 auto srcExpand = op.getSrc().getDefiningOp<memref::ExpandShapeOp>();
374 auto dstExpand = op.getDst().getDefiningOp<memref::ExpandShapeOp>();
375 if (!srcExpand && !dstExpand)
376 return rewriter.notifyMatchFailure(
377 op, "no expand_shapes found on indexed copy inputs");
378
379 Value newSrc = op.getSrc();
380 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
381 Value newDst = op.getDst();
382 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
383 if (srcExpand) {
384 newSrc = srcExpand.getViewSource();
385 newSrcIndices.clear();
386 memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, srcExpand,
387 op.getSrcIndices(), newSrcIndices,
388 /*startsInbounds=*/true);
389 }
390 if (dstExpand) {
391 newDst = dstExpand.getViewSource();
392 newDstIndices.clear();
393 memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, dstExpand,
394 op.getDstIndices(), newDstIndices,
395 /*startsInbounds=*/true);
396 }
397 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
398 newDstIndices);
399 return success();
400}
401
402LogicalResult IndexedMemCopyOpOfCollapseShapeOpFolder::matchAndRewrite(
403 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const {
404 auto srcCollapse = op.getSrc().getDefiningOp<memref::CollapseShapeOp>();
405 auto dstCollapse = op.getDst().getDefiningOp<memref::CollapseShapeOp>();
406 if (!srcCollapse && !dstCollapse)
407 return rewriter.notifyMatchFailure(
408 op, "no collapse_shapes found on indexed copy inputs");
409
410 Value newSrc = op.getSrc();
411 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
412 Value newDst = op.getDst();
413 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
414 if (srcCollapse) {
415 newSrc = srcCollapse.getViewSource();
416 newSrcIndices.clear();
418 op.getLoc(), rewriter, srcCollapse, op.getSrcIndices(), newSrcIndices,
419 /*startsInbounds=*/true);
420 }
421 if (dstCollapse) {
422 newDst = dstCollapse.getViewSource();
423 newDstIndices.clear();
425 op.getLoc(), rewriter, dstCollapse, op.getDstIndices(), newDstIndices,
426 /*startsInbounds=*/true);
427 }
428 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
429 newDstIndices);
430 return success();
431}
432
433LogicalResult
434TransferOpOfSubViewOpFolder::matchAndRewrite(VectorTransferOpInterface op,
435 PatternRewriter &rewriter) const {
436 auto subview = op.getBase().getDefiningOp<memref::SubViewOp>();
437 if (!subview)
438 return rewriter.notifyMatchFailure(op, "not accessing a subview");
441 // Note: no identity permutation check here, since subview folding can handle
442 // complex permutations because it doesn't merge or split any individual
443 // dimension.
444 if (op.hasOutOfBoundsDim())
445 return rewriter.notifyMatchFailure(op, "out of bounds dimension");
446 VectorType vecTy = op.getVectorType();
447 // Because we know the permutation map is a minor identity, we know that the
448 // last N dimensions must have unit stride, where N is the vector rank.
449 if (!hasTrailingUnitStrides(subview, vecTy.getRank()))
450 return rewriter.notifyMatchFailure(subview, "non-unit stride within last " +
451 Twine(vecTy.getRank()) +
452 " dimensions");
453
454 AffineMap newPerm = expandDimsToRank(perm, subview.getSourceType().getRank(),
455 subview.getDroppedDims());
456
457 if (failed(op.mayUpdateStartingPosition(subview.getSourceType(), newPerm)))
458 return rewriter.notifyMatchFailure(subview,
459 "failed op-specific preconditions");
460
463 rewriter, op.getLoc(), subview.getMixedOffsets(),
464 subview.getMixedStrides(), subview.getDroppedDims(), op.getIndices(),
465 newIndices);
466 op.updateStartingPosition(rewriter, subview.getSource(), newIndices,
467 AffineMapAttr::get(newPerm));
468 return success();
469}
470
471LogicalResult TransferOpOfExpandShapeOpFolder::matchAndRewrite(
472 VectorTransferOpInterface op, PatternRewriter &rewriter) const {
473 auto expand = op.getBase().getDefiningOp<memref::ExpandShapeOp>();
474 if (!expand)
475 return rewriter.notifyMatchFailure(op, "not accessing an expand_shape");
476
477 if (op.hasOutOfBoundsDim())
478 return rewriter.notifyMatchFailure(op, "out of bounds dimension");
479
480 int64_t srcRank = expand.getSrc().getType().getRank();
481 int64_t vecRank = op.getVectorType().getRank();
482 if (srcRank < vecRank)
483 return rewriter.notifyMatchFailure(op,
484 "source rank is less than vector rank");
485
486 llvm::SmallDenseMap<int64_t, int64_t, 8> unstridedResDimToSrcDim;
487 for (auto [srcIdx, reassoc] :
488 llvm::enumerate(expand.getReassociationIndices())) {
489 unstridedResDimToSrcDim.insert({reassoc.back(), srcIdx});
491 // If every dimension of the expanded shape that appears in the permutation
492 // map is also present in the final entry of the expansions (meaning that
493 // collapsing in more values won't cause us to need to stride the index), we
494 // can fold in the expansion. (This doesn't currently account for expanding
495 // length X to X by 1, but it could in the future).
496 AffineMap permMap = op.getPermutationMap();
497 SmallVector<AffineExpr> newPermMapResults;
498 newPermMapResults.reserve(permMap.getNumResults());
499 for (AffineExpr permRes : permMap.getResults()) {
500 auto resDim = dyn_cast<AffineDimExpr>(permRes);
501 if (!resDim)
502 return rewriter.notifyMatchFailure(
503 op, "has non-dim entry in permutation map");
504 auto dimInSrc = unstridedResDimToSrcDim.find(resDim.getPosition());
505 if (dimInSrc == unstridedResDimToSrcDim.end())
506 return rewriter.notifyMatchFailure(op,
507 "permutation map result would be made "
508 "strided by expand_shape folding");
509 newPermMapResults.push_back(rewriter.getAffineDimExpr(dimInSrc->second));
510 }
511
512 auto newPerm = AffineMap::get(srcRank, 0, newPermMapResults, op.getContext());
513
514 if (failed(op.mayUpdateStartingPosition(expand.getSrc().getType(), newPerm)))
515 return rewriter.notifyMatchFailure(op, "failed op-specific preconditions");
516
517 SmallVector<Value> newIndices;
518 // We can use a disjoint linearization if we aren't masking, because then all
519 // indicators show that the start position will be in bounds.
520 memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, expand,
521 op.getIndices(), newIndices,
522 /*startsInbounds=*/!op.getMask());
523
524 op.updateStartingPosition(rewriter, expand.getViewSource(), newIndices,
525 AffineMapAttr::get(newPerm));
526 return success();
527}
528
529LogicalResult TransferOpOfCollapseShapeOpFolder::matchAndRewrite(
530 VectorTransferOpInterface op, PatternRewriter &rewriter) const {
531 auto collapse = op.getBase().getDefiningOp<memref::CollapseShapeOp>();
532 if (!collapse)
533 return rewriter.notifyMatchFailure(op, "not accessing a collapse_shape");
534
535 if (!op.getPermutationMap().isMinorIdentity())
536 return rewriter.notifyMatchFailure(op,
537 "non-minor identity permutation map");
538
539 if (op.hasOutOfBoundsDim())
540 return rewriter.notifyMatchFailure(op, "out of bounds dimension");
541
542 int64_t srcRank = collapse.getSrc().getType().getRank();
543 int64_t vecRank = op.getVectorType().getRank();
544 if (srcRank < vecRank)
545 return rewriter.notifyMatchFailure(op,
546 "source rank is less than vector rank");
547
548 // Note: no - 1 on the rank here. While we could treat the collapse of [1, 1,
549 // N] into N as a special case, that is left as future work for those who need
550 // such a pattern.
551 SmallVector<ReassociationIndices> reassocs =
552 collapse.getReassociationIndices();
553 if (!hasTrivialReassociationSuffix(reassocs, vecRank))
554 return rewriter.notifyMatchFailure(
555 op, "collapse_shape folding would split a transfer dimension");
556
557 AffineMap newPerm =
558 AffineMap::getMinorIdentityMap(srcRank, vecRank, op.getContext());
559 if (failed(
560 op.mayUpdateStartingPosition(collapse.getSrc().getType(), newPerm)))
561 return rewriter.notifyMatchFailure(op, "failed op-specific preconditions");
562
563 SmallVector<Value> newIndices;
564 memref::resolveSourceIndicesCollapseShape(op.getLoc(), rewriter, collapse,
565 op.getIndices(), newIndices,
566 /*startsInbounds=*/!op.getMask());
567
568 op.updateStartingPosition(rewriter, collapse.getViewSource(), newIndices,
569 AffineMapAttr::get(newPerm));
570 return success();
571}
572
574 patterns
575 .add<AccessOpOfSubViewOpFolder, AccessOpOfExpandShapeOpFolder,
576 AccessOpOfCollapseShapeOpFolder, IndexedMemCopyOpOfSubViewOpFolder,
577 IndexedMemCopyOpOfExpandShapeOpFolder,
578 IndexedMemCopyOpOfCollapseShapeOpFolder, TransferOpOfSubViewOpFolder,
579 TransferOpOfExpandShapeOpFolder, TransferOpOfCollapseShapeOpFolder,
580 SubViewOfSubViewFolder>(patterns.getContext());
581}
582
583//===----------------------------------------------------------------------===//
584// Pass registration
585//===----------------------------------------------------------------------===//
586
587namespace {
588
589struct FoldMemRefAliasOpsPass final
590 : public memref::impl::FoldMemRefAliasOpsPassBase<FoldMemRefAliasOpsPass> {
591 void runOnOperation() override;
592};
593
594} // namespace
595
596void FoldMemRefAliasOpsPass::runOnOperation() {
597 RewritePatternSet patterns(&getContext());
599 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
600}
return success()
b getContext())
static bool hasTrivialReassociationSuffix(ArrayRef< ReassociationIndices > reassocs, int64_t n)
Deterimine if the last N indices of reassocitaion are trivial - that is, check if they all contain ex...
static bool hasTrailingUnitStrides(memref::SubViewOp subview, int64_t n)
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:368
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
LogicalResult mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > producerOffsets, ArrayRef< OpFoldResult > producerSizes, ArrayRef< OpFoldResult > producerStrides, const llvm::SmallBitVector &droppedProducerDims, ArrayRef< OpFoldResult > consumerOffsets, ArrayRef< OpFoldResult > consumerSizes, ArrayRef< OpFoldResult > consumerStrides, SmallVector< OpFoldResult > &combinedOffsets, SmallVector< OpFoldResult > &combinedSizes, SmallVector< OpFoldResult > &combinedStrides)
Fills the combinedOffsets, combinedSizes and combinedStrides to use when combining a producer slice i...
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
void resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...