MLIR 22.0.0git
DecomposeMemRefs.cpp
Go to the documentation of this file.
1//===- DecomposeMemRefs.cpp - Decompose memrefs pass implementation -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements decompose memrefs pass.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/AffineExpr.h"
19#include "mlir/IR/Builders.h"
22
23namespace mlir {
24#define GEN_PASS_DEF_GPUDECOMPOSEMEMREFSPASS
25#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29
30static MemRefType inferCastResultType(Value source, OpFoldResult offset) {
31 auto sourceType = cast<BaseMemRefType>(source.getType());
32 SmallVector<int64_t> staticOffsets;
33 SmallVector<Value> dynamicOffsets;
34 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
35 auto stridedLayout =
36 StridedLayoutAttr::get(source.getContext(), staticOffsets.front(), {});
37 return MemRefType::get({}, sourceType.getElementType(), stridedLayout,
38 sourceType.getMemorySpace());
39}
40
41static void setInsertionPointToStart(OpBuilder &builder, Value val) {
42 if (auto *parentOp = val.getDefiningOp()) {
43 builder.setInsertionPointAfter(parentOp);
44 } else {
46 }
47}
48
49static bool isInsideLaunch(Operation *op) {
50 return op->getParentOfType<gpu::LaunchOp>();
51}
52
53static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>>
55 ArrayRef<OpFoldResult> subOffsets,
56 ArrayRef<OpFoldResult> subStrides = {}) {
57 auto sourceType = cast<MemRefType>(source.getType());
58 auto sourceRank = static_cast<unsigned>(sourceType.getRank());
59
60 memref::ExtractStridedMetadataOp newExtractStridedMetadata;
61 {
62 OpBuilder::InsertionGuard g(rewriter);
63 setInsertionPointToStart(rewriter, source);
64 newExtractStridedMetadata =
65 memref::ExtractStridedMetadataOp::create(rewriter, loc, source);
66 }
67
68 auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
69
70 auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
71 return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal)
72 : rewriter.getIndexAttr(dim);
73 };
74
75 OpFoldResult origOffset =
76 getDim(sourceOffset, newExtractStridedMetadata.getOffset());
77 ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
78
79 SmallVector<OpFoldResult> origStrides;
80 origStrides.reserve(sourceRank);
81
83 strides.reserve(sourceRank);
84
85 AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
86 AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
87 for (auto i : llvm::seq(0u, sourceRank)) {
88 OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
89
90 if (!subStrides.empty()) {
92 rewriter, loc, s0 * s1, {subStrides[i], origStride}));
93 }
94
95 origStrides.emplace_back(origStride);
96 }
97
98 auto &&[expr, values] =
99 computeLinearIndex(origOffset, origStrides, subOffsets);
100 OpFoldResult finalOffset =
101 affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values);
102 return {newExtractStridedMetadata.getBaseBuffer(), finalOffset, strides};
103}
104
105static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source,
106 ValueRange offsets) {
107 SmallVector<OpFoldResult> offsetsTemp = getAsOpFoldResult(offsets);
108 auto &&[base, offset, ignore] =
109 getFlatOffsetAndStrides(rewriter, loc, source, offsetsTemp);
110 MemRefType retType = inferCastResultType(base, offset);
111 return memref::ReinterpretCastOp::create(rewriter, loc, retType, base, offset,
116static bool needFlatten(Value val) {
117 auto type = cast<MemRefType>(val.getType());
118 return type.getRank() != 0;
120
121static bool checkLayout(Value val) {
122 auto type = cast<MemRefType>(val.getType());
123 return type.getLayout().isIdentity() ||
124 isa<StridedLayoutAttr>(type.getLayout());
126
127namespace {
128struct FlattenLoad : public OpRewritePattern<memref::LoadOp> {
131 LogicalResult matchAndRewrite(memref::LoadOp op,
132 PatternRewriter &rewriter) const override {
134 return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
135
136 Value memref = op.getMemref();
137 if (!needFlatten(memref))
138 return rewriter.notifyMatchFailure(op, "nothing to do");
139
140 if (!checkLayout(memref))
141 return rewriter.notifyMatchFailure(op, "unsupported layout");
142
143 Location loc = op.getLoc();
144 Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices());
145 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, flatMemref);
146 return success();
147 }
148};
149
150struct FlattenStore : public OpRewritePattern<memref::StoreOp> {
152
153 LogicalResult matchAndRewrite(memref::StoreOp op,
154 PatternRewriter &rewriter) const override {
155 if (!isInsideLaunch(op))
156 return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
157
158 Value memref = op.getMemref();
159 if (!needFlatten(memref))
160 return rewriter.notifyMatchFailure(op, "nothing to do");
161
162 if (!checkLayout(memref))
163 return rewriter.notifyMatchFailure(op, "unsupported layout");
164
165 Location loc = op.getLoc();
166 Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices());
167 Value value = op.getValue();
168 rewriter.replaceOpWithNewOp<memref::StoreOp>(op, value, flatMemref);
169 return success();
170 }
171};
172
173struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
175
176 LogicalResult matchAndRewrite(memref::SubViewOp op,
177 PatternRewriter &rewriter) const override {
178 if (!isInsideLaunch(op))
179 return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
180
181 Value memref = op.getSource();
182 if (!needFlatten(memref))
183 return rewriter.notifyMatchFailure(op, "nothing to do");
184
185 if (!checkLayout(memref))
186 return rewriter.notifyMatchFailure(op, "unsupported layout");
187
188 Location loc = op.getLoc();
189 SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets();
190 SmallVector<OpFoldResult> subSizes = op.getMixedSizes();
191 SmallVector<OpFoldResult> subStrides = op.getMixedStrides();
192 auto &&[base, finalOffset, strides] =
193 getFlatOffsetAndStrides(rewriter, loc, memref, subOffsets, subStrides);
194
195 auto srcType = cast<MemRefType>(memref.getType());
196 auto resultType = cast<MemRefType>(op.getType());
197 unsigned subRank = static_cast<unsigned>(resultType.getRank());
198
199 llvm::SmallBitVector droppedDims = op.getDroppedDims();
200
201 SmallVector<OpFoldResult> finalSizes;
202 finalSizes.reserve(subRank);
203
204 SmallVector<OpFoldResult> finalStrides;
205 finalStrides.reserve(subRank);
206
207 for (auto i : llvm::seq(0u, static_cast<unsigned>(srcType.getRank()))) {
208 if (droppedDims.test(i))
209 continue;
210
211 finalSizes.push_back(subSizes[i]);
212 finalStrides.push_back(strides[i]);
213 }
214
215 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
216 op, resultType, base, finalOffset, finalSizes, finalStrides);
217 return success();
218 }
219};
220
221struct GpuDecomposeMemrefsPass
222 : public impl::GpuDecomposeMemrefsPassBase<GpuDecomposeMemrefsPass> {
223
224 void runOnOperation() override {
225 RewritePatternSet patterns(&getContext());
226
228
229 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
230 return signalPassFailure();
231 }
232};
233
234} // namespace
235
237 patterns.insert<FlattenLoad, FlattenStore, FlattenSubview>(
238 patterns.getContext());
239}
return success()
static bool isInsideLaunch(Operation *op)
static bool needFlatten(Value val)
static MemRefType inferCastResultType(Value source, OpFoldResult offset)
static bool checkLayout(Value val)
static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source, ValueRange offsets)
static void setInsertionPointToStart(OpBuilder &builder, Value val)
static std::tuple< Value, OpFoldResult, SmallVector< OpFoldResult > > getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source, ArrayRef< OpFoldResult > subOffsets, ArrayRef< OpFoldResult > subStrides={})
b getContext())
Base type for affine expression.
Definition AffineExpr.h:68
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:368
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition Value.cpp:46
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
void populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns)
Collect a set of patterns to decompose memrefs ops.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...