MLIR 23.0.0git
ElideReinterpretCast.cpp
Go to the documentation of this file.
1//===-ElideReinterpretCast.cpp - Expansion patterns for MemRef operations-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
16#include "llvm/ADT/Repeated.h"
17#include <cassert>
18
19namespace mlir {
20namespace memref {
21#define GEN_PASS_DEF_ELIDEREINTERPRETCASTPASS
22#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
23} // namespace memref
24} // namespace mlir
25
26using namespace mlir;
27
28namespace {
29
30/// Returns true if `rc` represents a scalar view (all sizes == 1)
31/// into a memref that has exactly one non-unit dimension located at
32/// either the first or last position (i.e. a "row" or "column").
33///
34/// Examples that return true:
35///
36/// // Row-major slice (last dim is non-unit)
37/// memref.reinterpret_cast %buff to offset: [%off],
38/// sizes: [1, 1, 1], strides: [1, 1, 1]
39/// : memref<1x1x8xi32> to memref<1x1x1xi32>
40///
41/// // Column-major slice (first dim is non-unit)
42/// memref.reinterpret_cast %buff to offset: [%off],
43/// sizes: [1, 1], strides: [1, 1]
44/// : memref<2x1xf32> to memref<1x1xf32>
45///
46/// // Random strides
47/// memref.reinterpret_cast %buff to offset: [%off],
48/// sizes: [1, 1], strides: [10, 100]
49/// : memref<2x1xf32, strided<[10, 100]>>
50/// to memref<1x1xf32>
51///
52/// // Rank-1 case
53/// memref.reinterpret_cast %buf to offset: [%off],
54/// sizes: [1], strides: [1]
55/// : memref<8xi32> to memref<1xi32>
56///
57/// Examples that return false:
58///
59/// // More non-unit dims
60/// memref.reinterpret_cast %buff to offset: [%off],
61/// sizes: [1, 1, 1], strides: [1, 1, 1]
62/// : memref<1x2x8xi32> to memref<1x1x1xi32>
63///
64/// // View is not scalar (size != 1)
65/// memref.reinterpret_cast %buff to offset: [%off],
66/// sizes: [2, 1], strides: [1, 1]
67/// : memref<1x2xf32> to memref<2x1xf32>
68///
69/// // Base has non-identity layout
70/// %buff = memref.alloc() : memref<1x2xf32, strided<[1, 3]>>
71/// memref.reinterpret_cast %buff to offset: [%off],
72/// sizes: [1, 1], strides: [1, 1]
73/// : memref<1x2xf32, strided<[1, 3]>> to memref<1x1xf32>
74static bool isScalarSlice(memref::ReinterpretCastOp rc) {
75 auto rcInputTy = dyn_cast<MemRefType>(rc.getSource().getType());
76 auto rcOutputTy = dyn_cast<MemRefType>(rc.getType());
77
78 // Reject strided base - logic for computing linear idx is TODO
79 if (!rcInputTy.getLayout().isIdentity())
80 return false;
81
82 // Reject non-matching ranks
83 unsigned srcRank = rcInputTy.getRank();
84 if (srcRank != rcOutputTy.getRank())
85 return false;
86
87 ArrayRef<int64_t> sizes = rc.getStaticSizes();
88
89 // View must be scalar: memref<1x...x1>
90 if (!llvm::all_of(rcOutputTy.getShape(),
91 [](int64_t dim) { return dim == 1; }))
92 return false;
93
94 // Sizes must all be statically 1
95 if (!llvm::all_of(sizes, [](int64_t size) {
96 return !ShapedType::isDynamic(size) && size == 1;
97 }))
98 return false;
99
100 // Rank-1 special case
101 if (srcRank == 1) {
102 // Reject non-scalar output
103 if (rcOutputTy.getDimSize(0) > 1)
104 return false;
105 }
106
107 int nonUnitCount =
108 std::count_if(rcInputTy.getShape().begin(), rcInputTy.getShape().end(),
109 [](int dim) { return dim != 1; });
110 return nonUnitCount == 1;
111}
112
113/// Rewrites `memref.copy` of a 1-element MemRef as a scalar load-store pair
114///
115/// The pattern matches a reinterpret_cast that creates a scalar view
116/// (`sizes = [1, ..., 1]`) into a memref with a single non-unit dimension.
117/// Since the view contains only one element, the accessed address is
118/// determined solely by the base pointer and the offset.
119///
120/// Two layouts are supported:
121/// * row-major slice (stride pattern [N, ..., 1])
122/// * column-major slice (stride pattern [1, ..., N])
123///
124/// BEFORE (row-major slice)
125/// %view = memref.reinterpret_cast %base
126/// to offset: [%off], sizes: [1, ..., 1], strides: [N, ..., 1]
127/// : memref<1x...xNxf32>
128/// to memref<1x...x1xf32, strided<[N, ..., 1], offset: ?>>
129/// memref.copy %src, %view
130/// : memref<1x...x1xf32>
131/// to memref<1x...x1xf32, strided<[N, ..., 1], offset: ?>>
132///
133/// AFTER
134/// %c0 = arith.constant 0 : index
135/// %v = memref.load %src[%c0, ..., %c0] : memref<1x...x1xf32>
136/// memref.store %v, %base[%c0, ..., %off] : memref<1x...xNxf32>
137///
138/// BEFORE (column-major slice)
139/// %view = memref.reinterpret_cast %base
140/// to offset: [%off], sizes: [1, ..., 1], strides: [1, ..., N]
141/// : memref<Nx...x1xf32>
142/// to memref<1x...x1xf32, strided<[1, ..., N], offset: ?>>
143/// memref.copy %src, %view
144/// : memref<1x...x1xf32>
145/// to memref<1x...x1xf32, strided<[1, ..., N], offset: ?>>
146///
147/// AFTER
148/// %c0 = arith.constant 0 : index
149/// %v = memref.load %src[%c0, ..., %c0] : memref<1x...x1xf32>
150/// memref.store %v, %base[%off, ..., %c0] : memref<Nx...x1xf32>
151struct CopyToScalarLoadAndStore : public OpRewritePattern<memref::CopyOp> {
152public:
154
155 LogicalResult matchAndRewrite(memref::CopyOp op,
156 PatternRewriter &rewriter) const final {
157 Value rcOutput = op.getTarget();
158 auto rc = rcOutput.getDefiningOp<memref::ReinterpretCastOp>();
159 if (!rc)
160 return rewriter.notifyMatchFailure(
161 op, "target is not a memref.reinterpret_cast");
162
163 if (!isScalarSlice(rc))
164 return rewriter.notifyMatchFailure(
165 op, "reinterpret_cast does not match scalar slice");
166
167 Location loc = op.getLoc();
168
169 Value src = op.getSource();
170 Value dst = rc.getSource();
171
172 auto dstType = cast<MemRefType>(dst.getType());
173 unsigned dstRank = dstType.getRank();
174
175 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
176
177 auto srcType = cast<MemRefType>(src.getType());
178 Repeated<Value> loadIndices(srcType.getRank(), zero);
179 auto offsets = rc.getMixedOffsets();
180 assert(offsets.size() == 1 && "Expecting single offset");
181 OpFoldResult offset = offsets[0];
182 Value storeOffset = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
183 unsigned offsetDim = dstType.getDimSize(0) == 1 ? dstRank - 1 : 0;
184 SmallVector<Value> storeIndices(dstRank, zero);
185 storeIndices[offsetDim] = storeOffset;
186 // If the only user of `rc` is the current Op (which is about to be erased),
187 // we can safely erase it.
188 if (rcOutput.hasOneUse())
189 rewriter.eraseOp(rc);
190
191 Value val = memref::LoadOp::create(rewriter, loc, src, loadIndices);
192 memref::StoreOp::create(rewriter, loc, val, dst, storeIndices);
193
194 rewriter.eraseOp(op);
195 return success();
196 }
197};
198
199struct ElideReinterpretCastPass
200 : public memref::impl::ElideReinterpretCastPassBase<
201 ElideReinterpretCastPass> {
202 void runOnOperation() override {
203 MLIRContext &ctx = getContext();
204
205 RewritePatternSet patterns(&ctx);
207 ConversionTarget target(ctx);
208 target.addDynamicallyLegalOp<memref::CopyOp>([](memref::CopyOp op) {
209 auto rc = op.getTarget().getDefiningOp<memref::ReinterpretCastOp>();
210 if (!rc)
211 return true;
212 return !isScalarSlice(rc);
213 });
214 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
215 if (failed(applyPartialConversion(getOperation(), target,
216 std::move(patterns))))
217 signalPassFailure();
218 }
219};
220
221} // namespace
222
224 RewritePatternSet &patterns) {
225 patterns.add<CopyToScalarLoadAndStore>(patterns.getContext());
226}
return success()
b getContext())
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type getType() const
Return the type of this value.
Definition Value.h:105
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
void populateElideReinterpretCastPatterns(RewritePatternSet &patterns)
Collects a set of patterns that bypass memref.reinterpet_cast Ops.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...