MLIR 23.0.0git
FoldMemRefAliasOps.cpp
Go to the documentation of this file.
1//===- FoldMemRefAliasOps.cpp - Fold memref alias ops ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation pass folds loading/storing from/to subview ops into
10// loading/storing from/to the original memref.
11//
12//===----------------------------------------------------------------------===//
13
22#include "mlir/IR/AffineExpr.h"
23#include "mlir/IR/AffineMap.h"
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/ADT/SmallBitVector.h"
28#include "llvm/Support/Debug.h"
29#include <cstdint>
30
31#define DEBUG_TYPE "fold-memref-alias-ops"
32#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
33
34namespace mlir {
35namespace memref {
36#define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS
37#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
38} // namespace memref
39} // namespace mlir
40
41using namespace mlir;
42
43//===----------------------------------------------------------------------===//
44// Utility functions
45//===----------------------------------------------------------------------===//
46
47/// Deterimine if the last N indices of `reassocitaion` are trivial - that is,
48/// check if they all contain exactly one dimension to collape/expand into.
49static bool
51 int64_t n) {
52 if (n <= 0)
53 return true;
54 if (n > static_cast<int64_t>(reassocs.size()))
55 return false;
56 return llvm::all_of(
57 reassocs.take_back(n),
58 [&](const ReassociationIndices &indices) { return indices.size() == 1; });
59}
60
61static bool hasTrailingUnitStrides(memref::SubViewOp subview, int64_t n) {
62 if (n <= 0)
63 return true;
64 ArrayRef<int64_t> strides = subview.getStaticStrides();
65 if (n > static_cast<int64_t>(strides.size()))
66 return false;
67 return llvm::all_of(strides.take_back(n), [](int64_t s) { return s == 1; });
68}
69
70//===----------------------------------------------------------------------===//
71// Patterns
72//===----------------------------------------------------------------------===//
73
74namespace {
75/// Folds subview(subview(x)) to a single subview(x).
76class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
77public:
78 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
79
80 LogicalResult matchAndRewrite(memref::SubViewOp subView,
81 PatternRewriter &rewriter) const override {
82 auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
83 if (!srcSubView)
84 return failure();
85
86 SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
87 if (failed(affine::mergeOffsetsSizesAndStrides(
88 rewriter, subView.getLoc(), srcSubView, subView,
89 srcSubView.getDroppedDims(), newOffsets, newSizes, newStrides)))
90 return failure();
91
92 // Replace original op.
93 rewriter.replaceOpWithNewOp<memref::SubViewOp>(
94 subView, subView.getType(), srcSubView.getSource(), newOffsets,
95 newSizes, newStrides);
96 return success();
97 }
98};
99
100/// Merges subview operations with load/store like operations unless such a
101/// merger would cause the strides between dimensions accessed by that operaton
102/// to change.
103struct AccessOpOfSubViewOpFolder final
104 : OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
105 using Base::Base;
106
107 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
108 PatternRewriter &rewriter) const override;
109};
110
111/// Merge a memref.expand_shape operation with an operation that accesses a
112/// memref by index unless that operation accesss more than one dimension of
113/// memory and any dimension other than the outermost dimension accessed this
114/// way would be merged. This prevents issuses from arising with, say, a
115/// vector.load of a 4x2 vector having the two trailing dimensions of the access
116/// get merged.
117struct AccessOpOfExpandShapeOpFolder final
118 : OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
119 using Base::Base;
120
121 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
122 PatternRewriter &rewriter) const override;
123};
124
125/// Merges an operation that accesses a memref by index with a
126/// memref.collapse_shape, unless this would break apart a dimension other than
127/// the outermost one that an operation accesses. This prevents, for example,
128/// transforming a load of a 3x8 vector from a 6x8 memref into a load
129/// from a 3x4x2 memref (as this would require special handling and could lead
130/// to invalid IR if that higher-dimensional memref comes from a subview) but
131/// does permit turning a load of a length-8 vector from a 3x8 memref into a
132/// load from a 3x2x8 one.
133struct AccessOpOfCollapseShapeOpFolder final
134 : OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
135 using Base::Base;
136
137 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
138 PatternRewriter &rewriter) const override;
139};
140
141/// Merges memref.subview operations present on the source or destination
142/// operands of indexed memory copy operations (DMA operations) into those
143/// operations. This is perfromed unconditionally, since folding in a subview
144/// cannot change the starting position of the copy, which is what the
145/// memref/index pair represent in DMA operations.
146struct IndexedMemCopyOpOfSubViewOpFolder final
147 : OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
148 using Base::Base;
149
150 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
151 PatternRewriter &rewriter) const override;
152};
153
154/// Merges memref.expand_shape operations that are present on the source or
155/// destination of an indexed memory copy/DMA into the memref/index arguments of
156/// that DMA. As with subviews, this can be done unconditionally.
157struct IndexedMemCopyOpOfExpandShapeOpFolder final
158 : OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
159 using Base::Base;
160
161 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
162 PatternRewriter &rewriter) const override;
163};
164
165/// Merges memref.collapse_shape operations that are present on the source or
166/// destination of an indexed memory copy/DMA into the memref/index arguments of
167/// that DMA. As with subviews, this can be done unconditionally.
168struct IndexedMemCopyOpOfCollapseShapeOpFolder final
169 : OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
170 using Base::Base;
171
172 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
173 PatternRewriter &rewriter) const override;
174};
175
176/// Merges memref.subview ops on the base argument to vector transfer operations
177/// into the base and indices of that transfer if:
178/// - The subview has unit strides on transfer dimensions
179/// - All the transfer dimensions are in-bounds
180/// This will correctly update said permutation map to account for dropped
181/// dimensions in rank-reducing subviews.
182struct TransferOpOfSubViewOpFolder final
183 : OpInterfaceRewritePattern<VectorTransferOpInterface> {
184 using Base::Base;
185
186 LogicalResult matchAndRewrite(VectorTransferOpInterface op,
187 PatternRewriter &rewriter) const override;
188};
189
190/// Merges memref.expand_shape ops that create the base of a vector transfer
191/// operation into the base and indices of that transfer. Does not act when the
192/// a dimension is potentially out of bounds, if one of the transfer dimensions
193/// would need to be strided because of the collapse, or if it would merge two
194/// dimensions that are both transfer dimensions.
195/// TODO: become more sophisticated about length-1 dimensions that are the
196/// result of an expansion becoming broadcasts.
197struct TransferOpOfExpandShapeOpFolder final
198 : OpInterfaceRewritePattern<VectorTransferOpInterface> {
199 using Base::Base;
200
201 LogicalResult matchAndRewrite(VectorTransferOpInterface op,
202 PatternRewriter &rewriter) const override;
203};
204
205/// Merges memref.collapse_shape ops that create the base of a vector transfer
206/// operation into the base and indices of that transfer. Does not act when the
207/// permutation map is not trivial, a dimension could be performing out of
208/// bounds reads, or if it would break apart a transfer dimension.
209struct TransferOpOfCollapseShapeOpFolder final
210 : OpInterfaceRewritePattern<VectorTransferOpInterface> {
211 using Base::Base;
212
213 LogicalResult matchAndRewrite(VectorTransferOpInterface op,
214 PatternRewriter &rewriter) const override;
215};
216} // namespace
217
218LogicalResult
219AccessOpOfSubViewOpFolder::matchAndRewrite(memref::IndexedAccessOpInterface op,
220 PatternRewriter &rewriter) const {
221 TypedValue<MemRefType> accessedMemref = op.getAccessedMemref();
222 if (!accessedMemref)
223 return rewriter.notifyMatchFailure(op, "not accessing a memref");
224
225 auto subview = accessedMemref.getDefiningOp<memref::SubViewOp>();
226 if (!subview)
227 return rewriter.notifyMatchFailure(op, "not accessing a subview");
228
229 SmallVector<int64_t> accessedShape = op.getAccessedShape();
230 // Note the subtle difference between accessedShape = {1} and accessedShape =
231 // {} here. The former prevents us from folding in a subview that doesn't
232 // have a unit stride on the final dimension, while the latter does not (since
233 // it indexes scalar accesses).
234 int64_t accessedDims = accessedShape.size();
235 if (!hasTrailingUnitStrides(subview, accessedDims))
236 return rewriter.notifyMatchFailure(
237 op, "non-unit stride on accessed dimensions");
238
239 llvm::SmallBitVector droppedDims = subview.getDroppedDims();
240 int64_t sourceRank = subview.getSourceType().getRank();
241
242 // Ignore outermost access dimension - we only care about dropped dimensions
243 // between the accessed op's results, as those could break the accessing op's
244 // semantics.
245 int64_t secondAccessedDim = sourceRank - (accessedDims - 1);
246 if (secondAccessedDim < sourceRank) {
247 for (int64_t d : llvm::seq(secondAccessedDim, sourceRank)) {
248 if (droppedDims.test(d))
249 return rewriter.notifyMatchFailure(
250 op, "reintroducing dropped dimension " + Twine(d) +
251 " would break access op semantics");
252 }
253 }
254
255 SmallVector<Value> sourceIndices;
256 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
257 rewriter, op.getLoc(), subview.getMixedOffsets(),
258 subview.getMixedStrides(), droppedDims, op.getIndices(), sourceIndices);
259
260 std::optional<SmallVector<Value>> newValues =
261 op.updateMemrefAndIndices(rewriter, subview.getSource(), sourceIndices);
262 if (newValues)
263 rewriter.replaceOp(op, *newValues);
264 return success();
265}
266
267LogicalResult AccessOpOfExpandShapeOpFolder::matchAndRewrite(
268 memref::IndexedAccessOpInterface op, PatternRewriter &rewriter) const {
269 TypedValue<MemRefType> accessedMemref = op.getAccessedMemref();
270 if (!accessedMemref)
271 return rewriter.notifyMatchFailure(op, "not accessing a memref");
272
273 auto expand = accessedMemref.getDefiningOp<memref::ExpandShapeOp>();
274 if (!expand)
275 return rewriter.notifyMatchFailure(op, "not accessing an expand_shape");
276
277 SmallVector<int64_t> rawAccessedShape = op.getAccessedShape();
278 ArrayRef<int64_t> accessedShape = rawAccessedShape;
279 if (expand.getSrcType().getRank() <
280 static_cast<int64_t>(accessedShape.size()))
281 return rewriter.notifyMatchFailure(
282 op, "expand_shape source rank is too small for the accessed shape");
283
284 // Cut off the leading dimension, since we don't care about modifying its
285 // strides.
286 if (!accessedShape.empty())
287 accessedShape = accessedShape.drop_front();
288
289 SmallVector<ReassociationIndices, 4> reassocs =
290 expand.getReassociationIndices();
291 if (!hasTrivialReassociationSuffix(reassocs, accessedShape.size()))
292 return rewriter.notifyMatchFailure(
293 op,
294 "expand_shape folding would merge semantically important dimensions");
295
296 SmallVector<Value> sourceIndices;
297 memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, expand,
298 op.getIndices(), sourceIndices,
299 op.hasInboundsIndices());
300
301 std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices(
302 rewriter, expand.getViewSource(), sourceIndices);
303 if (newValues)
304 rewriter.replaceOp(op, *newValues);
305 return success();
306}
307
308LogicalResult AccessOpOfCollapseShapeOpFolder::matchAndRewrite(
309 memref::IndexedAccessOpInterface op, PatternRewriter &rewriter) const {
310 TypedValue<MemRefType> accessedMemref = op.getAccessedMemref();
311 if (!accessedMemref)
312 return rewriter.notifyMatchFailure(op, "not accessing a memref");
313
314 auto collapse = accessedMemref.getDefiningOp<memref::CollapseShapeOp>();
315 if (!collapse)
316 return rewriter.notifyMatchFailure(op, "not accessing a collapse_shape");
317
318 SmallVector<int64_t> rawAccessedShape = op.getAccessedShape();
319 ArrayRef<int64_t> accessedShape = rawAccessedShape;
320 if (collapse.getSrcType().getRank() <
321 static_cast<int64_t>(accessedShape.size()))
322 return rewriter.notifyMatchFailure(
323 op, "collapse_shape source rank is too small for the accessed shape");
324
325 // Cut off the leading dimension, since we don't care about its strides being
326 // modified and we know that the dimensions within its reassociation group, if
327 // it's non-trivial, must be contiguous.
328 if (!accessedShape.empty())
329 accessedShape = accessedShape.drop_front();
330
331 SmallVector<ReassociationIndices, 4> reassocs =
332 collapse.getReassociationIndices();
333 if (!hasTrivialReassociationSuffix(reassocs, accessedShape.size()))
334 return rewriter.notifyMatchFailure(op, "collapse_shape folding would merge "
335 "semantically important dimensions");
336
337 SmallVector<Value> sourceIndices;
338 memref::resolveSourceIndicesCollapseShape(op.getLoc(), rewriter, collapse,
339 op.getIndices(), sourceIndices,
340 op.hasInboundsIndices());
341
342 std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices(
343 rewriter, collapse.getViewSource(), sourceIndices);
344 if (newValues)
345 rewriter.replaceOp(op, *newValues);
346 return success();
347}
348
349LogicalResult IndexedMemCopyOpOfSubViewOpFolder::matchAndRewrite(
350 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const {
351 TypedValue<MemRefType> src = op.getSrc();
352 TypedValue<MemRefType> dst = op.getDst();
353 auto srcSubview = src ? src.getDefiningOp<memref::SubViewOp>() : nullptr;
354 auto dstSubview = dst ? dst.getDefiningOp<memref::SubViewOp>() : nullptr;
355 if (!srcSubview && !dstSubview)
356 return rewriter.notifyMatchFailure(
357 op, "no subviews found on indexed copy inputs");
358
359 Value newSrc = src;
360 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
361 Value newDst = dst;
362 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
363 if (srcSubview) {
364 newSrc = srcSubview.getSource();
365 newSrcIndices.clear();
366 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
367 rewriter, op.getLoc(), srcSubview.getMixedOffsets(),
368 srcSubview.getMixedStrides(), srcSubview.getDroppedDims(),
369 op.getSrcIndices(), newSrcIndices);
370 }
371 if (dstSubview) {
372 newDst = dstSubview.getSource();
373 newDstIndices.clear();
374 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
375 rewriter, op.getLoc(), dstSubview.getMixedOffsets(),
376 dstSubview.getMixedStrides(), dstSubview.getDroppedDims(),
377 op.getDstIndices(), newDstIndices);
378 }
379 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
380 newDstIndices);
381 return success();
382}
383
384LogicalResult IndexedMemCopyOpOfExpandShapeOpFolder::matchAndRewrite(
385 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const {
386 TypedValue<MemRefType> src = op.getSrc();
387 TypedValue<MemRefType> dst = op.getDst();
388 auto srcExpand = src ? src.getDefiningOp<memref::ExpandShapeOp>() : nullptr;
389 auto dstExpand = dst ? dst.getDefiningOp<memref::ExpandShapeOp>() : nullptr;
390 if (!srcExpand && !dstExpand)
391 return rewriter.notifyMatchFailure(
392 op, "no expand_shapes found on indexed copy inputs");
393
394 Value newSrc = src;
395 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
396 Value newDst = dst;
397 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
398 if (srcExpand) {
399 newSrc = srcExpand.getViewSource();
400 newSrcIndices.clear();
401 memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, srcExpand,
402 op.getSrcIndices(), newSrcIndices,
403 op.hasInboundsSrcIndices());
404 }
405 if (dstExpand) {
406 newDst = dstExpand.getViewSource();
407 newDstIndices.clear();
408 memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, dstExpand,
409 op.getDstIndices(), newDstIndices,
410 op.hasInboundsDstIndices());
411 }
412 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
413 newDstIndices);
414 return success();
415}
416
417LogicalResult IndexedMemCopyOpOfCollapseShapeOpFolder::matchAndRewrite(
418 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const {
419 TypedValue<MemRefType> src = op.getSrc();
420 TypedValue<MemRefType> dst = op.getDst();
421 auto srcCollapse =
422 src ? src.getDefiningOp<memref::CollapseShapeOp>() : nullptr;
423 auto dstCollapse =
424 dst ? dst.getDefiningOp<memref::CollapseShapeOp>() : nullptr;
425 if (!srcCollapse && !dstCollapse)
426 return rewriter.notifyMatchFailure(
427 op, "no collapse_shapes found on indexed copy inputs");
428
429 Value newSrc = src;
430 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
431 Value newDst = dst;
432 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
433 if (srcCollapse) {
434 newSrc = srcCollapse.getViewSource();
435 newSrcIndices.clear();
437 op.getLoc(), rewriter, srcCollapse, op.getSrcIndices(), newSrcIndices,
438 op.hasInboundsSrcIndices());
439 }
440 if (dstCollapse) {
441 newDst = dstCollapse.getViewSource();
442 newDstIndices.clear();
444 op.getLoc(), rewriter, dstCollapse, op.getDstIndices(), newDstIndices,
445 op.hasInboundsDstIndices());
446 }
447 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
448 newDstIndices);
449 return success();
450}
451
452LogicalResult
453TransferOpOfSubViewOpFolder::matchAndRewrite(VectorTransferOpInterface op,
454 PatternRewriter &rewriter) const {
455 auto subview = op.getBase().getDefiningOp<memref::SubViewOp>();
456 if (!subview)
457 return rewriter.notifyMatchFailure(op, "not accessing a subview");
458
459 AffineMap perm = op.getPermutationMap();
460 // Note: no identity permutation check here, since subview folding can handle
461 // complex permutations because it doesn't merge or split any individual
462 // dimension.
463 if (op.hasOutOfBoundsDim())
464 return rewriter.notifyMatchFailure(op, "out of bounds dimension");
465 VectorType vecTy = op.getVectorType();
466 // Because we know the permutation map is a minor identity, we know that the
467 // last N dimensions must have unit stride, where N is the vector rank.
468 if (!hasTrailingUnitStrides(subview, vecTy.getRank()))
469 return rewriter.notifyMatchFailure(subview, "non-unit stride within last " +
470 Twine(vecTy.getRank()) +
471 " dimensions");
472
473 AffineMap newPerm = expandDimsToRank(perm, subview.getSourceType().getRank(),
474 subview.getDroppedDims());
475
476 if (failed(op.mayUpdateStartingPosition(subview.getSourceType(), newPerm)))
477 return rewriter.notifyMatchFailure(subview,
478 "failed op-specific preconditions");
479
480 SmallVector<Value> newIndices;
481 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
482 rewriter, op.getLoc(), subview.getMixedOffsets(),
483 subview.getMixedStrides(), subview.getDroppedDims(), op.getIndices(),
484 newIndices);
485 op.updateStartingPosition(rewriter, subview.getSource(), newIndices,
486 AffineMapAttr::get(newPerm));
487 return success();
488}
489
490LogicalResult TransferOpOfExpandShapeOpFolder::matchAndRewrite(
491 VectorTransferOpInterface op, PatternRewriter &rewriter) const {
492 auto expand = op.getBase().getDefiningOp<memref::ExpandShapeOp>();
493 if (!expand)
494 return rewriter.notifyMatchFailure(op, "not accessing an expand_shape");
495
496 if (op.hasOutOfBoundsDim())
497 return rewriter.notifyMatchFailure(op, "out of bounds dimension");
498
499 int64_t srcRank = expand.getSrc().getType().getRank();
500 int64_t vecRank = op.getVectorType().getRank();
501 if (srcRank < vecRank)
502 return rewriter.notifyMatchFailure(op,
503 "source rank is less than vector rank");
504
505 llvm::SmallDenseMap<int64_t, int64_t, 8> unstridedResDimToSrcDim;
506 for (auto [srcIdx, reassoc] :
507 llvm::enumerate(expand.getReassociationIndices())) {
508 unstridedResDimToSrcDim.insert({reassoc.back(), srcIdx});
509 }
510 // If every dimension of the expanded shape that appears in the permutation
511 // map is also present in the final entry of the expansions (meaning that
512 // collapsing in more values won't cause us to need to stride the index), we
513 // can fold in the expansion. (This doesn't currently account for expanding
514 // length X to X by 1, but it could in the future).
515 AffineMap permMap = op.getPermutationMap();
516 SmallVector<AffineExpr> newPermMapResults;
517 newPermMapResults.reserve(permMap.getNumResults());
518 for (AffineExpr permRes : permMap.getResults()) {
519 auto resDim = dyn_cast<AffineDimExpr>(permRes);
520 if (!resDim)
521 return rewriter.notifyMatchFailure(
522 op, "has non-dim entry in permutation map");
523 auto dimInSrc = unstridedResDimToSrcDim.find(resDim.getPosition());
524 if (dimInSrc == unstridedResDimToSrcDim.end())
525 return rewriter.notifyMatchFailure(op,
526 "permutation map result would be made "
527 "strided by expand_shape folding");
528 newPermMapResults.push_back(rewriter.getAffineDimExpr(dimInSrc->second));
529 }
530
531 auto newPerm = AffineMap::get(srcRank, 0, newPermMapResults, op.getContext());
532
533 if (failed(op.mayUpdateStartingPosition(expand.getSrc().getType(), newPerm)))
534 return rewriter.notifyMatchFailure(op, "failed op-specific preconditions");
535
536 SmallVector<Value> newIndices;
537 // We can use a disjoint linearization if we aren't masking, because then all
538 // indicators show that the start position will be in bounds.
539 memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, expand,
540 op.getIndices(), newIndices,
541 /*startsInbounds=*/!op.getMask());
542
543 op.updateStartingPosition(rewriter, expand.getViewSource(), newIndices,
544 AffineMapAttr::get(newPerm));
545 return success();
546}
547
548LogicalResult TransferOpOfCollapseShapeOpFolder::matchAndRewrite(
549 VectorTransferOpInterface op, PatternRewriter &rewriter) const {
550 auto collapse = op.getBase().getDefiningOp<memref::CollapseShapeOp>();
551 if (!collapse)
552 return rewriter.notifyMatchFailure(op, "not accessing a collapse_shape");
553
554 if (!op.getPermutationMap().isMinorIdentity())
555 return rewriter.notifyMatchFailure(op,
556 "non-minor identity permutation map");
557
558 if (op.hasOutOfBoundsDim())
559 return rewriter.notifyMatchFailure(op, "out of bounds dimension");
560
561 int64_t srcRank = collapse.getSrc().getType().getRank();
562 int64_t vecRank = op.getVectorType().getRank();
563 if (srcRank < vecRank)
564 return rewriter.notifyMatchFailure(op,
565 "source rank is less than vector rank");
566
567 // Note: no - 1 on the rank here. While we could treat the collapse of [1, 1,
568 // N] into N as a special case, that is left as future work for those who need
569 // such a pattern.
570 SmallVector<ReassociationIndices> reassocs =
571 collapse.getReassociationIndices();
572 if (!hasTrivialReassociationSuffix(reassocs, vecRank))
573 return rewriter.notifyMatchFailure(
574 op, "collapse_shape folding would split a transfer dimension");
575
576 AffineMap newPerm =
577 AffineMap::getMinorIdentityMap(srcRank, vecRank, op.getContext());
578 if (failed(
579 op.mayUpdateStartingPosition(collapse.getSrc().getType(), newPerm)))
580 return rewriter.notifyMatchFailure(op, "failed op-specific preconditions");
581
582 SmallVector<Value> newIndices;
583 memref::resolveSourceIndicesCollapseShape(op.getLoc(), rewriter, collapse,
584 op.getIndices(), newIndices,
585 /*startsInbounds=*/!op.getMask());
586
587 op.updateStartingPosition(rewriter, collapse.getViewSource(), newIndices,
588 AffineMapAttr::get(newPerm));
589 return success();
590}
591
593 patterns
594 .add<AccessOpOfSubViewOpFolder, AccessOpOfExpandShapeOpFolder,
595 AccessOpOfCollapseShapeOpFolder, IndexedMemCopyOpOfSubViewOpFolder,
596 IndexedMemCopyOpOfExpandShapeOpFolder,
597 IndexedMemCopyOpOfCollapseShapeOpFolder, TransferOpOfSubViewOpFolder,
598 TransferOpOfExpandShapeOpFolder, TransferOpOfCollapseShapeOpFolder,
599 SubViewOfSubViewFolder>(patterns.getContext());
600}
601
602//===----------------------------------------------------------------------===//
603// Pass registration
604//===----------------------------------------------------------------------===//
605
606namespace {
607
608struct FoldMemRefAliasOpsPass final
609 : public memref::impl::FoldMemRefAliasOpsPassBase<FoldMemRefAliasOpsPass> {
610 void runOnOperation() override;
611};
612
613} // namespace
614
615void FoldMemRefAliasOpsPass::runOnOperation() {
616 RewritePatternSet patterns(&getContext());
618 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
619}
return success()
b getContext())
static bool hasTrivialReassociationSuffix(ArrayRef< ReassociationIndices > reassocs, int64_t n)
Deterimine if the last N indices of reassocitaion are trivial - that is, check if they all contain ex...
static bool hasTrailingUnitStrides(memref::SubViewOp subview, int64_t n)
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:369
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
void resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...