MLIR  22.0.0git
FlattenMemRefs.cpp
Go to the documentation of this file.
1 //===----- FlattenMemRefs.cpp - MemRef ops flattener pass ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains patterns for flattening an multi-rank memref-related
10 // ops into 1-d memref ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
23 #include "mlir/IR/Attributes.h"
24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/OpDefinition.h"
27 #include "mlir/IR/PatternMatch.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 
31 namespace mlir {
32 namespace memref {
33 #define GEN_PASS_DEF_FLATTENMEMREFSPASS
34 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
35 } // namespace memref
36 } // namespace mlir
37 
38 using namespace mlir;
39 
41  OpFoldResult in) {
42  if (Attribute offsetAttr = dyn_cast<Attribute>(in)) {
44  rewriter, loc, cast<IntegerAttr>(offsetAttr).getInt());
45  }
46  return cast<Value>(in);
47 }
48 
49 /// Returns a collapsed memref and the linearized index to access the element
50 /// at the specified indices.
51 static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
52  Location loc,
53  Value source,
54  ValueRange indices) {
55  int64_t sourceOffset;
56  SmallVector<int64_t, 4> sourceStrides;
57  auto sourceType = cast<MemRefType>(source.getType());
58  if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset))) {
59  assert(false);
60  }
61 
62  memref::ExtractStridedMetadataOp stridedMetadata =
63  memref::ExtractStridedMetadataOp::create(rewriter, loc, source);
64 
65  auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth();
66  OpFoldResult linearizedIndices;
67  memref::LinearizedMemRefInfo linearizedInfo;
68  std::tie(linearizedInfo, linearizedIndices) =
70  rewriter, loc, typeBit, typeBit,
71  stridedMetadata.getConstifiedMixedOffset(),
72  stridedMetadata.getConstifiedMixedSizes(),
73  stridedMetadata.getConstifiedMixedStrides(),
74  getAsOpFoldResult(indices));
75 
76  return std::make_pair(
77  memref::ReinterpretCastOp::create(
78  rewriter, loc, source,
79  /* offset = */ linearizedInfo.linearizedOffset,
80  /* shapes = */
81  ArrayRef<OpFoldResult>{linearizedInfo.linearizedSize},
82  /* strides = */
83  ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)}),
84  getValueFromOpFoldResult(rewriter, loc, linearizedIndices));
85 }
86 
87 static bool needFlattening(Value val) {
88  auto type = cast<MemRefType>(val.getType());
89  return type.getRank() > 1;
90 }
91 
92 static bool checkLayout(Value val) {
93  auto type = cast<MemRefType>(val.getType());
94  return type.getLayout().isIdentity() ||
95  isa<StridedLayoutAttr>(type.getLayout());
96 }
97 
98 namespace {
99 static Value getTargetMemref(Operation *op) {
101  .template Case<memref::LoadOp, memref::StoreOp, memref::AllocaOp,
102  memref::AllocOp>([](auto op) { return op.getMemref(); })
103  .template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
104  vector::MaskedStoreOp, vector::TransferReadOp,
105  vector::TransferWriteOp>(
106  [](auto op) { return op.getBase(); })
107  .Default([](auto) { return Value{}; });
108 }
109 
110 template <typename T>
111 static void castAllocResult(T oper, T newOper, Location loc,
112  PatternRewriter &rewriter) {
113  memref::ExtractStridedMetadataOp stridedMetadata =
114  memref::ExtractStridedMetadataOp::create(rewriter, loc, oper);
115  rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
116  oper, cast<MemRefType>(oper.getType()), newOper,
117  /*offset=*/rewriter.getIndexAttr(0),
118  stridedMetadata.getConstifiedMixedSizes(),
119  stridedMetadata.getConstifiedMixedStrides());
120 }
121 
122 template <typename T>
123 static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
124  Value offset) {
125  Location loc = op->getLoc();
126  llvm::TypeSwitch<Operation *>(op.getOperation())
127  .template Case<memref::AllocOp>([&](auto oper) {
128  auto newAlloc = memref::AllocOp::create(
129  rewriter, loc, cast<MemRefType>(flatMemref.getType()),
130  oper.getAlignmentAttr());
131  castAllocResult(oper, newAlloc, loc, rewriter);
132  })
133  .template Case<memref::AllocaOp>([&](auto oper) {
134  auto newAlloca = memref::AllocaOp::create(
135  rewriter, loc, cast<MemRefType>(flatMemref.getType()),
136  oper.getAlignmentAttr());
137  castAllocResult(oper, newAlloca, loc, rewriter);
138  })
139  .template Case<memref::LoadOp>([&](auto op) {
140  auto newLoad =
141  memref::LoadOp::create(rewriter, loc, op->getResultTypes(),
142  flatMemref, ValueRange{offset});
143  newLoad->setAttrs(op->getAttrs());
144  rewriter.replaceOp(op, newLoad.getResult());
145  })
146  .template Case<memref::StoreOp>([&](auto op) {
147  auto newStore =
148  memref::StoreOp::create(rewriter, loc, op->getOperands().front(),
149  flatMemref, ValueRange{offset});
150  newStore->setAttrs(op->getAttrs());
151  rewriter.replaceOp(op, newStore);
152  })
153  .template Case<vector::LoadOp>([&](auto op) {
154  auto newLoad =
155  vector::LoadOp::create(rewriter, loc, op->getResultTypes(),
156  flatMemref, ValueRange{offset});
157  newLoad->setAttrs(op->getAttrs());
158  rewriter.replaceOp(op, newLoad.getResult());
159  })
160  .template Case<vector::StoreOp>([&](auto op) {
161  auto newStore =
162  vector::StoreOp::create(rewriter, loc, op->getOperands().front(),
163  flatMemref, ValueRange{offset});
164  newStore->setAttrs(op->getAttrs());
165  rewriter.replaceOp(op, newStore);
166  })
167  .template Case<vector::MaskedLoadOp>([&](auto op) {
168  auto newMaskedLoad = vector::MaskedLoadOp::create(
169  rewriter, loc, op.getType(), flatMemref, ValueRange{offset},
170  op.getMask(), op.getPassThru());
171  newMaskedLoad->setAttrs(op->getAttrs());
172  rewriter.replaceOp(op, newMaskedLoad.getResult());
173  })
174  .template Case<vector::MaskedStoreOp>([&](auto op) {
175  auto newMaskedStore = vector::MaskedStoreOp::create(
176  rewriter, loc, flatMemref, ValueRange{offset}, op.getMask(),
177  op.getValueToStore());
178  newMaskedStore->setAttrs(op->getAttrs());
179  rewriter.replaceOp(op, newMaskedStore);
180  })
181  .template Case<vector::TransferReadOp>([&](auto op) {
182  auto newTransferRead = vector::TransferReadOp::create(
183  rewriter, loc, op.getType(), flatMemref, ValueRange{offset},
184  op.getPadding());
185  rewriter.replaceOp(op, newTransferRead.getResult());
186  })
187  .template Case<vector::TransferWriteOp>([&](auto op) {
188  auto newTransferWrite = vector::TransferWriteOp::create(
189  rewriter, loc, op.getVector(), flatMemref, ValueRange{offset});
190  rewriter.replaceOp(op, newTransferWrite);
191  })
192  .Default([&](auto op) {
193  op->emitOpError("unimplemented: do not know how to replace op.");
194  });
195 }
196 
197 template <typename T>
198 static ValueRange getIndices(T op) {
199  if constexpr (std::is_same_v<T, memref::AllocaOp> ||
200  std::is_same_v<T, memref::AllocOp>) {
201  return ValueRange{};
202  } else {
203  return op.getIndices();
204  }
205 }
206 
207 template <typename T>
208 static LogicalResult canBeFlattened(T op, PatternRewriter &rewriter) {
209  return llvm::TypeSwitch<Operation *, LogicalResult>(op.getOperation())
210  .template Case<vector::TransferReadOp, vector::TransferWriteOp>(
211  [&](auto oper) {
212  // For vector.transfer_read/write, must make sure:
213  // 1. all accesses are inbound, and
214  // 2. has an identity or minor identity permutation map.
215  auto permutationMap = oper.getPermutationMap();
216  if (!permutationMap.isIdentity() &&
217  !permutationMap.isMinorIdentity()) {
218  return rewriter.notifyMatchFailure(
219  oper, "only identity permutation map is supported");
220  }
221  mlir::ArrayAttr inbounds = oper.getInBounds();
222  if (llvm::any_of(inbounds, [](Attribute attr) {
223  return !cast<BoolAttr>(attr).getValue();
224  })) {
225  return rewriter.notifyMatchFailure(oper,
226  "only inbounds are supported");
227  }
228  return success();
229  })
230  .Default([&](auto op) { return success(); });
231 }
232 
233 template <typename T>
234 struct MemRefRewritePattern : public OpRewritePattern<T> {
236  LogicalResult matchAndRewrite(T op,
237  PatternRewriter &rewriter) const override {
238  LogicalResult canFlatten = canBeFlattened(op, rewriter);
239  if (failed(canFlatten)) {
240  return canFlatten;
241  }
242 
243  Value memref = getTargetMemref(op);
244  if (!needFlattening(memref) || !checkLayout(memref))
245  return failure();
246  auto &&[flatMemref, offset] = getFlattenMemrefAndOffset(
247  rewriter, op->getLoc(), memref, getIndices<T>(op));
248  replaceOp<T>(op, rewriter, flatMemref, offset);
249  return success();
250  }
251 };
252 
253 struct FlattenMemrefsPass
254  : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
255  using Base::Base;
256 
257  void getDependentDialects(DialectRegistry &registry) const override {
258  registry.insert<affine::AffineDialect, arith::ArithDialect,
259  memref::MemRefDialect, vector::VectorDialect>();
260  }
261 
262  void runOnOperation() override {
264 
266 
267  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
268  return signalPassFailure();
269  }
270 };
271 
272 } // namespace
273 
275  patterns.insert<MemRefRewritePattern<memref::LoadOp>,
276  MemRefRewritePattern<memref::StoreOp>,
277  MemRefRewritePattern<memref::AllocOp>,
278  MemRefRewritePattern<memref::AllocaOp>,
279  MemRefRewritePattern<vector::LoadOp>,
280  MemRefRewritePattern<vector::StoreOp>,
281  MemRefRewritePattern<vector::TransferReadOp>,
282  MemRefRewritePattern<vector::TransferWriteOp>,
283  MemRefRewritePattern<vector::MaskedLoadOp>,
284  MemRefRewritePattern<vector::MaskedStoreOp>>(
285  patterns.getContext());
286 }
static std::pair< Value, Value > getFlattenMemrefAndOffset(OpBuilder &rewriter, Location loc, Value source, ValueRange indices)
Returns a collapsed memref and the linearized index to access the element at the specified indices.
static bool checkLayout(Value val)
static bool needFlattening(Value val)
static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc, OpFoldResult in)
static MLIRContext * getContext(OpFoldResult val)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:205
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
Definition: MemRefUtils.cpp:51
void populateFlattenMemrefsPatterns(RewritePatternSet &patterns)
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
Definition: MemRefUtils.h:50