1 //===- DecomposeMemRefs.cpp - Decompose memrefs pass implementation -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements decompose memrefs pass.
10 //
11 //===----------------------------------------------------------------------===//
19 #include "mlir/IR/AffineExpr.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Pass/Pass.h"
25 namespace mlir {
27 #include "mlir/Dialect/GPU/Transforms/"
28 } // namespace mlir
30 using namespace mlir;
32 static void setInsertionPointToStart(OpBuilder &builder, Value val) {
33  if (auto *parentOp = val.getDefiningOp()) {
34  builder.setInsertionPointAfter(parentOp);
35  } else {
37  }
38 }
40 static bool isInsideLaunch(Operation *op) {
41  return op->getParentOfType<gpu::LaunchOp>();
42 }
44 static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>>
46  ArrayRef<OpFoldResult> subOffsets,
47  ArrayRef<OpFoldResult> subStrides = std::nullopt) {
48  auto sourceType = cast<MemRefType>(source.getType());
49  auto sourceRank = static_cast<unsigned>(sourceType.getRank());
51  memref::ExtractStridedMetadataOp newExtractStridedMetadata;
52  {
53  OpBuilder::InsertionGuard g(rewriter);
54  setInsertionPointToStart(rewriter, source);
55  newExtractStridedMetadata =
56  rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
57  }
59  auto &&[sourceStrides, sourceOffset] = getStridesAndOffset(sourceType);
61  auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
62  return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal)
63  : rewriter.getIndexAttr(dim);
64  };
66  OpFoldResult origOffset =
67  getDim(sourceOffset, newExtractStridedMetadata.getOffset());
68  ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
70  SmallVector<OpFoldResult> origStrides;
71  origStrides.reserve(sourceRank);
74  strides.reserve(sourceRank);
76  AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
77  AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
78  for (auto i : llvm::seq(0u, sourceRank)) {
79  OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
81  if (!subStrides.empty()) {
82  strides.push_back(affine::makeComposedFoldedAffineApply(
83  rewriter, loc, s0 * s1, {subStrides[i], origStride}));
84  }
86  origStrides.emplace_back(origStride);
87  }
89  auto &&[expr, values] =
90  computeLinearIndex(origOffset, origStrides, subOffsets);
91  OpFoldResult finalOffset =
92  affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values);
93  return {newExtractStridedMetadata.getBaseBuffer(), finalOffset, strides};
94 }
96 static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source,
97  ValueRange offsets) {
98  SmallVector<OpFoldResult> offsetsTemp = getAsOpFoldResult(offsets);
99  auto &&[base, offset, ignore] =
100  getFlatOffsetAndStrides(rewriter, loc, source, offsetsTemp);
101  auto retType = cast<MemRefType>(base.getType());
102  return rewriter.create<memref::ReinterpretCastOp>(loc, retType, base, offset,
103  std::nullopt, std::nullopt);
104 }
106 static bool needFlatten(Value val) {
107  auto type = cast<MemRefType>(val.getType());
108  return type.getRank() != 0;
109 }
111 static bool checkLayout(Value val) {
112  auto type = cast<MemRefType>(val.getType());
113  return type.getLayout().isIdentity() ||
114  isa<StridedLayoutAttr>(type.getLayout());
115 }
117 namespace {
118 struct FlattenLoad : public OpRewritePattern<memref::LoadOp> {
121  LogicalResult matchAndRewrite(memref::LoadOp op,
122  PatternRewriter &rewriter) const override {
123  if (!isInsideLaunch(op))
124  return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
126  Value memref = op.getMemref();
127  if (!needFlatten(memref))
128  return rewriter.notifyMatchFailure(op, "nothing to do");
130  if (!checkLayout(memref))
131  return rewriter.notifyMatchFailure(op, "unsupported layout");
133  Location loc = op.getLoc();
134  Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices());
135  rewriter.replaceOpWithNewOp<memref::LoadOp>(op, flatMemref);
136  return success();
137  }
138 };
140 struct FlattenStore : public OpRewritePattern<memref::StoreOp> {
143  LogicalResult matchAndRewrite(memref::StoreOp op,
144  PatternRewriter &rewriter) const override {
145  if (!isInsideLaunch(op))
146  return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
148  Value memref = op.getMemref();
149  if (!needFlatten(memref))
150  return rewriter.notifyMatchFailure(op, "nothing to do");
152  if (!checkLayout(memref))
153  return rewriter.notifyMatchFailure(op, "unsupported layout");
155  Location loc = op.getLoc();
156  Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices());
157  Value value = op.getValue();
158  rewriter.replaceOpWithNewOp<memref::StoreOp>(op, value, flatMemref);
159  return success();
160  }
161 };
163 struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
166  LogicalResult matchAndRewrite(memref::SubViewOp op,
167  PatternRewriter &rewriter) const override {
168  if (!isInsideLaunch(op))
169  return rewriter.notifyMatchFailure(op, "not inside gpu.launch");
171  Value memref = op.getSource();
172  if (!needFlatten(memref))
173  return rewriter.notifyMatchFailure(op, "nothing to do");
175  if (!checkLayout(memref))
176  return rewriter.notifyMatchFailure(op, "unsupported layout");
178  Location loc = op.getLoc();
179  SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets();
180  SmallVector<OpFoldResult> subSizes = op.getMixedSizes();
181  SmallVector<OpFoldResult> subStrides = op.getMixedStrides();
182  auto &&[base, finalOffset, strides] =
183  getFlatOffsetAndStrides(rewriter, loc, memref, subOffsets, subStrides);
185  auto srcType = cast<MemRefType>(memref.getType());
186  auto resultType = cast<MemRefType>(op.getType());
187  unsigned subRank = static_cast<unsigned>(resultType.getRank());
189  llvm::SmallBitVector droppedDims = op.getDroppedDims();
191  SmallVector<OpFoldResult> finalSizes;
192  finalSizes.reserve(subRank);
194  SmallVector<OpFoldResult> finalStrides;
195  finalStrides.reserve(subRank);
197  for (auto i : llvm::seq(0u, static_cast<unsigned>(srcType.getRank()))) {
198  if (droppedDims.test(i))
199  continue;
201  finalSizes.push_back(subSizes[i]);
202  finalStrides.push_back(strides[i]);
203  }
205  rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
206  op, resultType, base, finalOffset, finalSizes, finalStrides);
207  return success();
208  }
209 };
211 struct GpuDecomposeMemrefsPass
212  : public impl::GpuDecomposeMemrefsPassBase<GpuDecomposeMemrefsPass> {
214  void runOnOperation() override {
215  RewritePatternSet patterns(&getContext());
219  if (failed(
220  applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
221  return signalPassFailure();
222  }
223 };
225 } // namespace
228  patterns.insert<FlattenLoad, FlattenStore, FlattenSubview>(
229  patterns.getContext());
230 }
232 std::unique_ptr<Pass> mlir::createGpuDecomposeMemrefsPass() {
233  return std::make_unique<GpuDecomposeMemrefsPass>();
234 }
