30#include "llvm/ADT/STLExtras.h"
35#define GEN_PASS_DEF_FLATTENMEMREFSPASS
36#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
44 if (
Attribute offsetAttr = dyn_cast<Attribute>(in)) {
46 rewriter, loc, cast<IntegerAttr>(offsetAttr).getInt());
48 return cast<Value>(in);
59 auto sourceType = cast<MemRefType>(source.
getType());
60 if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset))) {
64 memref::ExtractStridedMetadataOp stridedMetadata =
65 memref::ExtractStridedMetadataOp::create(rewriter, loc, source);
67 auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth();
70 std::tie(linearizedInfo, linearizedIndices) =
72 rewriter, loc, typeBit, typeBit,
73 stridedMetadata.getConstifiedMixedOffset(),
74 stridedMetadata.getConstifiedMixedSizes(),
75 stridedMetadata.getConstifiedMixedStrides(),
78 return std::make_pair(
79 memref::ReinterpretCastOp::create(
80 rewriter, loc, source,
90 auto type = cast<MemRefType>(val.
getType());
91 return type.getRank() > 1;
95 auto type = cast<MemRefType>(val.
getType());
96 return type.getLayout().isIdentity() ||
97 isa<StridedLayoutAttr>(type.getLayout());
102 auto type = cast<MemRefType>(
memref.getType());
103 return type.getElementType().isIntOrFloat();
108static FailureOr<MemRefType> getFlattenedMemRefType(MemRefType sourceType) {
111 if (
failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)))
115 for (
auto [size, stride] :
116 llvm::zip_equal(sourceType.getShape(), sourceStrides)) {
119 flatDimSize = flatDimSize.smax(dimSize);
120 if (flatDimSize.isSaturated())
124 if (sourceType.getLayout().isIdentity())
125 return MemRefType::get(
126 {flatDimSize.asInteger()}, sourceType.getElementType(),
127 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
129 return MemRefType::get(
130 {flatDimSize.asInteger()}, sourceType.getElementType(),
131 StridedLayoutAttr::get(sourceType.getContext(), sourceOffset, {1}),
132 sourceType.getMemorySpace());
143 if (!hasSupportedElementType(
memref))
150static LogicalResult hasUnitTrailingStride(
Operation *op,
153 if (!
memref.getType().areTrailingDimsContiguous(1))
155 op,
"cannot preserve non-unit trailing access stride");
161canLinearizeAccessedShape(memref::IndexedAccessOpInterface op,
165 if (accessedShape.empty())
167 if (accessedShape.size() > 1)
169 op,
"cannot preserve multi-dimensional accessed shape");
171 return hasUnitTrailingStride(op,
memref, rewriter);
174static LogicalResult canFlattenTransferOp(VectorTransferOpInterface op,
184 op,
"only identity or minor identity permutation map is supported");
186 if (op.hasOutOfBoundsDim())
189 if (op.getTransferRank() > 1)
191 op,
"cannot flatten multi-dimensional vector transfer");
193 if (op.getTransferRank() > 0 &&
213template <
typename AllocLikeOp>
214struct AllocLikeFlattenPattern final :
public OpRewritePattern<AllocLikeOp> {
215 using Base = OpRewritePattern<AllocLikeOp>;
218 LogicalResult matchAndRewrite(AllocLikeOp op,
219 PatternRewriter &rewriter)
const override {
223 Location loc = op->getLoc();
224 auto memrefType = cast<MemRefType>(op.getType());
225 auto elemType = memrefType.getElementType();
226 if (!elemType.isIntOrFloat())
228 unsigned elemBitWidth = elemType.getIntOrFloatBitWidth();
230 SmallVector<OpFoldResult> sizes = op.getMixedSizes();
232 int64_t staticOffset;
233 SmallVector<int64_t> staticStrides;
234 if (
failed(memrefType.getStridesAndOffset(staticStrides, staticOffset)))
236 if (staticOffset == ShapedType::kDynamic)
238 SmallVector<OpFoldResult> strides;
239 strides.reserve(staticStrides.size());
240 for (int64_t stride : staticStrides) {
241 if (stride == ShapedType::kDynamic)
243 "dynamic stride cannot be computed");
249 memref::LinearizedMemRefInfo linearizedInfo;
250 OpFoldResult linearizedOffset;
251 std::tie(linearizedInfo, linearizedOffset) =
253 rewriter, loc, elemBitWidth, elemBitWidth, rewriter.
getIndexAttr(0),
255 (void)linearizedOffset;
262 if (staticOffset != 0) {
265 flatSizeOfr = affine::makeComposedFoldedAffineApply(
266 rewriter, loc, s0 + staticOffset, {flatSizeOfr});
271 int64_t flatDimSize = ShapedType::kDynamic;
272 if (
auto attr = dyn_cast<Attribute>(flatSizeOfr))
273 if (
auto intAttr = dyn_cast<IntegerAttr>(attr))
274 flatDimSize = intAttr.getInt();
276 auto flatMemrefType =
277 MemRefType::get({flatDimSize}, memrefType.getElementType(),
278 StridedLayoutAttr::get(rewriter.
getContext(), 0, {1}),
279 memrefType.getMemorySpace());
282 SmallVector<Value, 1> dynSizes;
283 if (flatDimSize == ShapedType::kDynamic)
286 auto newOp = AllocLikeOp::create(rewriter, loc, flatMemrefType, dynSizes,
287 op.getAlignmentAttr());
289 op, cast<MemRefType>(op.getType()), newOp,
296struct IndexedAccessOpFlattenPattern final
300 LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
301 PatternRewriter &rewriter)
const override {
305 if (
failed(checkFlattenableMemref(op, memref, rewriter)))
307 if (
failed(canLinearizeAccessedShape(op, memref, rewriter)))
311 rewriter, op->getLoc(), memref, op.getIndices());
312 std::optional<SmallVector<Value>> replacementValues =
313 op.updateMemrefAndIndices(rewriter, flatMemref,
ValueRange{offset});
314 if (replacementValues)
315 rewriter.
replaceOp(op, *replacementValues);
323struct VectorTransferOpFlattenPattern final
327 LogicalResult matchAndRewrite(VectorTransferOpInterface op,
328 PatternRewriter &rewriter)
const override {
329 auto memref = dyn_cast<TypedValue<MemRefType>>(op.getBase());
332 if (
failed(checkFlattenableMemref(op, memref, rewriter)))
334 if (
failed(canFlattenTransferOp(op, memref, rewriter)))
337 FailureOr<MemRefType> flatMemrefType =
338 getFlattenedMemRefType(memref.getType());
339 if (
failed(flatMemrefType))
342 1, op.getTransferRank(), op.getContext());
344 op.mayUpdateStartingPosition(*flatMemrefType, newPermutationMap)))
346 "failed op-specific preconditions");
349 rewriter, op->getLoc(), memref, op.getIndices());
350 op.updateStartingPosition(rewriter, flatMemref,
ValueRange{offset},
351 AffineMapAttr::get(newPermutationMap));
358struct FlattenedMemrefAccess {
364struct IndexedMemCopyOpFlattenPattern final
368 LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
369 PatternRewriter &rewriter)
const override {
380 if (
failed(checkFlattenableMemref(op, memref, rewriter)))
383 auto [flatMemref, offset] =
385 return FlattenedMemrefAccess{flatMemref, offset};
388 std::optional<FlattenedMemrefAccess> newSrc =
389 tryFlatten(src, op.getSrcIndices());
390 std::optional<FlattenedMemrefAccess> newDst =
391 tryFlatten(dst, op.getDstIndices());
392 if (!newSrc && !newDst)
394 op,
"no source or destination memref needed flattening");
396 Value srcMemref = src;
399 srcMemref = newSrc->memref;
403 Value dstMemref = dst;
406 dstMemref = newDst->memref;
410 op.setMemrefsAndIndices(rewriter, srcMemref, srcIndices, dstMemref,
416struct FlattenMemrefsPass
417 :
public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
420 void getDependentDialects(DialectRegistry ®istry)
const override {
421 registry.
insert<affine::AffineDialect, arith::ArithDialect,
422 memref::MemRefDialect>();
425 void runOnOperation()
override {
431 return signalPassFailure();
438 patterns.
insert<IndexedAccessOpFlattenPattern, IndexedMemCopyOpFlattenPattern,
439 VectorTransferOpFlattenPattern,
440 AllocLikeFlattenPattern<memref::AllocOp>,
441 AllocLikeFlattenPattern<memref::AllocaOp>>(
static bool checkLayout(Value val)
static bool needFlattening(Value val)
static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc, OpFoldResult in)
static std::pair< Value, Value > getFlattenMemrefAndOffset(OpBuilder &rewriter, Location loc, Value source, ValueRange indices)
Returns a collapsed memref and the linearized index to access the element at the specified indices.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap getMinorIdentityMap(unsigned dims, unsigned results, MLIRContext *context)
Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions.
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
std::pair< LinearizedMemRefInfo, OpFoldResult > getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, int dstBits, OpFoldResult offset, ArrayRef< OpFoldResult > sizes, ArrayRef< OpFoldResult > strides, ArrayRef< OpFoldResult > indices={}, LinearizedDivKind sizeDivKind=LinearizedDivKind::Floor)
void populateFlattenMemrefsPatterns(RewritePatternSet &patterns)
Patterns for flattening all supported multi-dimensional memref operations into one-dimensional memref...
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
static SaturatedInteger wrap(int64_t v)
For a memref with offset, sizes and strides, returns the offset, size, and potentially the size padde...
OpFoldResult linearizedSize
OpFoldResult linearizedOffset