MLIR  22.0.0git
DecomposeMemRefs.cpp
Go to the documentation of this file.
1 //===- DecomposeMemRefs.cpp - Decompose memrefs pass implementation -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements decompose memrefs pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/AffineExpr.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/PatternMatch.h"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_GPUDECOMPOSEMEMREFSPASS
25 #include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 
30 static MemRefType inferCastResultType(Value source, OpFoldResult offset) {
31  auto sourceType = cast<BaseMemRefType>(source.getType());
32  SmallVector<int64_t> staticOffsets;
33  SmallVector<Value> dynamicOffsets;
34  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
35  auto stridedLayout =
36  StridedLayoutAttr::get(source.getContext(), staticOffsets.front(), {});
37  return MemRefType::get({}, sourceType.getElementType(), stridedLayout,
38  sourceType.getMemorySpace());
39 }
40 
41 static void setInsertionPointToStart(OpBuilder &builder, Value val) {
42  if (auto *parentOp = val.getDefiningOp()) {
43  builder.setInsertionPointAfter(parentOp);
44  } else {
46  }
47 }
48 
49 static bool isInsideLaunch(Operation *op) {
50  return op->getParentOfType<gpu::LaunchOp>();
51 }
52 
53 static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>>
55  ArrayRef<OpFoldResult> subOffsets,
56  ArrayRef<OpFoldResult> subStrides = {}) {
57  auto sourceType = cast<MemRefType>(source.getType());
58  auto sourceRank = static_cast<unsigned>(sourceType.getRank());
59 
60  memref::ExtractStridedMetadataOp newExtractStridedMetadata;
61  {
62  OpBuilder::InsertionGuard g(rewriter);
63  setInsertionPointToStart(rewriter, source);
64  newExtractStridedMetadata =
65  memref::ExtractStridedMetadataOp::create(rewriter, loc, source);
66  }
67 
68  auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
69 
70  auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
71  return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal)
72  : rewriter.getIndexAttr(dim);
73  };
74 
75  OpFoldResult origOffset =
76  getDim(sourceOffset, newExtractStridedMetadata.getOffset());
77  ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
78 
79  SmallVector<OpFoldResult> origStrides;
80  origStrides.reserve(sourceRank);
81 
83  strides.reserve(sourceRank);
84 
85  AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
86  AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
87  for (auto i : llvm::seq(0u, sourceRank)) {
88  OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
89 
90  if (!subStrides.empty()) {
91  strides.push_back(affine::makeComposedFoldedAffineApply(
92  rewriter, loc, s0 * s1, {subStrides[i], origStride}));
93  }
94 
95  origStrides.emplace_back(origStride);
96  }
97 
98  auto &&[expr, values] =
99  computeLinearIndex(origOffset, origStrides, subOffsets);
100  OpFoldResult finalOffset =
101  affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values);
102  return {newExtractStridedMetadata.getBaseBuffer(), finalOffset, strides};
103 }
104 
105 static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source,
106  ValueRange offsets) {
107  SmallVector<OpFoldResult> offsetsTemp = getAsOpFoldResult(offsets);
108  auto &&[base, offset, ignore] =
109  getFlatOffsetAndStrides(rewriter, loc, source, offsetsTemp);
110  MemRefType retType = inferCastResultType(base, offset);
111  return memref::ReinterpretCastOp::create(rewriter, loc, retType, base, offset,
114 }
115 
116 static bool needFlatten(Value val) {
117  auto type = cast<MemRefType>(val.getType());
118  return type.getRank() != 0;
119 }
120 
121 static bool checkLayout(Value val) {
122  auto type = cast<MemRefType>(val.getType());
123  return type.getLayout().isIdentity() ||
124  isa<StridedLayoutAttr>(type.getLayout());
125 }
126 
127 namespace {
128 struct FlattenLoad : public OpRewritePattern<memref::LoadOp> {
130 
131  LogicalResult matchAndRewrite(memref::LoadOp op,
132  PatternRewriter &rewriter) const override {
133  if (!isInsideLaunch(op))
134  return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
135 
136  Value memref = op.getMemref();
137  if (!needFlatten(memref))
138  return rewriter.notifyMatchFailure(op, "nothing to do");
139 
140  if (!checkLayout(memref))
141  return rewriter.notifyMatchFailure(op, "unsupported layout");
142 
143  Location loc = op.getLoc();
144  Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices());
145  rewriter.replaceOpWithNewOp<memref::LoadOp>(op, flatMemref);
146  return success();
147  }
148 };
149 
150 struct FlattenStore : public OpRewritePattern<memref::StoreOp> {
152 
153  LogicalResult matchAndRewrite(memref::StoreOp op,
154  PatternRewriter &rewriter) const override {
155  if (!isInsideLaunch(op))
156  return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
157 
158  Value memref = op.getMemref();
159  if (!needFlatten(memref))
160  return rewriter.notifyMatchFailure(op, "nothing to do");
161 
162  if (!checkLayout(memref))
163  return rewriter.notifyMatchFailure(op, "unsupported layout");
164 
165  Location loc = op.getLoc();
166  Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices());
167  Value value = op.getValue();
168  rewriter.replaceOpWithNewOp<memref::StoreOp>(op, value, flatMemref);
169  return success();
170  }
171 };
172 
173 struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
175 
176  LogicalResult matchAndRewrite(memref::SubViewOp op,
177  PatternRewriter &rewriter) const override {
178  if (!isInsideLaunch(op))
179  return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
180 
181  Value memref = op.getSource();
182  if (!needFlatten(memref))
183  return rewriter.notifyMatchFailure(op, "nothing to do");
184 
185  if (!checkLayout(memref))
186  return rewriter.notifyMatchFailure(op, "unsupported layout");
187 
188  Location loc = op.getLoc();
189  SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets();
190  SmallVector<OpFoldResult> subSizes = op.getMixedSizes();
191  SmallVector<OpFoldResult> subStrides = op.getMixedStrides();
192  auto &&[base, finalOffset, strides] =
193  getFlatOffsetAndStrides(rewriter, loc, memref, subOffsets, subStrides);
194 
195  auto srcType = cast<MemRefType>(memref.getType());
196  auto resultType = cast<MemRefType>(op.getType());
197  unsigned subRank = static_cast<unsigned>(resultType.getRank());
198 
199  llvm::SmallBitVector droppedDims = op.getDroppedDims();
200 
201  SmallVector<OpFoldResult> finalSizes;
202  finalSizes.reserve(subRank);
203 
204  SmallVector<OpFoldResult> finalStrides;
205  finalStrides.reserve(subRank);
206 
207  for (auto i : llvm::seq(0u, static_cast<unsigned>(srcType.getRank()))) {
208  if (droppedDims.test(i))
209  continue;
210 
211  finalSizes.push_back(subSizes[i]);
212  finalStrides.push_back(strides[i]);
213  }
214 
215  rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
216  op, resultType, base, finalOffset, finalSizes, finalStrides);
217  return success();
218  }
219 };
220 
221 struct GpuDecomposeMemrefsPass
222  : public impl::GpuDecomposeMemrefsPassBase<GpuDecomposeMemrefsPass> {
223 
224  void runOnOperation() override {
226 
228 
229  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
230  return signalPassFailure();
231  }
232 };
233 
234 } // namespace
235 
237  patterns.insert<FlattenLoad, FlattenStore, FlattenSubview>(
238  patterns.getContext());
239 }
static bool isInsideLaunch(Operation *op)
static bool needFlatten(Value val)
static MemRefType inferCastResultType(Value source, OpFoldResult offset)
static bool checkLayout(Value val)
static std::tuple< Value, OpFoldResult, SmallVector< OpFoldResult > > getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source, ArrayRef< OpFoldResult > subOffsets, ArrayRef< OpFoldResult > subStrides={})
static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source, ValueRange offsets)
static void setInsertionPointToStart(OpBuilder &builder, Value val)
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:363
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:769
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:702
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Type getType() const
Return the type of this value.
Definition: Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:46
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1331
Include the generated interface declarations.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns)
Collect a set of patterns to decompose memrefs ops.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319