MLIR 23.0.0git
FlattenMemRefs.cpp
Go to the documentation of this file.
1//===----- FlattenMemRefs.cpp - MemRef ops flattener pass ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains patterns for flattening an multi-rank memref-related
10// ops into 1-d memref ops.
11//
12//===----------------------------------------------------------------------===//
13
23#include "mlir/IR/Attributes.h"
24#include "mlir/IR/Builders.h"
29#include "llvm/ADT/TypeSwitch.h"
30
31namespace mlir {
32namespace memref {
33#define GEN_PASS_DEF_FLATTENMEMREFSPASS
34#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
35} // namespace memref
36} // namespace mlir
37
38using namespace mlir;
39
41 OpFoldResult in) {
42 if (Attribute offsetAttr = dyn_cast<Attribute>(in)) {
44 rewriter, loc, cast<IntegerAttr>(offsetAttr).getInt());
45 }
46 return cast<Value>(in);
47}
48
49/// Returns a collapsed memref and the linearized index to access the element
50/// at the specified indices.
51static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
52 Location loc,
53 Value source,
55 int64_t sourceOffset;
56 SmallVector<int64_t, 4> sourceStrides;
57 auto sourceType = cast<MemRefType>(source.getType());
58 if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset))) {
59 assert(false);
60 }
61
62 memref::ExtractStridedMetadataOp stridedMetadata =
63 memref::ExtractStridedMetadataOp::create(rewriter, loc, source);
64
65 auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth();
66 OpFoldResult linearizedIndices;
67 memref::LinearizedMemRefInfo linearizedInfo;
68 std::tie(linearizedInfo, linearizedIndices) =
70 rewriter, loc, typeBit, typeBit,
71 stridedMetadata.getConstifiedMixedOffset(),
72 stridedMetadata.getConstifiedMixedSizes(),
73 stridedMetadata.getConstifiedMixedStrides(),
75
76 return std::make_pair(
77 memref::ReinterpretCastOp::create(
78 rewriter, loc, source,
79 /* offset = */ linearizedInfo.linearizedOffset,
80 /* shapes = */
81 ArrayRef<OpFoldResult>{linearizedInfo.linearizedSize},
82 /* strides = */
83 ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)}),
84 getValueFromOpFoldResult(rewriter, loc, linearizedIndices));
85}
86
87static bool needFlattening(Value val) {
88 auto type = cast<MemRefType>(val.getType());
89 return type.getRank() > 1;
90}
91
92static bool checkLayout(Value val) {
93 auto type = cast<MemRefType>(val.getType());
94 return type.getLayout().isIdentity() ||
95 isa<StridedLayoutAttr>(type.getLayout());
96}
97
98namespace {
99static Value getTargetMemref(Operation *op) {
101 .template Case<memref::LoadOp, memref::StoreOp, memref::AllocaOp,
102 memref::AllocOp>([](auto op) { return op.getMemref(); })
103 .template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
104 vector::MaskedStoreOp, vector::TransferReadOp,
105 vector::TransferWriteOp>(
106 [](auto op) { return op.getBase(); })
107 .Default(nullptr);
108}
109
110template <typename T>
111static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
112 Value offset) {
113 Location loc = op->getLoc();
114 llvm::TypeSwitch<Operation *>(op.getOperation())
115 .Case([&](memref::LoadOp op) {
116 auto newLoad =
117 memref::LoadOp::create(rewriter, loc, op->getResultTypes(),
118 flatMemref, ValueRange{offset});
119 newLoad->setAttrs(op->getAttrs());
120 rewriter.replaceOp(op, newLoad.getResult());
121 })
122 .Case([&](memref::StoreOp op) {
123 auto newStore =
124 memref::StoreOp::create(rewriter, loc, op->getOperands().front(),
125 flatMemref, ValueRange{offset});
126 newStore->setAttrs(op->getAttrs());
127 rewriter.replaceOp(op, newStore);
128 })
129 .Case([&](vector::LoadOp op) {
130 auto newLoad =
131 vector::LoadOp::create(rewriter, loc, op->getResultTypes(),
132 flatMemref, ValueRange{offset});
133 newLoad->setAttrs(op->getAttrs());
134 rewriter.replaceOp(op, newLoad.getResult());
135 })
136 .Case([&](vector::StoreOp op) {
137 auto newStore =
138 vector::StoreOp::create(rewriter, loc, op->getOperands().front(),
139 flatMemref, ValueRange{offset});
140 newStore->setAttrs(op->getAttrs());
141 rewriter.replaceOp(op, newStore);
142 })
143 .Case([&](vector::MaskedLoadOp op) {
144 auto newMaskedLoad = vector::MaskedLoadOp::create(
145 rewriter, loc, op.getType(), flatMemref, ValueRange{offset},
146 op.getMask(), op.getPassThru());
147 newMaskedLoad->setAttrs(op->getAttrs());
148 rewriter.replaceOp(op, newMaskedLoad.getResult());
149 })
150 .Case([&](vector::MaskedStoreOp op) {
151 auto newMaskedStore = vector::MaskedStoreOp::create(
152 rewriter, loc, flatMemref, ValueRange{offset}, op.getMask(),
153 op.getValueToStore());
154 newMaskedStore->setAttrs(op->getAttrs());
155 rewriter.replaceOp(op, newMaskedStore);
156 })
157 .Case([&](vector::TransferReadOp op) {
158 auto newTransferRead = vector::TransferReadOp::create(
159 rewriter, loc, op.getType(), flatMemref, ValueRange{offset},
160 op.getPadding());
161 rewriter.replaceOp(op, newTransferRead.getResult());
162 })
163 .Case([&](vector::TransferWriteOp op) {
164 auto newTransferWrite = vector::TransferWriteOp::create(
165 rewriter, loc, op.getVector(), flatMemref, ValueRange{offset});
166 rewriter.replaceOp(op, newTransferWrite);
167 })
168 .Default([&](auto op) {
169 op->emitOpError("unimplemented: do not know how to replace op.");
170 });
171}
172
173template <typename T>
174static ValueRange getIndices(T op) {
175 return op.getIndices();
176}
177
178template <typename T>
179static LogicalResult canBeFlattened(T op, PatternRewriter &rewriter) {
180 return llvm::TypeSwitch<Operation *, LogicalResult>(op.getOperation())
181 .template Case<vector::TransferReadOp, vector::TransferWriteOp>(
182 [&](auto oper) {
183 // For vector.transfer_read/write, must make sure:
184 // 1. all accesses are inbound, and
185 // 2. has an identity or minor identity permutation map.
186 auto permutationMap = oper.getPermutationMap();
187 if (!permutationMap.isIdentity() &&
188 !permutationMap.isMinorIdentity()) {
189 return rewriter.notifyMatchFailure(
190 oper, "only identity permutation map is supported");
191 }
192 mlir::ArrayAttr inbounds = oper.getInBounds();
193 if (llvm::any_of(inbounds, [](Attribute attr) {
194 return !cast<BoolAttr>(attr).getValue();
195 })) {
196 return rewriter.notifyMatchFailure(oper,
197 "only inbounds are supported");
198 }
199 return success();
200 })
201 .Default([&](auto op) { return success(); });
202}
203
204// Pattern for memref::AllocOp and memref::AllocaOp.
205//
206// The "source" memref for these ops IS the op's own result, so the generic
207// MemRefRewritePattern cannot be used: getFlattenMemrefAndOffset would insert
208// ExtractStridedMetadataOp and ReinterpretCastOp that use op.result BEFORE op
209// in the block. After replaceOpWithNewOp the original result is RAUW'd to the
210// new ReinterpretCastOp, leaving the earlier ops with forward references
211// (domination violations) caught by MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS.
212//
213// Instead, sizes and strides are computed from the op's operands and type
214// (which all dominate the op), avoiding any reference to op.result until the
215// final replaceOpWithNewOp.
216template <typename AllocLikeOp>
217struct AllocLikeFlattenPattern : public OpRewritePattern<AllocLikeOp> {
218 using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
219 LogicalResult matchAndRewrite(AllocLikeOp op,
220 PatternRewriter &rewriter) const override {
221 if (!needFlattening(op.getMemref()) || !checkLayout(op.getMemref()))
222 return failure();
223
224 Location loc = op->getLoc();
225 auto memrefType = cast<MemRefType>(op.getType());
226 auto elemType = memrefType.getElementType();
227 if (!elemType.isIntOrFloat())
228 return failure();
229 unsigned elemBitWidth = elemType.getIntOrFloatBitWidth();
230
231 SmallVector<OpFoldResult> sizes = op.getMixedSizes();
232
233 int64_t staticOffset;
234 SmallVector<int64_t> staticStrides;
235 if (failed(memrefType.getStridesAndOffset(staticStrides, staticOffset)))
236 return failure();
237 if (staticOffset == ShapedType::kDynamic)
238 return rewriter.notifyMatchFailure(op, "dynamic offset not supported");
239 SmallVector<OpFoldResult> strides;
240 strides.reserve(staticStrides.size());
241 for (int64_t stride : staticStrides) {
242 if (stride == ShapedType::kDynamic)
243 return rewriter.notifyMatchFailure(op,
244 "dynamic stride cannot be computed");
245 strides.push_back(rewriter.getIndexAttr(stride));
246 }
247
248 // Compute the linearized flat extent from sizes and strides (no SSA ops
249 // referencing op.result are created here).
250 memref::LinearizedMemRefInfo linearizedInfo;
251 OpFoldResult linearizedOffset;
252 std::tie(linearizedInfo, linearizedOffset) =
254 rewriter, loc, elemBitWidth, elemBitWidth, rewriter.getIndexAttr(0),
255 sizes, strides);
256 (void)linearizedOffset;
257
258 // The total allocation must cover [0, staticOffset + linearizedExtent).
259 // When the offset is non-zero, add it to the computed extent so that the
260 // buffer is large enough for elements accessed at positions
261 // [staticOffset, staticOffset + linearizedExtent).
262 OpFoldResult flatSizeOfr = linearizedInfo.linearizedSize;
263 if (staticOffset != 0) {
264 AffineExpr s0;
265 bindSymbols(rewriter.getContext(), s0);
267 rewriter, loc, s0 + staticOffset, {flatSizeOfr});
268 }
269
270 // Build the flat 1-D MemRefType. The linearized size may be static or
271 // dynamic (OpFoldResult of either IntegerAttr or a Value).
272 int64_t flatDimSize = ShapedType::kDynamic;
273 if (auto attr = dyn_cast<Attribute>(flatSizeOfr))
274 if (auto intAttr = dyn_cast<IntegerAttr>(attr))
275 flatDimSize = intAttr.getInt();
276
277 auto flatMemrefType =
278 MemRefType::get({flatDimSize}, memrefType.getElementType(),
279 StridedLayoutAttr::get(rewriter.getContext(), 0, {1}),
280 memrefType.getMemorySpace());
281
282 // Collect the flat dynamic-size operand (empty for fully-static case).
283 SmallVector<Value, 1> dynSizes;
284 if (flatDimSize == ShapedType::kDynamic)
285 dynSizes.push_back(getValueFromOpFoldResult(rewriter, loc, flatSizeOfr));
286
287 auto newOp = AllocLikeOp::create(rewriter, loc, flatMemrefType, dynSizes,
288 op.getAlignmentAttr());
289 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
290 op, cast<MemRefType>(op.getType()), newOp,
291 rewriter.getIndexAttr(staticOffset), sizes, strides);
292 return success();
293 }
294};
295
296template <typename T>
297struct MemRefRewritePattern : public OpRewritePattern<T> {
298 using OpRewritePattern<T>::OpRewritePattern;
299 LogicalResult matchAndRewrite(T op,
300 PatternRewriter &rewriter) const override {
301 LogicalResult canFlatten = canBeFlattened(op, rewriter);
302 if (failed(canFlatten))
303 return canFlatten;
304
305 Value memref = getTargetMemref(op);
306 if (!needFlattening(memref) || !checkLayout(memref))
307 return failure();
308
309 auto &&[flatMemref, offset] = getFlattenMemrefAndOffset(
310 rewriter, op->getLoc(), memref, getIndices<T>(op));
311 replaceOp<T>(op, rewriter, flatMemref, offset);
312 return success();
313 }
314};
315
316struct FlattenMemrefsPass
317 : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
318 using Base::Base;
319
320 void getDependentDialects(DialectRegistry &registry) const override {
321 registry.insert<affine::AffineDialect, arith::ArithDialect,
322 memref::MemRefDialect, vector::VectorDialect>();
323 }
324
325 void runOnOperation() override {
326 RewritePatternSet patterns(&getContext());
327
329
330 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
331 return signalPassFailure();
332 }
333};
334
335} // namespace
336
338 RewritePatternSet &patterns) {
339 patterns.insert<MemRefRewritePattern<vector::LoadOp>,
340 MemRefRewritePattern<vector::StoreOp>,
341 MemRefRewritePattern<vector::TransferReadOp>,
342 MemRefRewritePattern<vector::TransferWriteOp>,
343 MemRefRewritePattern<vector::MaskedLoadOp>,
344 MemRefRewritePattern<vector::MaskedStoreOp>>(
345 patterns.getContext());
346}
347
349 patterns.insert<MemRefRewritePattern<memref::LoadOp>,
350 MemRefRewritePattern<memref::StoreOp>,
351 AllocLikeFlattenPattern<memref::AllocOp>,
352 AllocLikeFlattenPattern<memref::AllocaOp>>(
353 patterns.getContext());
354}
355
return success()
static bool checkLayout(Value val)
static bool needFlattening(Value val)
static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc, OpFoldResult in)
static std::pair< Value, Value > getFlattenMemrefAndOffset(OpBuilder &rewriter, Location loc, Value source, ValueRange indices)
Returns a collapsed memref and the linearized index to access the element at the specified indices.
b getContext())
Attributes are known-constant values of operations.
Definition Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
MLIRContext * getContext() const
Definition Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={})
void populateFlattenMemrefsPatterns(RewritePatternSet &patterns)
void populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns)
void populateFlattenVectorOpsOnMemrefPatterns(RewritePatternSet &patterns)
Patterns for flattening multi-dimensional memref operations into one-dimensional memref operations.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
Definition MemRefUtils.h:50