MLIR  20.0.0git
DecomposeMemRefs.cpp
Go to the documentation of this file.
1 //===- DecomposeMemRefs.cpp - Decompose memrefs pass implementation -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements decompose memrefs pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
19 #include "mlir/IR/AffineExpr.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Pass/Pass.h"
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_GPUDECOMPOSEMEMREFSPASS
27 #include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
28 } // namespace mlir
29 
30 using namespace mlir;
31 
32 static void setInsertionPointToStart(OpBuilder &builder, Value val) {
33  if (auto *parentOp = val.getDefiningOp()) {
34  builder.setInsertionPointAfter(parentOp);
35  } else {
37  }
38 }
39 
40 static bool isInsideLaunch(Operation *op) {
41  return op->getParentOfType<gpu::LaunchOp>();
42 }
43 
44 static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>>
46  ArrayRef<OpFoldResult> subOffsets,
47  ArrayRef<OpFoldResult> subStrides = std::nullopt) {
48  auto sourceType = cast<MemRefType>(source.getType());
49  auto sourceRank = static_cast<unsigned>(sourceType.getRank());
50 
51  memref::ExtractStridedMetadataOp newExtractStridedMetadata;
52  {
53  OpBuilder::InsertionGuard g(rewriter);
54  setInsertionPointToStart(rewriter, source);
55  newExtractStridedMetadata =
56  rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
57  }
58 
59  auto &&[sourceStrides, sourceOffset] = getStridesAndOffset(sourceType);
60 
61  auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
62  return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal)
63  : rewriter.getIndexAttr(dim);
64  };
65 
66  OpFoldResult origOffset =
67  getDim(sourceOffset, newExtractStridedMetadata.getOffset());
68  ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
69 
70  SmallVector<OpFoldResult> origStrides;
71  origStrides.reserve(sourceRank);
72 
74  strides.reserve(sourceRank);
75 
76  AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
77  AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
78  for (auto i : llvm::seq(0u, sourceRank)) {
79  OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
80 
81  if (!subStrides.empty()) {
82  strides.push_back(affine::makeComposedFoldedAffineApply(
83  rewriter, loc, s0 * s1, {subStrides[i], origStride}));
84  }
85 
86  origStrides.emplace_back(origStride);
87  }
88 
89  auto &&[expr, values] =
90  computeLinearIndex(origOffset, origStrides, subOffsets);
91  OpFoldResult finalOffset =
92  affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values);
93  return {newExtractStridedMetadata.getBaseBuffer(), finalOffset, strides};
94 }
95 
96 static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source,
97  ValueRange offsets) {
98  SmallVector<OpFoldResult> offsetsTemp = getAsOpFoldResult(offsets);
99  auto &&[base, offset, ignore] =
100  getFlatOffsetAndStrides(rewriter, loc, source, offsetsTemp);
101  auto retType = cast<MemRefType>(base.getType());
102  return rewriter.create<memref::ReinterpretCastOp>(loc, retType, base, offset,
103  std::nullopt, std::nullopt);
104 }
105 
106 static bool needFlatten(Value val) {
107  auto type = cast<MemRefType>(val.getType());
108  return type.getRank() != 0;
109 }
110 
111 static bool checkLayout(Value val) {
112  auto type = cast<MemRefType>(val.getType());
113  return type.getLayout().isIdentity() ||
114  isa<StridedLayoutAttr>(type.getLayout());
115 }
116 
117 namespace {
118 struct FlattenLoad : public OpRewritePattern<memref::LoadOp> {
120 
121  LogicalResult matchAndRewrite(memref::LoadOp op,
122  PatternRewriter &rewriter) const override {
123  if (!isInsideLaunch(op))
124  return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
125 
126  Value memref = op.getMemref();
127  if (!needFlatten(memref))
128  return rewriter.notifyMatchFailure(op, "nothing to do");
129 
130  if (!checkLayout(memref))
131  return rewriter.notifyMatchFailure(op, "unsupported layout");
132 
133  Location loc = op.getLoc();
134  Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices());
135  rewriter.replaceOpWithNewOp<memref::LoadOp>(op, flatMemref);
136  return success();
137  }
138 };
139 
140 struct FlattenStore : public OpRewritePattern<memref::StoreOp> {
142 
143  LogicalResult matchAndRewrite(memref::StoreOp op,
144  PatternRewriter &rewriter) const override {
145  if (!isInsideLaunch(op))
146  return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
147 
148  Value memref = op.getMemref();
149  if (!needFlatten(memref))
150  return rewriter.notifyMatchFailure(op, "nothing to do");
151 
152  if (!checkLayout(memref))
153  return rewriter.notifyMatchFailure(op, "unsupported layout");
154 
155  Location loc = op.getLoc();
156  Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices());
157  Value value = op.getValue();
158  rewriter.replaceOpWithNewOp<memref::StoreOp>(op, value, flatMemref);
159  return success();
160  }
161 };
162 
163 struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
165 
166  LogicalResult matchAndRewrite(memref::SubViewOp op,
167  PatternRewriter &rewriter) const override {
168  if (!isInsideLaunch(op))
169  return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
170 
171  Value memref = op.getSource();
172  if (!needFlatten(memref))
173  return rewriter.notifyMatchFailure(op, "nothing to do");
174 
175  if (!checkLayout(memref))
176  return rewriter.notifyMatchFailure(op, "unsupported layout");
177 
178  Location loc = op.getLoc();
179  SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets();
180  SmallVector<OpFoldResult> subSizes = op.getMixedSizes();
181  SmallVector<OpFoldResult> subStrides = op.getMixedStrides();
182  auto &&[base, finalOffset, strides] =
183  getFlatOffsetAndStrides(rewriter, loc, memref, subOffsets, subStrides);
184 
185  auto srcType = cast<MemRefType>(memref.getType());
186  auto resultType = cast<MemRefType>(op.getType());
187  unsigned subRank = static_cast<unsigned>(resultType.getRank());
188 
189  llvm::SmallBitVector droppedDims = op.getDroppedDims();
190 
191  SmallVector<OpFoldResult> finalSizes;
192  finalSizes.reserve(subRank);
193 
194  SmallVector<OpFoldResult> finalStrides;
195  finalStrides.reserve(subRank);
196 
197  for (auto i : llvm::seq(0u, static_cast<unsigned>(srcType.getRank()))) {
198  if (droppedDims.test(i))
199  continue;
200 
201  finalSizes.push_back(subSizes[i]);
202  finalStrides.push_back(strides[i]);
203  }
204 
205  rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
206  op, resultType, base, finalOffset, finalSizes, finalStrides);
207  return success();
208  }
209 };
210 
211 struct GpuDecomposeMemrefsPass
212  : public impl::GpuDecomposeMemrefsPassBase<GpuDecomposeMemrefsPass> {
213 
214  void runOnOperation() override {
215  RewritePatternSet patterns(&getContext());
216 
218 
219  if (failed(
220  applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
221  return signalPassFailure();
222  }
223 };
224 
225 } // namespace
226 
228  patterns.insert<FlattenLoad, FlattenStore, FlattenSubview>(
229  patterns.getContext());
230 }
231 
232 std::unique_ptr<Pass> mlir::createGpuDecomposeMemrefsPass() {
233  return std::make_unique<GpuDecomposeMemrefsPass>();
234 }
static bool isInsideLaunch(Operation *op)
static bool needFlatten(Value val)
static bool checkLayout(Value val)
static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source, ValueRange offsets)
static void setInsertionPointToStart(OpBuilder &builder, Value val)
static std::tuple< Value, OpFoldResult, SmallVector< OpFoldResult > > getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source, ArrayRef< OpFoldResult > subOffsets, ArrayRef< OpFoldResult > subStrides=std::nullopt)
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
Definition: AffineExpr.h:68
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:132
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:383
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:352
This class helps build Operations.
Definition: Builders.h:211
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:435
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:472
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:416
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:931
MLIRContext * getContext() const
Definition: PatternMatch.h:823
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1192
Include the generated interface declarations.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
std::unique_ptr< Pass > createGpuDecomposeMemrefsPass()
Pass decomposes memref ops inside gpu.launch body.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns)
Collect a set of patterns to decompose memrefs ops.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362