26#include "llvm/ADT/STLExtras.h"
27#include "llvm/ADT/SmallBitVector.h"
28#include "llvm/Support/Debug.h"
31#define DEBUG_TYPE "fold-memref-alias-ops"
32#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
36#define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS
37#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
54 if (n >
static_cast<int64_t>(reassocs.size()))
57 reassocs.take_back(n),
65 if (n >
static_cast<int64_t>(strides.size()))
67 return llvm::all_of(strides.take_back(n), [](
int64_t s) { return s == 1; });
78 using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
80 LogicalResult matchAndRewrite(memref::SubViewOp subView,
81 PatternRewriter &rewriter)
const override {
82 auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
86 SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
88 rewriter, subView.getLoc(), srcSubView, subView,
89 srcSubView.getDroppedDims(), newOffsets, newSizes, newStrides)))
94 subView, subView.getType(), srcSubView.getSource(), newOffsets,
95 newSizes, newStrides);
103struct AccessOpOfSubViewOpFolder final
107 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
108 PatternRewriter &rewriter)
const override;
117struct AccessOpOfExpandShapeOpFolder final
121 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
122 PatternRewriter &rewriter)
const override;
133struct AccessOpOfCollapseShapeOpFolder final
137 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
138 PatternRewriter &rewriter)
const override;
146struct IndexedMemCopyOpOfSubViewOpFolder final
150 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
151 PatternRewriter &rewriter)
const override;
157struct IndexedMemCopyOpOfExpandShapeOpFolder final
161 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
162 PatternRewriter &rewriter)
const override;
168struct IndexedMemCopyOpOfCollapseShapeOpFolder final
172 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
173 PatternRewriter &rewriter)
const override;
182struct TransferOpOfSubViewOpFolder final
186 LogicalResult matchAndRewrite(VectorTransferOpInterface op,
187 PatternRewriter &rewriter)
const override;
197struct TransferOpOfExpandShapeOpFolder final
201 LogicalResult matchAndRewrite(VectorTransferOpInterface op,
202 PatternRewriter &rewriter)
const override;
209struct TransferOpOfCollapseShapeOpFolder final
213 LogicalResult matchAndRewrite(VectorTransferOpInterface op,
214 PatternRewriter &rewriter)
const override;
219AccessOpOfSubViewOpFolder::matchAndRewrite(memref::IndexedAccessOpInterface op,
221 auto subview = op.getAccessedMemref().getDefiningOp<memref::SubViewOp>();
225 SmallVector<int64_t> accessedShape = op.getAccessedShape();
230 int64_t accessedDims = accessedShape.size();
233 op,
"non-unit stride on accessed dimensions");
235 llvm::SmallBitVector droppedDims = subview.getDroppedDims();
236 int64_t sourceRank = subview.getSourceType().getRank();
241 int64_t secondAccessedDim = sourceRank - (accessedDims - 1);
242 if (secondAccessedDim < sourceRank) {
243 for (int64_t d : llvm::seq(secondAccessedDim, sourceRank)) {
244 if (droppedDims.test(d))
246 op,
"reintroducing dropped dimension " + Twine(d) +
247 " would break access op semantics");
251 SmallVector<Value> sourceIndices;
253 rewriter, op.getLoc(), subview.getMixedOffsets(),
254 subview.getMixedStrides(), droppedDims, op.getIndices(), sourceIndices);
256 std::optional<SmallVector<Value>> newValues =
257 op.updateMemrefAndIndices(rewriter, subview.getSource(), sourceIndices);
263LogicalResult AccessOpOfExpandShapeOpFolder::matchAndRewrite(
264 memref::IndexedAccessOpInterface op, PatternRewriter &rewriter)
const {
265 auto expand = op.getAccessedMemref().getDefiningOp<memref::ExpandShapeOp>();
269 SmallVector<int64_t> rawAccessedShape = op.getAccessedShape();
270 ArrayRef<int64_t> accessedShape = rawAccessedShape;
271 if (expand.getSrcType().getRank() <
272 static_cast<int64_t
>(accessedShape.size()))
274 op,
"expand_shape source rank is too small for the accessed shape");
278 if (!accessedShape.empty())
279 accessedShape = accessedShape.drop_front();
281 SmallVector<ReassociationIndices, 4> reassocs =
282 expand.getReassociationIndices();
286 "expand_shape folding would merge semantically important dimensions");
288 SmallVector<Value> sourceIndices;
290 op.getIndices(), sourceIndices,
291 op.hasInboundsIndices());
293 std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices(
294 rewriter, expand.getViewSource(), sourceIndices);
300LogicalResult AccessOpOfCollapseShapeOpFolder::matchAndRewrite(
301 memref::IndexedAccessOpInterface op, PatternRewriter &rewriter)
const {
303 op.getAccessedMemref().getDefiningOp<memref::CollapseShapeOp>();
307 SmallVector<int64_t> rawAccessedShape = op.getAccessedShape();
308 ArrayRef<int64_t> accessedShape = rawAccessedShape;
309 if (collapse.getSrcType().getRank() <
310 static_cast<int64_t
>(accessedShape.size()))
312 op,
"collapse_shape source rank is too small for the accessed shape");
317 if (!accessedShape.empty())
318 accessedShape = accessedShape.drop_front();
320 SmallVector<ReassociationIndices, 4> reassocs =
321 collapse.getReassociationIndices();
324 "semantically important dimensions");
326 SmallVector<Value> sourceIndices;
328 op.getIndices(), sourceIndices,
329 op.hasInboundsIndices());
331 std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices(
332 rewriter, collapse.getViewSource(), sourceIndices);
338LogicalResult IndexedMemCopyOpOfSubViewOpFolder::matchAndRewrite(
339 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter)
const {
340 auto srcSubview = op.getSrc().getDefiningOp<memref::SubViewOp>();
341 auto dstSubview = op.getDst().getDefiningOp<memref::SubViewOp>();
342 if (!srcSubview && !dstSubview)
344 op,
"no subviews found on indexed copy inputs");
346 Value newSrc = op.getSrc();
347 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
348 Value newDst = op.getDst();
349 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
351 newSrc = srcSubview.getSource();
352 newSrcIndices.clear();
354 rewriter, op.getLoc(), srcSubview.getMixedOffsets(),
355 srcSubview.getMixedStrides(), srcSubview.getDroppedDims(),
356 op.getSrcIndices(), newSrcIndices);
359 newDst = dstSubview.getSource();
360 newDstIndices.clear();
362 rewriter, op.getLoc(), dstSubview.getMixedOffsets(),
363 dstSubview.getMixedStrides(), dstSubview.getDroppedDims(),
364 op.getDstIndices(), newDstIndices);
366 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
371LogicalResult IndexedMemCopyOpOfExpandShapeOpFolder::matchAndRewrite(
372 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter)
const {
373 auto srcExpand = op.getSrc().getDefiningOp<memref::ExpandShapeOp>();
374 auto dstExpand = op.getDst().getDefiningOp<memref::ExpandShapeOp>();
375 if (!srcExpand && !dstExpand)
377 op,
"no expand_shapes found on indexed copy inputs");
379 Value newSrc = op.getSrc();
380 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
381 Value newDst = op.getDst();
382 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
384 newSrc = srcExpand.getViewSource();
385 newSrcIndices.clear();
387 op.getSrcIndices(), newSrcIndices,
391 newDst = dstExpand.getViewSource();
392 newDstIndices.clear();
394 op.getDstIndices(), newDstIndices,
397 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
402LogicalResult IndexedMemCopyOpOfCollapseShapeOpFolder::matchAndRewrite(
403 memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter)
const {
404 auto srcCollapse = op.getSrc().getDefiningOp<memref::CollapseShapeOp>();
405 auto dstCollapse = op.getDst().getDefiningOp<memref::CollapseShapeOp>();
406 if (!srcCollapse && !dstCollapse)
408 op,
"no collapse_shapes found on indexed copy inputs");
410 Value newSrc = op.getSrc();
411 SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
412 Value newDst = op.getDst();
413 SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
415 newSrc = srcCollapse.getViewSource();
416 newSrcIndices.clear();
418 op.getLoc(), rewriter, srcCollapse, op.getSrcIndices(), newSrcIndices,
422 newDst = dstCollapse.getViewSource();
423 newDstIndices.clear();
425 op.getLoc(), rewriter, dstCollapse, op.getDstIndices(), newDstIndices,
428 op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
434TransferOpOfSubViewOpFolder::matchAndRewrite(VectorTransferOpInterface op,
436 auto subview = op.getBase().getDefiningOp<memref::SubViewOp>();
444 if (op.hasOutOfBoundsDim())
446 VectorType vecTy = op.getVectorType();
451 Twine(vecTy.getRank()) +
455 subview.getDroppedDims());
457 if (failed(op.mayUpdateStartingPosition(subview.getSourceType(), newPerm)))
459 "failed op-specific preconditions");
463 rewriter, op.getLoc(), subview.getMixedOffsets(),
464 subview.getMixedStrides(), subview.getDroppedDims(), op.getIndices(),
466 op.updateStartingPosition(rewriter, subview.getSource(), newIndices,
467 AffineMapAttr::get(newPerm));
471LogicalResult TransferOpOfExpandShapeOpFolder::matchAndRewrite(
473 auto expand = op.getBase().getDefiningOp<memref::ExpandShapeOp>();
477 if (op.hasOutOfBoundsDim())
480 int64_t srcRank = expand.getSrc().getType().getRank();
481 int64_t vecRank = op.getVectorType().getRank();
482 if (srcRank < vecRank)
484 "source rank is less than vector rank");
486 llvm::SmallDenseMap<int64_t, int64_t, 8> unstridedResDimToSrcDim;
487 for (
auto [srcIdx, reassoc] :
488 llvm::enumerate(expand.getReassociationIndices())) {
489 unstridedResDimToSrcDim.insert({reassoc.back(), srcIdx});
500 auto resDim = dyn_cast<AffineDimExpr>(permRes);
503 op,
"has non-dim entry in permutation map");
504 auto dimInSrc = unstridedResDimToSrcDim.find(resDim.getPosition());
505 if (dimInSrc == unstridedResDimToSrcDim.end())
507 "permutation map result would be made "
508 "strided by expand_shape folding");
512 auto newPerm =
AffineMap::get(srcRank, 0, newPermMapResults, op.getContext());
514 if (failed(op.mayUpdateStartingPosition(expand.getSrc().getType(), newPerm)))
521 op.getIndices(), newIndices,
524 op.updateStartingPosition(rewriter, expand.getViewSource(), newIndices,
525 AffineMapAttr::get(newPerm));
529LogicalResult TransferOpOfCollapseShapeOpFolder::matchAndRewrite(
530 VectorTransferOpInterface op, PatternRewriter &rewriter)
const {
531 auto collapse = op.getBase().getDefiningOp<memref::CollapseShapeOp>();
535 if (!op.getPermutationMap().isMinorIdentity())
537 "non-minor identity permutation map");
539 if (op.hasOutOfBoundsDim())
542 int64_t srcRank = collapse.getSrc().getType().getRank();
543 int64_t vecRank = op.getVectorType().getRank();
544 if (srcRank < vecRank)
546 "source rank is less than vector rank");
551 SmallVector<ReassociationIndices> reassocs =
552 collapse.getReassociationIndices();
555 op,
"collapse_shape folding would split a transfer dimension");
560 op.mayUpdateStartingPosition(collapse.getSrc().getType(), newPerm)))
563 SmallVector<Value> newIndices;
565 op.getIndices(), newIndices,
568 op.updateStartingPosition(rewriter, collapse.getViewSource(), newIndices,
569 AffineMapAttr::get(newPerm));
575 .
add<AccessOpOfSubViewOpFolder, AccessOpOfExpandShapeOpFolder,
576 AccessOpOfCollapseShapeOpFolder, IndexedMemCopyOpOfSubViewOpFolder,
577 IndexedMemCopyOpOfExpandShapeOpFolder,
578 IndexedMemCopyOpOfCollapseShapeOpFolder, TransferOpOfSubViewOpFolder,
579 TransferOpOfExpandShapeOpFolder, TransferOpOfCollapseShapeOpFolder,
580 SubViewOfSubViewFolder>(patterns.
getContext());
589struct FoldMemRefAliasOpsPass final
591 void runOnOperation()
override;
596void FoldMemRefAliasOpsPass::runOnOperation() {
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineExpr getAffineDimExpr(unsigned position)
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
LogicalResult mergeOffsetsSizesAndStrides(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > producerOffsets, ArrayRef< OpFoldResult > producerSizes, ArrayRef< OpFoldResult > producerStrides, const llvm::SmallBitVector &droppedProducerDims, ArrayRef< OpFoldResult > consumerOffsets, ArrayRef< OpFoldResult > consumerSizes, ArrayRef< OpFoldResult > consumerStrides, SmallVector< OpFoldResult > &combinedOffsets, SmallVector< OpFoldResult > &combinedSizes, SmallVector< OpFoldResult > &combinedStrides)
Fills the combinedOffsets, combinedSizes and combinedStrides to use when combining a producer slice i...
void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding memref aliasing ops into consumer load/store ops into patterns.
void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
void resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
Include the generated interface declarations.
AffineMap expandDimsToRank(AffineMap map, int64_t rank, const llvm::SmallBitVector &projectedDimensions)
Expand map to operate on rank dims while projecting out the dims in projectedDimensions.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< int64_t, 2 > ReassociationIndices
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...