MLIR 23.0.0git
FoldMemRefsOps.cpp
Go to the documentation of this file.
1//===- FoldMemRefsOps.cpp - AMDGPU fold memref ops ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
16#include "llvm/ADT/TypeSwitch.h"
17
18namespace mlir::amdgpu {
19#define GEN_PASS_DEF_AMDGPUFOLDMEMREFOPSPASS
20#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
21
23 : amdgpu::impl::AmdgpuFoldMemRefOpsPassBase<AmdgpuFoldMemRefOpsPass> {
24 void runOnOperation() override {
25 RewritePatternSet patterns(&getContext());
27 walkAndApplyPatterns(getOperation(), std::move(patterns));
28 }
29};
30
31static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc,
33 SmallVectorImpl<Value> &resolvedIndices,
34 Value &memrefBase, StringRef role) {
35 Operation *defOp = view.getDefiningOp();
36 if (!defOp) {
37 return failure();
38 }
40 .Case([&](memref::SubViewOp subviewOp) {
42 rewriter, loc, subviewOp.getMixedOffsets(),
43 subviewOp.getMixedStrides(), subviewOp.getDroppedDims(), indices,
44 resolvedIndices);
45 memrefBase = subviewOp.getSource();
46 return success();
47 })
48 .Case([&](memref::ExpandShapeOp expandShapeOp) {
50 loc, rewriter, expandShapeOp, indices, resolvedIndices, false);
51 memrefBase = expandShapeOp.getViewSource();
52 return success();
53 })
54 .Case([&](memref::CollapseShapeOp collapseShapeOp) {
56 loc, rewriter, collapseShapeOp, indices, resolvedIndices);
57 memrefBase = collapseShapeOp.getViewSource();
58 return success();
59 })
60 .Default([&](Operation *op) {
61 return rewriter.notifyMatchFailure(
62 op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or "
63 "CollapseShapeOp")
64 .str());
65 });
66}
67
68struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
69 using Base::Base;
70 LogicalResult matchAndRewrite(GatherToLDSOp op,
71 PatternRewriter &rewriter) const override {
72 Location loc = op.getLoc();
73
74 SmallVector<Value> sourceIndices, destIndices;
75 Value memrefSource, memrefDest;
76
77 auto foldSrcResult =
78 foldMemrefViewOp(rewriter, loc, op.getSrc(), op.getSrcIndices(),
79 sourceIndices, memrefSource, "source");
80
81 if (failed(foldSrcResult)) {
82 memrefSource = op.getSrc();
83 sourceIndices = op.getSrcIndices();
84 }
85
86 auto foldDstResult =
87 foldMemrefViewOp(rewriter, loc, op.getDst(), op.getDstIndices(),
88 destIndices, memrefDest, "destination");
89
90 if (failed(foldDstResult)) {
91 memrefDest = op.getDst();
92 destIndices = op.getDstIndices();
93 }
94
95 if (failed(foldSrcResult) && failed(foldDstResult))
96 return rewriter.notifyMatchFailure(op, "no fold found");
97
98 rewriter.replaceOpWithNewOp<GatherToLDSOp>(
99 op, memrefSource, sourceIndices, memrefDest, destIndices,
100 op.getTransferType(), op.getAsync());
101
102 return success();
103 }
104};
105
106template <typename OpTy>
109 LogicalResult matchAndRewrite(OpTy op,
110 PatternRewriter &rewriter) const override {
111 Location loc = op.getLoc();
112
113 SmallVector<Value> globalIndices, ldsIndices;
114 Value globalBase, ldsBase;
115
116 LogicalResult didFoldGlobal =
117 foldMemrefViewOp(rewriter, loc, op.getGlobal(), op.getGlobalIndices(),
118 globalIndices, globalBase, "global");
119 if (failed(didFoldGlobal)) {
120 globalBase = op.getGlobal();
121 globalIndices = op.getGlobalIndices();
122 }
123
124 LogicalResult didFoldLds =
125 foldMemrefViewOp(rewriter, loc, op.getLds(), op.getLdsIndices(),
126 ldsIndices, ldsBase, "lds");
127 if (failed(didFoldLds)) {
128 ldsBase = op.getLds();
129 ldsIndices = op.getLdsIndices();
130 }
131
132 if (failed(didFoldGlobal) && failed(didFoldLds))
133 return rewriter.notifyMatchFailure(op, "no fold found");
134
135 rewriter.replaceOpWithNewOp<OpTy>(op, op.getBase().getType(), globalBase,
136 globalIndices, ldsBase, ldsIndices);
137 return success();
138 }
139};
140
142 : OpRewritePattern<TransposeLoadOp> {
143 using Base::Base;
144 LogicalResult matchAndRewrite(TransposeLoadOp op,
145 PatternRewriter &rewriter) const override {
146 SmallVector<Value> sourceIndices;
147 Value memrefSource;
148
149 if (failed(foldMemrefViewOp(rewriter, op.getLoc(), op.getSrc(),
150 op.getSrcIndices(), sourceIndices, memrefSource,
151 "source")))
152 return failure();
153
154 rewriter.replaceOpWithNewOp<TransposeLoadOp>(op, op.getResult().getType(),
155 memrefSource, sourceIndices);
156 return success();
157 }
158};
159
168} // namespace mlir::amdgpu
return success()
b getContext())
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
void resolveIndicesIntoOpWithOffsetsAndStrides(RewriterBase &rewriter, Location loc, ArrayRef< OpFoldResult > mixedSourceOffsets, ArrayRef< OpFoldResult > mixedSourceStrides, const llvm::SmallBitVector &rankReducedDims, ArrayRef< OpFoldResult > consumerIndices, SmallVectorImpl< Value > &resolvedIndices)
Given the 'consumerIndices' of a load/store operation operating on an op with offsets and strides,...
void populateAmdgpuFoldMemRefOpsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc, Value view, mlir::OperandRange indices, SmallVectorImpl< Value > &resolvedIndices, Value &memrefBase, StringRef role)
void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, memref::CollapseShapeOp collapseShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices)
Given the 'indices' of a load/store operation where the memref is a result of a collapse_shape op,...
void resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl< Value > &sourceIndices, bool startsInbounds)
Given the 'indices' of a load/store operation where the memref is a result of a expand_shape op,...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(GatherToLDSOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(TransposeLoadOp op, PatternRewriter &rewriter) const override