26#include "llvm/ADT/STLExtras.h"
27#include "llvm/ADT/SmallBitVector.h"
28#include "llvm/Support/Debug.h"
31#define DEBUG_TYPE "fold-memref-alias-ops"
32#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
36#define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS
37#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
54 if (n >
static_cast<int64_t>(reassocs.size()))
57 reassocs.take_back(n),
65 if (n >
static_cast<int64_t>(strides.size()))
67 return llvm::all_of(strides.take_back(n), [](
int64_t s) { return s == 1; });
78 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
80 LogicalResult matchAndRewrite(memref::SubViewOp subView,
81 PatternRewriter &rewriter)
const override {
82 auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
86 SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
87 if (
failed(affine::mergeOffsetsSizesAndStrides(
88 rewriter, subView.getLoc(), srcSubView, subView,
89 srcSubView.getDroppedDims(), newOffsets, newSizes, newStrides)))
94 subView, subView.getType(), srcSubView.getSource(), newOffsets,
95 newSizes, newStrides);
103struct AccessOpOfSubViewOpFolder final
107 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
108 PatternRewriter &rewriter)
const override;
117struct AccessOpOfExpandShapeOpFolder final
121 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
122 PatternRewriter &rewriter)
const override;
133struct AccessOpOfCollapseShapeOpFolder final
137 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
138 PatternRewriter &rewriter)
const override;
146struct IndexedMemCopyOpOfSubViewOpFolder final
150 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
151 PatternRewriter &rewriter)
const override;
157struct IndexedMemCopyOpOfExpandShapeOpFolder final
161 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
162 PatternRewriter &rewriter)
const override;
168struct IndexedMemCopyOpOfCollapseShapeOpFolder final
172 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
173 PatternRewriter &rewriter)
const override;
182struct TransferOpOfSubViewOpFolder final
186 LogicalResult matchAndRewrite(VectorTransferOpInterface op,
187 PatternRewriter &rewriter)
const override;
197struct TransferOpOfExpandShapeOpFolder final
201 LogicalResult matchAndRewrite(VectorTransferOpInterface op,
202 PatternRewriter &rewriter)
const override;
209struct TransferOpOfCollapseShapeOpFolder final
213 LogicalResult matchAndRewrite(VectorTransferOpInterface op,
214 PatternRewriter &rewriter)
const override;
219AccessOpOfSubViewOpFolder::matchAndRewrite(memref::IndexedAccessOpInterface op,
225 auto subview = accessedMemref.getDefiningOp<memref::SubViewOp>();
229 SmallVector<int64_t> accessedShape = op.getAccessedShape();
234 int64_t accessedDims = accessedShape.size();
237 op,
"non-unit stride on accessed dimensions");
239 llvm::SmallBitVector droppedDims = subview.getDroppedDims();
240 int64_t sourceRank = subview.getSourceType().getRank();
245 int64_t secondAccessedDim = sourceRank - (accessedDims - 1);
246 if (secondAccessedDim < sourceRank) {
247 for (int64_t d : llvm::seq(secondAccessedDim, sourceRank)) {
248 if (droppedDims.test(d))
250 op,
"reintroducing dropped dimension " + Twine(d) +
251 " would break access op semantics");
255 SmallVector<Value> sourceIndices;
256 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
257 rewriter, op.getLoc(), subview.getMixedOffsets(),
258 subview.getMixedStrides(), droppedDims, op.getIndices(), sourceIndices);
260 std::optional<SmallVector<Value>> newValues =
261 op.updateMemrefAndIndices(rewriter, subview.getSource(), sourceIndices);
267LogicalResult AccessOpOfExpandShapeOpFolder::matchAndRewrite(
268 memref::IndexedAccessOpInterface op, PatternRewriter &rewriter)
const {
273 auto expand = accessedMemref.getDefiningOp<memref::ExpandShapeOp>();
277 SmallVector<int64_t> rawAccessedShape = op.getAccessedShape();
278 ArrayRef<int64_t> accessedShape = rawAccessedShape;
279 if (expand.getSrcType().getRank() <
280 static_cast<int64_t
>(accessedShape.size()))
282 op,
"expand_shape source rank is too small for the accessed shape");
286 if (!accessedShape.empty())
287 accessedShape = accessedShape.drop_front();
289 SmallVector<ReassociationIndices, 4> reassocs =
290 expand.getReassociationIndices();
294 "expand_shape folding would merge semantically important dimensions");
296 SmallVector<Value> sourceIndices;
298 op.getIndices(), sourceIndices,
299 op.hasInboundsIndices());
301 std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices(
302 rewriter, expand.getViewSource(), sourceIndices);
308LogicalResult AccessOpOfCollapseShapeOpFolder::matchAndRewrite(
309 memref::IndexedAccessOpInterface op, PatternRewriter &rewriter)
const {
314 auto collapse = accessedMemref.getDefiningOp<memref::CollapseShapeOp>();
318 SmallVector<int64_t> rawAccessedShape = op.getAccessedShape();
319 ArrayRef<int64_t> accessedShape = rawAccessedShape;
320 if (collapse.getSrcType().getRank() <
321 static_cast<int64_t
>(accessedShape.size()))
323 op,
"collapse_shape source rank is too small for the accessed shape");
328 if (!accessedShape.empty())
329 accessedShape = accessedShape.drop_front();
331 SmallVector<ReassociationIndices, 4> reassocs =
332 collapse.getReassociationIndices();
335 "semantically important dimensions");
337 SmallVector<Value> sourceIndices;
339 op.getIndices(), sourceIndices,
340 op.hasInboundsIndices());
342 std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices(
343 rewriter, collapse.getViewSource(), sourceIndices);
349LogicalResult IndexedMemCopyOpOfSubViewOpFolder::matchAndRewrite(
350 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter)
const {
353 auto srcSubview = src ? src.getDefiningOp<memref::SubViewOp>() :
nullptr;
354 auto dstSubview = dst ? dst.getDefiningOp<memref::SubViewOp>() :
nullptr;
355 if (!srcSubview && !dstSubview)
357 op,
"no subviews found on indexed copy inputs");
360 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
362 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
364 newSrc = srcSubview.getSource();
365 newSrcIndices.clear();
366 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
367 rewriter, op.getLoc(), srcSubview.getMixedOffsets(),
368 srcSubview.getMixedStrides(), srcSubview.getDroppedDims(),
369 op.getSrcIndices(), newSrcIndices);
372 newDst = dstSubview.getSource();
373 newDstIndices.clear();
374 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
375 rewriter, op.getLoc(), dstSubview.getMixedOffsets(),
376 dstSubview.getMixedStrides(), dstSubview.getDroppedDims(),
377 op.getDstIndices(), newDstIndices);
379 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
384LogicalResult IndexedMemCopyOpOfExpandShapeOpFolder::matchAndRewrite(
385 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter)
const {
388 auto srcExpand = src ? src.getDefiningOp<memref::ExpandShapeOp>() :
nullptr;
389 auto dstExpand = dst ? dst.getDefiningOp<memref::ExpandShapeOp>() :
nullptr;
390 if (!srcExpand && !dstExpand)
392 op,
"no expand_shapes found on indexed copy inputs");
395 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
397 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
399 newSrc = srcExpand.getViewSource();
400 newSrcIndices.clear();
402 op.getSrcIndices(), newSrcIndices,
403 op.hasInboundsSrcIndices());
406 newDst = dstExpand.getViewSource();
407 newDstIndices.clear();
409 op.getDstIndices(), newDstIndices,
410 op.hasInboundsDstIndices());
412 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
417LogicalResult IndexedMemCopyOpOfCollapseShapeOpFolder::matchAndRewrite(
418 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter)
const {
422 src ? src.getDefiningOp<memref::CollapseShapeOp>() :
nullptr;
424 dst ? dst.getDefiningOp<memref::CollapseShapeOp>() :
nullptr;
425 if (!srcCollapse && !dstCollapse)
427 op,
"no collapse_shapes found on indexed copy inputs");
430 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
432 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
434 newSrc = srcCollapse.getViewSource();
435 newSrcIndices.clear();
437 op.getLoc(), rewriter, srcCollapse, op.getSrcIndices(), newSrcIndices,
438 op.hasInboundsSrcIndices());
441 newDst = dstCollapse.getViewSource();
442 newDstIndices.clear();
444 op.getLoc(), rewriter, dstCollapse, op.getDstIndices(), newDstIndices,
445 op.hasInboundsDstIndices());
447 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
453TransferOpOfSubViewOpFolder::matchAndRewrite(VectorTransferOpInterface op,
454 PatternRewriter &rewriter)
const {
455 auto subview = op.getBase().getDefiningOp<memref::SubViewOp>();
459 AffineMap perm = op.getPermutationMap();
463 if (op.hasOutOfBoundsDim())
465 VectorType vecTy = op.getVectorType();
470 Twine(vecTy.getRank()) +
473 AffineMap newPerm =
expandDimsToRank(perm, subview.getSourceType().getRank(),
474 subview.getDroppedDims());
476 if (
failed(op.mayUpdateStartingPosition(subview.getSourceType(), newPerm)))
478 "failed op-specific preconditions");
480 SmallVector<Value> newIndices;
481 affine::resolveIndicesIntoOpWithOffsetsAndStrides(
482 rewriter, op.getLoc(), subview.getMixedOffsets(),
483 subview.getMixedStrides(), subview.getDroppedDims(), op.getIndices(),
485 op.updateStartingPosition(rewriter, subview.getSource(), newIndices,
486 AffineMapAttr::get(newPerm));
490LogicalResult TransferOpOfExpandShapeOpFolder::matchAndRewrite(
491 VectorTransferOpInterface op, PatternRewriter &rewriter)
const {
492 auto expand = op.getBase().getDefiningOp<memref::ExpandShapeOp>();
496 if (op.hasOutOfBoundsDim())
499 int64_t srcRank = expand.getSrc().getType().getRank();
500 int64_t vecRank = op.getVectorType().getRank();
501 if (srcRank < vecRank)
503 "source rank is less than vector rank");
505 llvm::SmallDenseMap<int64_t, int64_t, 8> unstridedResDimToSrcDim;
506 for (
auto [srcIdx, reassoc] :
507 llvm::enumerate(expand.getReassociationIndices())) {
508 unstridedResDimToSrcDim.insert({reassoc.back(), srcIdx});
515 AffineMap permMap = op.getPermutationMap();
516 SmallVector<AffineExpr> newPermMapResults;
518 for (AffineExpr permRes : permMap.
getResults()) {
519 auto resDim = dyn_cast<AffineDimExpr>(permRes);
522 op,
"has non-dim entry in permutation map");
523 auto dimInSrc = unstridedResDimToSrcDim.find(resDim.getPosition());
524 if (dimInSrc == unstridedResDimToSrcDim.end())
526 "permutation map result would be made "
527 "strided by expand_shape folding");
531 auto newPerm =
AffineMap::get(srcRank, 0, newPermMapResults, op.getContext());
533 if (
failed(op.mayUpdateStartingPosition(expand.getSrc().getType(), newPerm)))
536 SmallVector<Value> newIndices;
540 op.getIndices(), newIndices,
543 op.updateStartingPosition(rewriter, expand.getViewSource(), newIndices,
544 AffineMapAttr::get(newPerm));
548LogicalResult TransferOpOfCollapseShapeOpFolder::matchAndRewrite(
549 VectorTransferOpInterface op, PatternRewriter &rewriter)
const {
550 auto collapse = op.getBase().getDefiningOp<memref::CollapseShapeOp>();
554 if (!op.getPermutationMap().isMinorIdentity())
556 "non-minor identity permutation map");
558 if (op.hasOutOfBoundsDim())
561 int64_t srcRank = collapse.getSrc().getType().getRank();
562 int64_t vecRank = op.getVectorType().getRank();
563 if (srcRank < vecRank)
565 "source rank is less than vector rank");
570 SmallVector<ReassociationIndices> reassocs =
571 collapse.getReassociationIndices();
574 op,
"collapse_shape folding would split a transfer dimension");
579 op.mayUpdateStartingPosition(collapse.getSrc().getType(), newPerm)))
582 SmallVector<Value> newIndices;
584 op.getIndices(), newIndices,
587 op.updateStartingPosition(rewriter, collapse.getViewSource(), newIndices,
588 AffineMapAttr::get(newPerm));
594 .
add<AccessOpOfSubViewOpFolder, AccessOpOfExpandShapeOpFolder,
595 AccessOpOfCollapseShapeOpFolder, IndexedMemCopyOpOfSubViewOpFolder,
596 IndexedMemCopyOpOfExpandShapeOpFolder,
597 IndexedMemCopyOpOfCollapseShapeOpFolder, TransferOpOfSubViewOpFolder,
598 TransferOpOfExpandShapeOpFolder, TransferOpOfCollapseShapeOpFolder,
599 SubViewOfSubViewFolder>(patterns.
getContext());
608struct FoldMemRefAliasOpsPass final
609 :
public memref::impl::FoldMemRefAliasOpsPassBase<FoldMemRefAliasOpsPass> {
610 void runOnOperation()
override;
615void FoldMemRefAliasOpsPass::runOnOperation() {
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getAffineDimExpr(unsigned position)
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
void resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< int64_t, 2 > ReassociationIndices
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...