MLIR  20.0.0git
DecomposeMemRefs.cpp
Go to the documentation of this file.
1 //===- DecomposeMemRefs.cpp - Decompose memrefs pass implementation -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements decompose memrefs pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
19 #include "mlir/IR/AffineExpr.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Pass/Pass.h"
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_GPUDECOMPOSEMEMREFSPASS
27 #include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
28 } // namespace mlir
29 
30 using namespace mlir;
31 
32 static MemRefType inferCastResultType(Value source, OpFoldResult offset) {
33  auto sourceType = cast<BaseMemRefType>(source.getType());
34  SmallVector<int64_t> staticOffsets;
35  SmallVector<Value> dynamicOffsets;
36  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
37  auto stridedLayout =
38  StridedLayoutAttr::get(source.getContext(), staticOffsets.front(), {});
39  return MemRefType::get({}, sourceType.getElementType(), stridedLayout,
40  sourceType.getMemorySpace());
41 }
42 
43 static void setInsertionPointToStart(OpBuilder &builder, Value val) {
44  if (auto *parentOp = val.getDefiningOp()) {
45  builder.setInsertionPointAfter(parentOp);
46  } else {
48  }
49 }
50 
51 static bool isInsideLaunch(Operation *op) {
52  return op->getParentOfType<gpu::LaunchOp>();
53 }
54 
55 static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>>
57  ArrayRef<OpFoldResult> subOffsets,
58  ArrayRef<OpFoldResult> subStrides = std::nullopt) {
59  auto sourceType = cast<MemRefType>(source.getType());
60  auto sourceRank = static_cast<unsigned>(sourceType.getRank());
61 
62  memref::ExtractStridedMetadataOp newExtractStridedMetadata;
63  {
64  OpBuilder::InsertionGuard g(rewriter);
65  setInsertionPointToStart(rewriter, source);
66  newExtractStridedMetadata =
67  rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
68  }
69 
70  auto &&[sourceStrides, sourceOffset] = getStridesAndOffset(sourceType);
71 
72  auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
73  return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal)
74  : rewriter.getIndexAttr(dim);
75  };
76 
77  OpFoldResult origOffset =
78  getDim(sourceOffset, newExtractStridedMetadata.getOffset());
79  ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
80 
81  SmallVector<OpFoldResult> origStrides;
82  origStrides.reserve(sourceRank);
83 
85  strides.reserve(sourceRank);
86 
87  AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
88  AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
89  for (auto i : llvm::seq(0u, sourceRank)) {
90  OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
91 
92  if (!subStrides.empty()) {
93  strides.push_back(affine::makeComposedFoldedAffineApply(
94  rewriter, loc, s0 * s1, {subStrides[i], origStride}));
95  }
96 
97  origStrides.emplace_back(origStride);
98  }
99 
100  auto &&[expr, values] =
101  computeLinearIndex(origOffset, origStrides, subOffsets);
102  OpFoldResult finalOffset =
103  affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values);
104  return {newExtractStridedMetadata.getBaseBuffer(), finalOffset, strides};
105 }
106 
107 static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source,
108  ValueRange offsets) {
109  SmallVector<OpFoldResult> offsetsTemp = getAsOpFoldResult(offsets);
110  auto &&[base, offset, ignore] =
111  getFlatOffsetAndStrides(rewriter, loc, source, offsetsTemp);
112  MemRefType retType = inferCastResultType(base, offset);
113  return rewriter.create<memref::ReinterpretCastOp>(loc, retType, base, offset,
114  std::nullopt, std::nullopt);
115 }
116 
117 static bool needFlatten(Value val) {
118  auto type = cast<MemRefType>(val.getType());
119  return type.getRank() != 0;
120 }
121 
122 static bool checkLayout(Value val) {
123  auto type = cast<MemRefType>(val.getType());
124  return type.getLayout().isIdentity() ||
125  isa<StridedLayoutAttr>(type.getLayout());
126 }
127 
128 namespace {
129 struct FlattenLoad : public OpRewritePattern<memref::LoadOp> {
131 
132  LogicalResult matchAndRewrite(memref::LoadOp op,
133  PatternRewriter &rewriter) const override {
134  if (!isInsideLaunch(op))
135  return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
136 
137  Value memref = op.getMemref();
138  if (!needFlatten(memref))
139  return rewriter.notifyMatchFailure(op, "nothing to do");
140 
141  if (!checkLayout(memref))
142  return rewriter.notifyMatchFailure(op, "unsupported layout");
143 
144  Location loc = op.getLoc();
145  Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices());
146  rewriter.replaceOpWithNewOp<memref::LoadOp>(op, flatMemref);
147  return success();
148  }
149 };
150 
151 struct FlattenStore : public OpRewritePattern<memref::StoreOp> {
153 
154  LogicalResult matchAndRewrite(memref::StoreOp op,
155  PatternRewriter &rewriter) const override {
156  if (!isInsideLaunch(op))
157  return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
158 
159  Value memref = op.getMemref();
160  if (!needFlatten(memref))
161  return rewriter.notifyMatchFailure(op, "nothing to do");
162 
163  if (!checkLayout(memref))
164  return rewriter.notifyMatchFailure(op, "unsupported layout");
165 
166  Location loc = op.getLoc();
167  Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices());
168  Value value = op.getValue();
169  rewriter.replaceOpWithNewOp<memref::StoreOp>(op, value, flatMemref);
170  return success();
171  }
172 };
173 
174 struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
176 
177  LogicalResult matchAndRewrite(memref::SubViewOp op,
178  PatternRewriter &rewriter) const override {
179  if (!isInsideLaunch(op))
180  return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
181 
182  Value memref = op.getSource();
183  if (!needFlatten(memref))
184  return rewriter.notifyMatchFailure(op, "nothing to do");
185 
186  if (!checkLayout(memref))
187  return rewriter.notifyMatchFailure(op, "unsupported layout");
188 
189  Location loc = op.getLoc();
190  SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets();
191  SmallVector<OpFoldResult> subSizes = op.getMixedSizes();
192  SmallVector<OpFoldResult> subStrides = op.getMixedStrides();
193  auto &&[base, finalOffset, strides] =
194  getFlatOffsetAndStrides(rewriter, loc, memref, subOffsets, subStrides);
195 
196  auto srcType = cast<MemRefType>(memref.getType());
197  auto resultType = cast<MemRefType>(op.getType());
198  unsigned subRank = static_cast<unsigned>(resultType.getRank());
199 
200  llvm::SmallBitVector droppedDims = op.getDroppedDims();
201 
202  SmallVector<OpFoldResult> finalSizes;
203  finalSizes.reserve(subRank);
204 
205  SmallVector<OpFoldResult> finalStrides;
206  finalStrides.reserve(subRank);
207 
208  for (auto i : llvm::seq(0u, static_cast<unsigned>(srcType.getRank()))) {
209  if (droppedDims.test(i))
210  continue;
211 
212  finalSizes.push_back(subSizes[i]);
213  finalStrides.push_back(strides[i]);
214  }
215 
216  rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
217  op, resultType, base, finalOffset, finalSizes, finalStrides);
218  return success();
219  }
220 };
221 
222 struct GpuDecomposeMemrefsPass
223  : public impl::GpuDecomposeMemrefsPassBase<GpuDecomposeMemrefsPass> {
224 
225  void runOnOperation() override {
227 
229 
230  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
231  return signalPassFailure();
232  }
233 };
234 
235 } // namespace
236 
238  patterns.insert<FlattenLoad, FlattenStore, FlattenSubview>(
239  patterns.getContext());
240 }
241 
242 std::unique_ptr<Pass> mlir::createGpuDecomposeMemrefsPass() {
243  return std::make_unique<GpuDecomposeMemrefsPass>();
244 }
static bool isInsideLaunch(Operation *op)
static bool needFlatten(Value val)
static MemRefType inferCastResultType(Value source, OpFoldResult offset)
static bool checkLayout(Value val)
static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source, ValueRange offsets)
static void setInsertionPointToStart(OpBuilder &builder, Value val)
static std::tuple< Value, OpFoldResult, SmallVector< OpFoldResult > > getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source, ArrayRef< OpFoldResult > subOffsets, ArrayRef< OpFoldResult > subStrides=std::nullopt)
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
Definition: AffineExpr.h:68
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:408
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:440
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:421
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:132
Type getType() const
Return the type of this value.
Definition: Value.h:129
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1194
Include the generated interface declarations.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
std::unique_ptr< Pass > createGpuDecomposeMemrefsPass()
Pass decomposes memref ops inside gpu.launch body.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns)
Collect a set of patterns to decompose memrefs ops.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362