MLIR 23.0.0git
FlattenMemRefs.cpp
Go to the documentation of this file.
1//===----- FlattenMemRefs.cpp - MemRef ops flattener pass ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains patterns for flattening an multi-rank memref-related
10// ops into 1-d memref ops.
11//
12//===----------------------------------------------------------------------===//
13
24#include "mlir/IR/Attributes.h"
25#include "mlir/IR/Builders.h"
30#include "llvm/ADT/STLExtras.h"
31#include <algorithm>
32
33namespace mlir {
34namespace memref {
35#define GEN_PASS_DEF_FLATTENMEMREFSPASS
36#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
37} // namespace memref
38} // namespace mlir
39
40using namespace mlir;
41
43 OpFoldResult in) {
44 if (Attribute offsetAttr = dyn_cast<Attribute>(in)) {
46 rewriter, loc, cast<IntegerAttr>(offsetAttr).getInt());
47 }
48 return cast<Value>(in);
49}
50
51/// Returns a collapsed memref and the linearized index to access the element
52/// at the specified indices.
53static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
54 Location loc,
55 Value source,
57 int64_t sourceOffset;
58 SmallVector<int64_t, 4> sourceStrides;
59 auto sourceType = cast<MemRefType>(source.getType());
60 if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset))) {
61 assert(false);
62 }
63
64 memref::ExtractStridedMetadataOp stridedMetadata =
65 memref::ExtractStridedMetadataOp::create(rewriter, loc, source);
66
67 auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth();
68 OpFoldResult linearizedIndices;
69 memref::LinearizedMemRefInfo linearizedInfo;
70 std::tie(linearizedInfo, linearizedIndices) =
72 rewriter, loc, typeBit, typeBit,
73 stridedMetadata.getConstifiedMixedOffset(),
74 stridedMetadata.getConstifiedMixedSizes(),
75 stridedMetadata.getConstifiedMixedStrides(),
77
78 return std::make_pair(
79 memref::ReinterpretCastOp::create(
80 rewriter, loc, source,
81 /*offset=*/linearizedInfo.linearizedOffset,
82 /*sizes=*/
83 ArrayRef<OpFoldResult>{linearizedInfo.linearizedSize},
84 /*strides=*/
85 ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)}),
86 getValueFromOpFoldResult(rewriter, loc, linearizedIndices));
87}
88
89static bool needFlattening(Value val) {
90 auto type = cast<MemRefType>(val.getType());
91 return type.getRank() > 1;
92}
93
94static bool checkLayout(Value val) {
95 auto type = cast<MemRefType>(val.getType());
96 return type.getLayout().isIdentity() ||
97 isa<StridedLayoutAttr>(type.getLayout());
98}
99
100namespace {
101static bool hasSupportedElementType(Value memref) {
102 auto type = cast<MemRefType>(memref.getType());
103 return type.getElementType().isIntOrFloat();
104}
105
106/// Compute the type that will be used to linearize the memref.
107/// Used so we don't create IR like `getLinearizedMemRefOffsetAndSize` would.
108static FailureOr<MemRefType> getFlattenedMemRefType(MemRefType sourceType) {
109 int64_t sourceOffset;
110 SmallVector<int64_t> sourceStrides;
111 if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)))
112 return failure();
113
114 auto flatDimSize = SaturatedInteger::wrap(0);
115 for (auto [size, stride] :
116 llvm::zip_equal(sourceType.getShape(), sourceStrides)) {
117 auto dimSize =
119 flatDimSize = flatDimSize.smax(dimSize);
120 if (flatDimSize.isSaturated())
121 break;
122 }
123
124 if (sourceType.getLayout().isIdentity())
125 return MemRefType::get(
126 {flatDimSize.asInteger()}, sourceType.getElementType(),
127 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
128
129 return MemRefType::get(
130 {flatDimSize.asInteger()}, sourceType.getElementType(),
131 StridedLayoutAttr::get(sourceType.getContext(), sourceOffset, {1}),
132 sourceType.getMemorySpace());
133}
134
135/// Return whether `memref` has the basic properties needed for linearizing it
136/// into a 1-D reinterpret_cast.
137static LogicalResult checkFlattenableMemref(Operation *op, Value memref,
138 PatternRewriter &rewriter) {
140 return rewriter.notifyMatchFailure(op, "memref does not need flattening");
141 if (!checkLayout(memref))
142 return rewriter.notifyMatchFailure(op, "unsupported memref layout");
143 if (!hasSupportedElementType(memref))
144 return rewriter.notifyMatchFailure(op, "unsupported element type");
145 return success();
146}
147
148/// Wrapeer around checking if the last memref dimension is contiguous that
149/// provides nice failures message.
150static LogicalResult hasUnitTrailingStride(Operation *op,
152 PatternRewriter &rewriter) {
153 if (!memref.getType().areTrailingDimsContiguous(1))
154 return rewriter.notifyMatchFailure(
155 op, "cannot preserve non-unit trailing access stride");
156
157 return success();
158}
159
160static LogicalResult
161canLinearizeAccessedShape(memref::IndexedAccessOpInterface op,
163 PatternRewriter &rewriter) {
164 SmallVector<int64_t> accessedShape = op.getAccessedShape();
165 if (accessedShape.empty())
166 return success();
167 if (accessedShape.size() > 1)
168 return rewriter.notifyMatchFailure(
169 op, "cannot preserve multi-dimensional accessed shape");
170
171 return hasUnitTrailingStride(op, memref, rewriter);
172}
173
174static LogicalResult canFlattenTransferOp(VectorTransferOpInterface op,
176 PatternRewriter &rewriter) {
177 // For vector.transfer_read/write, must make sure:
178 // 1. all accesses are inbounds,
179 // 2. has a minor identity permutation map, and
180 // 3. has at most one transfer dimension.
181 AffineMap permutationMap = op.getPermutationMap();
182 if (!permutationMap.isMinorIdentity())
183 return rewriter.notifyMatchFailure(
184 op, "only identity or minor identity permutation map is supported");
185
186 if (op.hasOutOfBoundsDim())
187 return rewriter.notifyMatchFailure(op, "only inbounds are supported");
188
189 if (op.getTransferRank() > 1)
190 return rewriter.notifyMatchFailure(
191 op, "cannot flatten multi-dimensional vector transfer");
192
193 if (op.getTransferRank() > 0 &&
194 failed(hasUnitTrailingStride(op, memref, rewriter)))
195 return failure();
196
197 return success();
198}
199
200// Pattern for memref::AllocOp and memref::AllocaOp.
201//
202// The "source" memref for these ops is the op's own result, so the generic
203// indexed access pattern cannot be used: getFlattenMemrefAndOffset would
204// insert ExtractStridedMetadataOp and ReinterpretCastOp that use op.result
205// before this op in the block. After replaceOpWithNewOp the original result is
206// RAUW'd to the new ReinterpretCastOp, leaving the earlier ops with forward
207// references (domination violations) caught by
208// MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS.
209//
210// Instead, sizes and strides are computed from the op's operands and type
211// (which all dominate the op), avoiding any reference to op.result until the
212// final replaceOpWithNewOp.
213template <typename AllocLikeOp>
214struct AllocLikeFlattenPattern final : public OpRewritePattern<AllocLikeOp> {
215 using Base = OpRewritePattern<AllocLikeOp>;
216 using Base::Base;
217
218 LogicalResult matchAndRewrite(AllocLikeOp op,
219 PatternRewriter &rewriter) const override {
220 if (!needFlattening(op.getMemref()) || !checkLayout(op.getMemref()))
221 return failure();
222
223 Location loc = op->getLoc();
224 auto memrefType = cast<MemRefType>(op.getType());
225 auto elemType = memrefType.getElementType();
226 if (!elemType.isIntOrFloat())
227 return failure();
228 unsigned elemBitWidth = elemType.getIntOrFloatBitWidth();
229
230 SmallVector<OpFoldResult> sizes = op.getMixedSizes();
231
232 int64_t staticOffset;
233 SmallVector<int64_t> staticStrides;
234 if (failed(memrefType.getStridesAndOffset(staticStrides, staticOffset)))
235 return failure();
236 if (staticOffset == ShapedType::kDynamic)
237 return rewriter.notifyMatchFailure(op, "dynamic offset not supported");
238 SmallVector<OpFoldResult> strides;
239 strides.reserve(staticStrides.size());
240 for (int64_t stride : staticStrides) {
241 if (stride == ShapedType::kDynamic)
242 return rewriter.notifyMatchFailure(op,
243 "dynamic stride cannot be computed");
244 strides.push_back(rewriter.getIndexAttr(stride));
245 }
246
247 // Compute the linearized flat extent from sizes and strides (no SSA ops
248 // referencing op.result are created here).
249 memref::LinearizedMemRefInfo linearizedInfo;
250 OpFoldResult linearizedOffset;
251 std::tie(linearizedInfo, linearizedOffset) =
253 rewriter, loc, elemBitWidth, elemBitWidth, rewriter.getIndexAttr(0),
254 sizes, strides);
255 (void)linearizedOffset;
256
257 // The total allocation must cover [0, staticOffset + linearizedExtent).
258 // When the offset is non-zero, add it to the computed extent so that the
259 // buffer is large enough for elements accessed at positions
260 // [staticOffset, staticOffset + linearizedExtent).
261 OpFoldResult flatSizeOfr = linearizedInfo.linearizedSize;
262 if (staticOffset != 0) {
263 AffineExpr s0;
264 bindSymbols(rewriter.getContext(), s0);
265 flatSizeOfr = affine::makeComposedFoldedAffineApply(
266 rewriter, loc, s0 + staticOffset, {flatSizeOfr});
267 }
268
269 // Build the flat 1-D MemRefType. The linearized size may be static or
270 // dynamic (OpFoldResult of either IntegerAttr or a Value).
271 int64_t flatDimSize = ShapedType::kDynamic;
272 if (auto attr = dyn_cast<Attribute>(flatSizeOfr))
273 if (auto intAttr = dyn_cast<IntegerAttr>(attr))
274 flatDimSize = intAttr.getInt();
275
276 auto flatMemrefType =
277 MemRefType::get({flatDimSize}, memrefType.getElementType(),
278 StridedLayoutAttr::get(rewriter.getContext(), 0, {1}),
279 memrefType.getMemorySpace());
280
281 // Collect the flat dynamic-size operand (empty for fully-static case).
282 SmallVector<Value, 1> dynSizes;
283 if (flatDimSize == ShapedType::kDynamic)
284 dynSizes.push_back(getValueFromOpFoldResult(rewriter, loc, flatSizeOfr));
285
286 auto newOp = AllocLikeOp::create(rewriter, loc, flatMemrefType, dynSizes,
287 op.getAlignmentAttr());
288 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
289 op, cast<MemRefType>(op.getType()), newOp,
290 rewriter.getIndexAttr(staticOffset), sizes, strides);
291 return success();
292 }
293};
294
295/// Pattern that flattens any IndexedAccessOpInterface op.
296struct IndexedAccessOpFlattenPattern final
297 : public OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
298 using Base::Base;
299
300 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
301 PatternRewriter &rewriter) const override {
302 TypedValue<MemRefType> memref = op.getAccessedMemref();
303 if (!memref)
304 return rewriter.notifyMatchFailure(op, "not accessing a memref");
305 if (failed(checkFlattenableMemref(op, memref, rewriter)))
306 return failure();
307 if (failed(canLinearizeAccessedShape(op, memref, rewriter)))
308 return failure();
309
310 auto [flatMemref, offset] = getFlattenMemrefAndOffset(
311 rewriter, op->getLoc(), memref, op.getIndices());
312 std::optional<SmallVector<Value>> replacementValues =
313 op.updateMemrefAndIndices(rewriter, flatMemref, ValueRange{offset});
314 if (replacementValues)
315 rewriter.replaceOp(op, *replacementValues);
316 return success();
317 }
318};
319
320/// Flatten operations that use VectorTransferOpInterface. Transfer ops have
321/// permutation-map and in_bounds semantics that are separate from
322/// IndexedAccessOpInterface, so use updateStartingPosition directly.
323struct VectorTransferOpFlattenPattern final
324 : public OpInterfaceRewritePattern<VectorTransferOpInterface> {
325 using Base::Base;
326
327 LogicalResult matchAndRewrite(VectorTransferOpInterface op,
328 PatternRewriter &rewriter) const override {
329 auto memref = dyn_cast<TypedValue<MemRefType>>(op.getBase());
330 if (!memref)
331 return rewriter.notifyMatchFailure(op, "not accessing a memref");
332 if (failed(checkFlattenableMemref(op, memref, rewriter)))
333 return failure();
334 if (failed(canFlattenTransferOp(op, memref, rewriter)))
335 return failure();
336
337 FailureOr<MemRefType> flatMemrefType =
338 getFlattenedMemRefType(memref.getType());
339 if (failed(flatMemrefType))
340 return failure();
341 AffineMap newPermutationMap = AffineMap::getMinorIdentityMap(
342 /*dims=*/1, op.getTransferRank(), op.getContext());
343 if (failed(
344 op.mayUpdateStartingPosition(*flatMemrefType, newPermutationMap)))
345 return rewriter.notifyMatchFailure(op,
346 "failed op-specific preconditions");
347
348 auto [flatMemref, offset] = getFlattenMemrefAndOffset(
349 rewriter, op->getLoc(), memref, op.getIndices());
350 op.updateStartingPosition(rewriter, flatMemref, ValueRange{offset},
351 AffineMapAttr::get(newPermutationMap));
352 return success();
353 }
354};
355
356/// Flatten the source and destination memref/index pairs of indexed memcpy-like
357/// operations such as memref.dma_start.
358struct FlattenedMemrefAccess {
359 Value memref;
360 Value index;
361};
362
363/// Flatten all IndexedMemCopyOpInterface operations.
364struct IndexedMemCopyOpFlattenPattern final
365 : public OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
366 using Base::Base;
367
368 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
369 PatternRewriter &rewriter) const override {
370 TypedValue<MemRefType> src = op.getSrc();
371 TypedValue<MemRefType> dst = op.getDst();
372 if (!src && !dst)
373 return rewriter.notifyMatchFailure(op, "not copying between memrefs");
374
375 auto tryFlatten =
376 [&](TypedValue<MemRefType> memref,
377 ValueRange indices) -> std::optional<FlattenedMemrefAccess> {
378 if (!memref || !needFlattening(memref))
379 return std::nullopt;
380 if (failed(checkFlattenableMemref(op, memref, rewriter)))
381 return std::nullopt;
382
383 auto [flatMemref, offset] =
384 getFlattenMemrefAndOffset(rewriter, op->getLoc(), memref, indices);
385 return FlattenedMemrefAccess{flatMemref, offset};
386 };
387
388 std::optional<FlattenedMemrefAccess> newSrc =
389 tryFlatten(src, op.getSrcIndices());
390 std::optional<FlattenedMemrefAccess> newDst =
391 tryFlatten(dst, op.getDstIndices());
392 if (!newSrc && !newDst)
393 return rewriter.notifyMatchFailure(
394 op, "no source or destination memref needed flattening");
395
396 Value srcMemref = src;
397 ValueRange srcIndices = op.getSrcIndices();
398 if (newSrc) {
399 srcMemref = newSrc->memref;
400 srcIndices = ValueRange(newSrc->index);
401 }
402
403 Value dstMemref = dst;
404 ValueRange dstIndices = op.getDstIndices();
405 if (newDst) {
406 dstMemref = newDst->memref;
407 dstIndices = ValueRange(newDst->index);
408 }
409
410 op.setMemrefsAndIndices(rewriter, srcMemref, srcIndices, dstMemref,
411 dstIndices);
412 return success();
413 }
414};
415
416struct FlattenMemrefsPass
417 : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
418 using Base::Base;
419
420 void getDependentDialects(DialectRegistry &registry) const override {
421 registry.insert<affine::AffineDialect, arith::ArithDialect,
422 memref::MemRefDialect>();
423 }
424
425 void runOnOperation() override {
426 RewritePatternSet patterns(&getContext());
427
429
430 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
431 return signalPassFailure();
432 }
433};
434
435} // namespace
436
438 patterns.insert<IndexedAccessOpFlattenPattern, IndexedMemCopyOpFlattenPattern,
439 VectorTransferOpFlattenPattern,
440 AllocLikeFlattenPattern<memref::AllocOp>,
441 AllocLikeFlattenPattern<memref::AllocaOp>>(
442 patterns.getContext());
443}
return success()
static bool checkLayout(Value val)
static bool needFlattening(Value val)
static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc, OpFoldResult in)
static std::pair< Value, Value > getFlattenMemrefAndOffset(OpBuilder &rewriter, Location loc, Value source, ValueRange indices)
Returns a collapsed memref and the linearized index to access the element at the specified indices.
b getContext())
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Attributes are known-constant values of operations.
Definition Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
MLIRContext * getContext() const
Definition Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:384
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={}, LinearizedDivKind sizeDivKind=LinearizedDivKind::Floor)
void populateFlattenMemrefsPatterns(RewritePatternSet &patterns)
Patterns for flattening all supported multi-dimensional memref operations into one-dimensional memref...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition AffineExpr.h:325
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
static SaturatedInteger wrap(int64_t v)
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
Definition MemRefUtils.h:64