MLIR  21.0.0git
FlattenMemRefs.cpp
Go to the documentation of this file.
1 //===----- FlattenMemRefs.cpp - MemRef ops flattener pass ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains patterns for flattening an multi-rank memref-related
10 // ops into 1-d memref ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
23 #include "mlir/IR/AffineExpr.h"
24 #include "mlir/IR/Attributes.h"
25 #include "mlir/IR/Builders.h"
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/IR/OpDefinition.h"
28 #include "mlir/IR/PatternMatch.h"
29 #include "mlir/Pass/Pass.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 
34 #include <numeric>
35 
36 namespace mlir {
37 namespace memref {
38 #define GEN_PASS_DEF_FLATTENMEMREFSPASS
39 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
40 } // namespace memref
41 } // namespace mlir
42 
43 using namespace mlir;
44 
46  OpFoldResult in) {
47  if (Attribute offsetAttr = dyn_cast<Attribute>(in)) {
48  return rewriter.create<arith::ConstantIndexOp>(
49  loc, cast<IntegerAttr>(offsetAttr).getInt());
50  }
51  return cast<Value>(in);
52 }
53 
54 /// Returns a collapsed memref and the linearized index to access the element
55 /// at the specified indices.
56 static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
57  Location loc,
58  Value source,
59  ValueRange indices) {
60  int64_t sourceOffset;
61  SmallVector<int64_t, 4> sourceStrides;
62  auto sourceType = cast<MemRefType>(source.getType());
63  if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset))) {
64  assert(false);
65  }
66 
67  memref::ExtractStridedMetadataOp stridedMetadata =
68  rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
69 
70  auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth();
71  OpFoldResult linearizedIndices;
72  memref::LinearizedMemRefInfo linearizedInfo;
73  std::tie(linearizedInfo, linearizedIndices) =
75  rewriter, loc, typeBit, typeBit,
76  stridedMetadata.getConstifiedMixedOffset(),
77  stridedMetadata.getConstifiedMixedSizes(),
78  stridedMetadata.getConstifiedMixedStrides(),
79  getAsOpFoldResult(indices));
80 
81  return std::make_pair(
82  rewriter.create<memref::ReinterpretCastOp>(
83  loc, source,
84  /* offset = */ linearizedInfo.linearizedOffset,
85  /* shapes = */
86  ArrayRef<OpFoldResult>{linearizedInfo.linearizedSize},
87  /* strides = */
88  ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)}),
89  getValueFromOpFoldResult(rewriter, loc, linearizedIndices));
90 }
91 
92 static bool needFlattening(Value val) {
93  auto type = cast<MemRefType>(val.getType());
94  return type.getRank() > 1;
95 }
96 
97 static bool checkLayout(Value val) {
98  auto type = cast<MemRefType>(val.getType());
99  return type.getLayout().isIdentity() ||
100  isa<StridedLayoutAttr>(type.getLayout());
101 }
102 
103 namespace {
104 static Value getTargetMemref(Operation *op) {
106  .template Case<memref::LoadOp, memref::StoreOp, memref::AllocaOp,
107  memref::AllocOp>([](auto op) { return op.getMemref(); })
108  .template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
109  vector::MaskedStoreOp, vector::TransferReadOp,
110  vector::TransferWriteOp>(
111  [](auto op) { return op.getBase(); })
112  .Default([](auto) { return Value{}; });
113 }
114 
115 template <typename T>
116 static void castAllocResult(T oper, T newOper, Location loc,
117  PatternRewriter &rewriter) {
118  memref::ExtractStridedMetadataOp stridedMetadata =
119  rewriter.create<memref::ExtractStridedMetadataOp>(loc, oper);
120  rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
121  oper, cast<MemRefType>(oper.getType()), newOper,
122  /*offset=*/rewriter.getIndexAttr(0),
123  stridedMetadata.getConstifiedMixedSizes(),
124  stridedMetadata.getConstifiedMixedStrides());
125 }
126 
127 template <typename T>
128 static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
129  Value offset) {
130  Location loc = op->getLoc();
131  llvm::TypeSwitch<Operation *>(op.getOperation())
132  .template Case<memref::AllocOp>([&](auto oper) {
133  auto newAlloc = rewriter.create<memref::AllocOp>(
134  loc, cast<MemRefType>(flatMemref.getType()),
135  oper.getAlignmentAttr());
136  castAllocResult(oper, newAlloc, loc, rewriter);
137  })
138  .template Case<memref::AllocaOp>([&](auto oper) {
139  auto newAlloca = rewriter.create<memref::AllocaOp>(
140  loc, cast<MemRefType>(flatMemref.getType()),
141  oper.getAlignmentAttr());
142  castAllocResult(oper, newAlloca, loc, rewriter);
143  })
144  .template Case<memref::LoadOp>([&](auto op) {
145  auto newLoad = rewriter.create<memref::LoadOp>(
146  loc, op->getResultTypes(), flatMemref, ValueRange{offset});
147  newLoad->setAttrs(op->getAttrs());
148  rewriter.replaceOp(op, newLoad.getResult());
149  })
150  .template Case<memref::StoreOp>([&](auto op) {
151  auto newStore = rewriter.create<memref::StoreOp>(
152  loc, op->getOperands().front(), flatMemref, ValueRange{offset});
153  newStore->setAttrs(op->getAttrs());
154  rewriter.replaceOp(op, newStore);
155  })
156  .template Case<vector::LoadOp>([&](auto op) {
157  auto newLoad = rewriter.create<vector::LoadOp>(
158  loc, op->getResultTypes(), flatMemref, ValueRange{offset});
159  newLoad->setAttrs(op->getAttrs());
160  rewriter.replaceOp(op, newLoad.getResult());
161  })
162  .template Case<vector::StoreOp>([&](auto op) {
163  auto newStore = rewriter.create<vector::StoreOp>(
164  loc, op->getOperands().front(), flatMemref, ValueRange{offset});
165  newStore->setAttrs(op->getAttrs());
166  rewriter.replaceOp(op, newStore);
167  })
168  .template Case<vector::MaskedLoadOp>([&](auto op) {
169  auto newMaskedLoad = rewriter.create<vector::MaskedLoadOp>(
170  loc, op.getType(), flatMemref, ValueRange{offset}, op.getMask(),
171  op.getPassThru());
172  newMaskedLoad->setAttrs(op->getAttrs());
173  rewriter.replaceOp(op, newMaskedLoad.getResult());
174  })
175  .template Case<vector::MaskedStoreOp>([&](auto op) {
176  auto newMaskedStore = rewriter.create<vector::MaskedStoreOp>(
177  loc, flatMemref, ValueRange{offset}, op.getMask(),
178  op.getValueToStore());
179  newMaskedStore->setAttrs(op->getAttrs());
180  rewriter.replaceOp(op, newMaskedStore);
181  })
182  .template Case<vector::TransferReadOp>([&](auto op) {
183  auto newTransferRead = rewriter.create<vector::TransferReadOp>(
184  loc, op.getType(), flatMemref, ValueRange{offset}, op.getPadding());
185  rewriter.replaceOp(op, newTransferRead.getResult());
186  })
187  .template Case<vector::TransferWriteOp>([&](auto op) {
188  auto newTransferWrite = rewriter.create<vector::TransferWriteOp>(
189  loc, op.getVector(), flatMemref, ValueRange{offset});
190  rewriter.replaceOp(op, newTransferWrite);
191  })
192  .Default([&](auto op) {
193  op->emitOpError("unimplemented: do not know how to replace op.");
194  });
195 }
196 
197 template <typename T>
198 static ValueRange getIndices(T op) {
199  if constexpr (std::is_same_v<T, memref::AllocaOp> ||
200  std::is_same_v<T, memref::AllocOp>) {
201  return ValueRange{};
202  } else {
203  return op.getIndices();
204  }
205 }
206 
207 template <typename T>
208 static LogicalResult canBeFlattened(T op, PatternRewriter &rewriter) {
209  return llvm::TypeSwitch<Operation *, LogicalResult>(op.getOperation())
210  .template Case<vector::TransferReadOp, vector::TransferWriteOp>(
211  [&](auto oper) {
212  // For vector.transfer_read/write, must make sure:
213  // 1. all accesses are inbound, and
214  // 2. has an identity or minor identity permutation map.
215  auto permutationMap = oper.getPermutationMap();
216  if (!permutationMap.isIdentity() &&
217  !permutationMap.isMinorIdentity()) {
218  return rewriter.notifyMatchFailure(
219  oper, "only identity permutation map is supported");
220  }
221  mlir::ArrayAttr inbounds = oper.getInBounds();
222  if (llvm::any_of(inbounds, [](Attribute attr) {
223  return !cast<BoolAttr>(attr).getValue();
224  })) {
225  return rewriter.notifyMatchFailure(oper,
226  "only inbounds are supported");
227  }
228  return success();
229  })
230  .Default([&](auto op) { return success(); });
231 }
232 
233 template <typename T>
234 struct MemRefRewritePattern : public OpRewritePattern<T> {
236  LogicalResult matchAndRewrite(T op,
237  PatternRewriter &rewriter) const override {
238  LogicalResult canFlatten = canBeFlattened(op, rewriter);
239  if (failed(canFlatten)) {
240  return canFlatten;
241  }
242 
243  Value memref = getTargetMemref(op);
244  if (!needFlattening(memref) || !checkLayout(memref))
245  return failure();
246  auto &&[flatMemref, offset] = getFlattenMemrefAndOffset(
247  rewriter, op->getLoc(), memref, getIndices<T>(op));
248  replaceOp<T>(op, rewriter, flatMemref, offset);
249  return success();
250  }
251 };
252 
253 struct FlattenMemrefsPass
254  : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
255  using Base::Base;
256 
257  void getDependentDialects(DialectRegistry &registry) const override {
258  registry.insert<affine::AffineDialect, arith::ArithDialect,
259  memref::MemRefDialect, vector::VectorDialect>();
260  }
261 
262  void runOnOperation() override {
264 
266 
267  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
268  return signalPassFailure();
269  }
270 };
271 
272 } // namespace
273 
275  patterns.insert<MemRefRewritePattern<memref::LoadOp>,
276  MemRefRewritePattern<memref::StoreOp>,
277  MemRefRewritePattern<memref::AllocOp>,
278  MemRefRewritePattern<memref::AllocaOp>,
279  MemRefRewritePattern<vector::LoadOp>,
280  MemRefRewritePattern<vector::StoreOp>,
281  MemRefRewritePattern<vector::TransferReadOp>,
282  MemRefRewritePattern<vector::TransferWriteOp>,
283  MemRefRewritePattern<vector::MaskedLoadOp>,
284  MemRefRewritePattern<vector::MaskedStoreOp>>(
285  patterns.getContext());
286 }
static std::pair< Value, Value > getFlattenMemrefAndOffset(OpBuilder &rewriter, Location loc, Value source, ValueRange indices)
Returns a collapsed memref and the linearized index to access the element at the specified indices.
static bool checkLayout(Value val)
static bool needFlattening(Value val)
static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc, OpFoldResult in)
static MLIRContext * getContext(OpFoldResult val)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:204
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Definition: MemRefUtils.cpp:52
void populateFlattenMemrefsPatterns(RewritePatternSet &patterns)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
Definition: MemRefUtils.h:50