MLIR  21.0.0git
DecomposeMemRefs.cpp
Go to the documentation of this file.
1 //===- DecomposeMemRefs.cpp - Decompose memrefs pass implementation -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements decompose memrefs pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
19 #include "mlir/IR/AffineExpr.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Pass/Pass.h"
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_GPUDECOMPOSEMEMREFSPASS
27 #include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
28 } // namespace mlir
29 
30 using namespace mlir;
31 
32 static MemRefType inferCastResultType(Value source, OpFoldResult offset) {
33  auto sourceType = cast<BaseMemRefType>(source.getType());
34  SmallVector<int64_t> staticOffsets;
35  SmallVector<Value> dynamicOffsets;
36  dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
37  auto stridedLayout =
38  StridedLayoutAttr::get(source.getContext(), staticOffsets.front(), {});
39  return MemRefType::get({}, sourceType.getElementType(), stridedLayout,
40  sourceType.getMemorySpace());
41 }
42 
43 static void setInsertionPointToStart(OpBuilder &builder, Value val) {
44  if (auto *parentOp = val.getDefiningOp()) {
45  builder.setInsertionPointAfter(parentOp);
46  } else {
48  }
49 }
50 
51 static bool isInsideLaunch(Operation *op) {
52  return op->getParentOfType<gpu::LaunchOp>();
53 }
54 
55 static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>>
57  ArrayRef<OpFoldResult> subOffsets,
58  ArrayRef<OpFoldResult> subStrides = {}) {
59  auto sourceType = cast<MemRefType>(source.getType());
60  auto sourceRank = static_cast<unsigned>(sourceType.getRank());
61 
62  memref::ExtractStridedMetadataOp newExtractStridedMetadata;
63  {
64  OpBuilder::InsertionGuard g(rewriter);
65  setInsertionPointToStart(rewriter, source);
66  newExtractStridedMetadata =
67  rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
68  }
69 
70  auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
71 
72  auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
73  return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal)
74  : rewriter.getIndexAttr(dim);
75  };
76 
77  OpFoldResult origOffset =
78  getDim(sourceOffset, newExtractStridedMetadata.getOffset());
79  ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
80 
81  SmallVector<OpFoldResult> origStrides;
82  origStrides.reserve(sourceRank);
83 
85  strides.reserve(sourceRank);
86 
87  AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
88  AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
89  for (auto i : llvm::seq(0u, sourceRank)) {
90  OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
91 
92  if (!subStrides.empty()) {
93  strides.push_back(affine::makeComposedFoldedAffineApply(
94  rewriter, loc, s0 * s1, {subStrides[i], origStride}));
95  }
96 
97  origStrides.emplace_back(origStride);
98  }
99 
100  auto &&[expr, values] =
101  computeLinearIndex(origOffset, origStrides, subOffsets);
102  OpFoldResult finalOffset =
103  affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values);
104  return {newExtractStridedMetadata.getBaseBuffer(), finalOffset, strides};
105 }
106 
107 static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source,
108  ValueRange offsets) {
109  SmallVector<OpFoldResult> offsetsTemp = getAsOpFoldResult(offsets);
110  auto &&[base, offset, ignore] =
111  getFlatOffsetAndStrides(rewriter, loc, source, offsetsTemp);
112  MemRefType retType = inferCastResultType(base, offset);
113  return rewriter.create<memref::ReinterpretCastOp>(loc, retType, base, offset,
116 }
117 
118 static bool needFlatten(Value val) {
119  auto type = cast<MemRefType>(val.getType());
120  return type.getRank() != 0;
121 }
122 
123 static bool checkLayout(Value val) {
124  auto type = cast<MemRefType>(val.getType());
125  return type.getLayout().isIdentity() ||
126  isa<StridedLayoutAttr>(type.getLayout());
127 }
128 
129 namespace {
130 struct FlattenLoad : public OpRewritePattern<memref::LoadOp> {
132 
133  LogicalResult matchAndRewrite(memref::LoadOp op,
134  PatternRewriter &rewriter) const override {
135  if (!isInsideLaunch(op))
136  return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
137 
138  Value memref = op.getMemref();
139  if (!needFlatten(memref))
140  return rewriter.notifyMatchFailure(op, "nothing to do");
141 
142  if (!checkLayout(memref))
143  return rewriter.notifyMatchFailure(op, "unsupported layout");
144 
145  Location loc = op.getLoc();
146  Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices());
147  rewriter.replaceOpWithNewOp<memref::LoadOp>(op, flatMemref);
148  return success();
149  }
150 };
151 
152 struct FlattenStore : public OpRewritePattern<memref::StoreOp> {
154 
155  LogicalResult matchAndRewrite(memref::StoreOp op,
156  PatternRewriter &rewriter) const override {
157  if (!isInsideLaunch(op))
158  return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
159 
160  Value memref = op.getMemref();
161  if (!needFlatten(memref))
162  return rewriter.notifyMatchFailure(op, "nothing to do");
163 
164  if (!checkLayout(memref))
165  return rewriter.notifyMatchFailure(op, "unsupported layout");
166 
167  Location loc = op.getLoc();
168  Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices());
169  Value value = op.getValue();
170  rewriter.replaceOpWithNewOp<memref::StoreOp>(op, value, flatMemref);
171  return success();
172  }
173 };
174 
175 struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
177 
178  LogicalResult matchAndRewrite(memref::SubViewOp op,
179  PatternRewriter &rewriter) const override {
180  if (!isInsideLaunch(op))
181  return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
182 
183  Value memref = op.getSource();
184  if (!needFlatten(memref))
185  return rewriter.notifyMatchFailure(op, "nothing to do");
186 
187  if (!checkLayout(memref))
188  return rewriter.notifyMatchFailure(op, "unsupported layout");
189 
190  Location loc = op.getLoc();
191  SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets();
192  SmallVector<OpFoldResult> subSizes = op.getMixedSizes();
193  SmallVector<OpFoldResult> subStrides = op.getMixedStrides();
194  auto &&[base, finalOffset, strides] =
195  getFlatOffsetAndStrides(rewriter, loc, memref, subOffsets, subStrides);
196 
197  auto srcType = cast<MemRefType>(memref.getType());
198  auto resultType = cast<MemRefType>(op.getType());
199  unsigned subRank = static_cast<unsigned>(resultType.getRank());
200 
201  llvm::SmallBitVector droppedDims = op.getDroppedDims();
202 
203  SmallVector<OpFoldResult> finalSizes;
204  finalSizes.reserve(subRank);
205 
206  SmallVector<OpFoldResult> finalStrides;
207  finalStrides.reserve(subRank);
208 
209  for (auto i : llvm::seq(0u, static_cast<unsigned>(srcType.getRank()))) {
210  if (droppedDims.test(i))
211  continue;
212 
213  finalSizes.push_back(subSizes[i]);
214  finalStrides.push_back(strides[i]);
215  }
216 
217  rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
218  op, resultType, base, finalOffset, finalSizes, finalStrides);
219  return success();
220  }
221 };
222 
223 struct GpuDecomposeMemrefsPass
224  : public impl::GpuDecomposeMemrefsPassBase<GpuDecomposeMemrefsPass> {
225 
226  void runOnOperation() override {
228 
230 
231  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
232  return signalPassFailure();
233  }
234 };
235 
236 } // namespace
237 
239  patterns.insert<FlattenLoad, FlattenStore, FlattenSubview>(
240  patterns.getContext());
241 }
static bool isInsideLaunch(Operation *op)
static bool needFlatten(Value val)
static MemRefType inferCastResultType(Value source, OpFoldResult offset)
static bool checkLayout(Value val)
static std::tuple< Value, OpFoldResult, SmallVector< OpFoldResult > > getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source, ArrayRef< OpFoldResult > subOffsets, ArrayRef< OpFoldResult > subStrides={})
static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source, ValueRange offsets)
static void setInsertionPointToStart(OpBuilder &builder, Value val)
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:363
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Type getType() const
Return the type of this value.
Definition: Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:48
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1333
Include the generated interface declarations.
std::pair< AffineExpr, SmallVector< OpFoldResult > > computeLinearIndex(OpFoldResult sourceOffset, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices)
Compute linear index from provided strides and indices, assuming strided layout.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns)
Collect a set of patterns to decompose memrefs ops.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319