MLIR  22.0.0git
FoldMemRefsOps.cpp
Go to the documentation of this file.
1 //===- FoldSubviewOps.cpp - AMDGPU fold subview ops -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "llvm/ADT/TypeSwitch.h"
17 
18 namespace mlir::amdgpu {
19 #define GEN_PASS_DEF_AMDGPUFOLDMEMREFOPSPASS
20 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
21 
23  : amdgpu::impl::AmdgpuFoldMemRefOpsPassBase<AmdgpuFoldMemRefOpsPass> {
24  void runOnOperation() override {
27  walkAndApplyPatterns(getOperation(), std::move(patterns));
28  }
29 };
30 
31 static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc,
32  Value view, mlir::OperandRange indices,
33  SmallVectorImpl<Value> &resolvedIndices,
34  Value &memrefBase, StringRef role) {
35  Operation *defOp = view.getDefiningOp();
36  if (!defOp) {
37  return failure();
38  }
40  .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
42  rewriter, loc, subviewOp.getMixedOffsets(),
43  subviewOp.getMixedStrides(), subviewOp.getDroppedDims(), indices,
44  resolvedIndices);
45  memrefBase = subviewOp.getSource();
46  return success();
47  })
48  .Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
50  loc, rewriter, expandShapeOp, indices, resolvedIndices,
51  false))) {
52  return failure();
53  }
54  memrefBase = expandShapeOp.getViewSource();
55  return success();
56  })
57  .Case<memref::CollapseShapeOp>(
58  [&](memref::CollapseShapeOp collapseShapeOp) {
60  loc, rewriter, collapseShapeOp, indices,
61  resolvedIndices))) {
62  return failure();
63  }
64  memrefBase = collapseShapeOp.getViewSource();
65  return success();
66  })
67  .Default([&](Operation *op) {
68  return rewriter.notifyMatchFailure(
69  op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or "
70  "CollapseShapeOp")
71  .str());
72  });
73 }
74 
75 struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
77  LogicalResult matchAndRewrite(GatherToLDSOp op,
78  PatternRewriter &rewriter) const override {
79  Location loc = op.getLoc();
80 
81  SmallVector<Value> sourceIndices, destIndices;
82  Value memrefSource, memrefDest;
83 
84  auto foldSrcResult =
85  foldMemrefViewOp(rewriter, loc, op.getSrc(), op.getSrcIndices(),
86  sourceIndices, memrefSource, "source");
87 
88  if (failed(foldSrcResult)) {
89  memrefSource = op.getSrc();
90  sourceIndices = op.getSrcIndices();
91  }
92 
93  auto foldDstResult =
94  foldMemrefViewOp(rewriter, loc, op.getDst(), op.getDstIndices(),
95  destIndices, memrefDest, "destination");
96 
97  if (failed(foldDstResult)) {
98  memrefDest = op.getDst();
99  destIndices = op.getDstIndices();
100  }
101 
102  rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
103  memrefDest, destIndices,
104  op.getTransferType());
105 
106  return success();
107  }
108 };
109 
111  PatternBenefit benefit) {
112  patterns.add<FoldMemRefOpsIntoGatherToLDSOp>(patterns.getContext(), benefit);
113 }
114 } // namespace mlir::amdgpu
static MLIRContext * getContext(OpFoldResult val)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
void populateAmdgpuFoldMemRefOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc, Value view, mlir::OperandRange indices, SmallVectorImpl< Value > &resolvedIndices, Value &memrefBase, StringRef role)
LogicalResult resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
LogicalResult resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
LogicalResult matchAndRewrite(GatherToLDSOp op, PatternRewriter &rewriter) const override