MLIR 23.0.0git
ElideReinterpretCast.cpp
Go to the documentation of this file.
1//===-ElideReinterpretCast.cpp - Expansion patterns for MemRef operations-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "mlir/IR/Matchers.h"
17#include "llvm/ADT/Repeated.h"
18#include <cassert>
19#include <optional>
20
21namespace mlir {
22namespace memref {
23#define GEN_PASS_DEF_ELIDEREINTERPRETCASTPASS
24#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
25} // namespace memref
26} // namespace mlir
27
28using namespace mlir;
30namespace {
31
32/// Returns true if `rc` represents a scalar view (all sizes == 1)
33/// into a memref that has exactly one non-unit dimension located at
34/// either the first or last position (i.e. a "row" or "column").
35///
36/// Examples that return true:
37///
38/// // Row-major slice (last dim is non-unit)
39/// memref.reinterpret_cast %buff to offset: [%off],
40/// sizes: [1, 1, 1], strides: [1, 1, 1]
41/// : memref<1x1x8xi32> to memref<1x1x1xi32>
42///
43/// // Column-major slice (first dim is non-unit)
44/// memref.reinterpret_cast %buff to offset: [%off],
45/// sizes: [1, 1], strides: [1, 1]
46/// : memref<2x1xf32> to memref<1x1xf32>
47///
48/// // Random strides
49/// memref.reinterpret_cast %buff to offset: [%off],
50/// sizes: [1, 1], strides: [10, 100]
51/// : memref<2x1xf32, strided<[10, 100]>>
52/// to memref<1x1xf32>
53///
54/// // Rank-1 case
55/// memref.reinterpret_cast %buf to offset: [%off],
56/// sizes: [1], strides: [1]
57/// : memref<8xi32> to memref<1xi32>
58///
59/// Examples that return false:
60///
61/// // More non-unit dims
62/// memref.reinterpret_cast %buff to offset: [%off],
63/// sizes: [1, 1, 1], strides: [1, 1, 1]
64/// : memref<1x2x8xi32> to memref<1x1x1xi32>
65///
66/// // View is not scalar (size != 1)
67/// memref.reinterpret_cast %buff to offset: [%off],
68/// sizes: [2, 1], strides: [1, 1]
69/// : memref<1x2xf32> to memref<2x1xf32>
70///
71/// // Base has non-identity layout
72/// %buff = memref.alloc() : memref<1x2xf32, strided<[1, 3]>>
73/// memref.reinterpret_cast %buff to offset: [%off],
74/// sizes: [1, 1], strides: [1, 1]
75/// : memref<1x2xf32, strided<[1, 3]>> to memref<1x1xf32>
76static bool isScalarSlice(memref::ReinterpretCastOp rc) {
77 auto rcInputTy = dyn_cast<MemRefType>(rc.getSource().getType());
78 auto rcOutputTy = dyn_cast<MemRefType>(rc.getType());
79
80 // Reject strided base - logic for computing linear idx is TODO
81 if (!rcInputTy.getLayout().isIdentity())
82 return false;
84 // Reject non-matching ranks
85 unsigned srcRank = rcInputTy.getRank();
86 if (srcRank != rcOutputTy.getRank())
87 return false;
88
89 ArrayRef<int64_t> sizes = rc.getStaticSizes();
90
91 // View must be scalar: memref<1x...x1>
92 if (!llvm::all_of(rcOutputTy.getShape(),
93 [](int64_t dim) { return dim == 1; }))
94 return false;
95
96 // Sizes must all be statically 1
97 if (!llvm::all_of(sizes, [](int64_t size) {
98 return !ShapedType::isDynamic(size) && size == 1;
99 }))
100 return false;
101
102 // Rank-1 special case
103 if (srcRank == 1) {
104 // Reject non-scalar output
105 if (rcOutputTy.getDimSize(0) > 1)
106 return false;
107 }
108
109 int nonUnitCount =
110 std::count_if(rcInputTy.getShape().begin(), rcInputTy.getShape().end(),
111 [](int dim) { return dim != 1; });
112 return nonUnitCount == 1;
113}
114
115/// Rewrites `memref.copy` of a 1-element MemRef as a scalar load-store pair
116///
117/// The pattern matches a reinterpret_cast that creates a scalar view
118/// (`sizes = [1, ..., 1]`) into a memref with a single non-unit dimension.
119/// Since the view contains only one element, the accessed address is
120/// determined solely by the base pointer and the offset.
121///
122/// Two layouts are supported:
123/// * row-major slice (stride pattern [N, ..., 1])
124/// * column-major slice (stride pattern [1, ..., N])
125///
126/// BEFORE (row-major slice)
127/// %view = memref.reinterpret_cast %base
128/// to offset: [%off], sizes: [1, ..., 1], strides: [N, ..., 1]
129/// : memref<1x...xNxf32>
130/// to memref<1x...x1xf32, strided<[N, ..., 1], offset: ?>>
131/// memref.copy %src, %view
132/// : memref<1x...x1xf32>
133/// to memref<1x...x1xf32, strided<[N, ..., 1], offset: ?>>
134///
135/// AFTER
136/// %c0 = arith.constant 0 : index
137/// %v = memref.load %src[%c0, ..., %c0] : memref<1x...x1xf32>
138/// memref.store %v, %base[%c0, ..., %off] : memref<1x...xNxf32>
139///
140/// BEFORE (column-major slice)
141/// %view = memref.reinterpret_cast %base
142/// to offset: [%off], sizes: [1, ..., 1], strides: [1, ..., N]
143/// : memref<Nx...x1xf32>
144/// to memref<1x...x1xf32, strided<[1, ..., N], offset: ?>>
145/// memref.copy %src, %view
146/// : memref<1x...x1xf32>
147/// to memref<1x...x1xf32, strided<[1, ..., N], offset: ?>>
148///
149/// AFTER
150/// %c0 = arith.constant 0 : index
151/// %v = memref.load %src[%c0, ..., %c0] : memref<1x...x1xf32>
152/// memref.store %v, %base[%off, ..., %c0] : memref<Nx...x1xf32>
153struct CopyToScalarLoadAndStore : public OpRewritePattern<memref::CopyOp> {
154public:
156
157 LogicalResult matchAndRewrite(memref::CopyOp op,
158 PatternRewriter &rewriter) const final {
159 Value rcOutput = op.getTarget();
160 auto rc = rcOutput.getDefiningOp<memref::ReinterpretCastOp>();
161 if (!rc)
162 return rewriter.notifyMatchFailure(
163 op, "target is not a memref.reinterpret_cast");
164
165 if (!isScalarSlice(rc))
166 return rewriter.notifyMatchFailure(
167 op, "reinterpret_cast does not match scalar slice");
168
169 Location loc = op.getLoc();
170
171 Value src = op.getSource();
172 Value dst = rc.getSource();
173
174 auto dstType = cast<MemRefType>(dst.getType());
175 unsigned dstRank = dstType.getRank();
176
177 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
178
179 auto srcType = cast<MemRefType>(src.getType());
180 Repeated<Value> loadIndices(srcType.getRank(), zero);
181 auto offsets = rc.getMixedOffsets();
182 assert(offsets.size() == 1 && "Expecting single offset");
183 OpFoldResult offset = offsets[0];
184 Value storeOffset = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
185 unsigned offsetDim = dstType.getDimSize(0) == 1 ? dstRank - 1 : 0;
186 SmallVector<Value> storeIndices(dstRank, zero);
187 storeIndices[offsetDim] = storeOffset;
188 // If the only user of `rc` is the current Op (which is about to be erased),
189 // we can safely erase it.
190 if (rcOutput.hasOneUse())
191 rewriter.eraseOp(rc);
192
193 Value val = memref::LoadOp::create(rewriter, loc, src, loadIndices);
194 memref::StoreOp::create(rewriter, loc, val, dst, storeIndices);
195
196 rewriter.eraseOp(op);
197 return success();
198 }
199};
200
201/// Captures info about MemRefs that are effectively 1D (the leading or trailing
202/// dims are all 1). The only accepted non-unit dim is either the leading of the
203/// trailing dim.
204///
205/// Examples:
206/// memref<1x1x4xf32>, memref<4x1x1xf32>, memref<1x1x1xf32>
207///
208struct ShapeInfoFor1DMemRef {
209 // Are all dims == 1? `false` means that there is exactly one dim != 1.
210 bool allOnes = true;
211 // If there is a non-unit boundary dim, is it the leading or the trailing dim?
212 bool isLeadingDimNonUnit = false;
213};
214
215/// Returns information about a MemRef if it contains at most one non-unit
216/// dimension.
217///
218/// The single non-unit dimension, if present, must be on the left or right
219/// boundary. Rank-1 non-unit MemRefs are treated as being on both boundaries.
220static std::optional<ShapeInfoFor1DMemRef>
221getShapeInfoFor1DMemRef(MemRefType type) {
222 ArrayRef<int64_t> shape = type.getShape();
223 int64_t nonUnitCount =
224 llvm::count_if(shape, [](int64_t dim) { return dim != 1; });
225 // Return default values if missing non-unit dimension (all-ones MemRef).
226 if (nonUnitCount == 0)
227 return ShapeInfoFor1DMemRef{};
228 // Return no info if MemRef has more non-unit dimensions.
229 if (nonUnitCount > 1)
230 return std::nullopt;
231 // Return no info if MemRef has non-unit dimension in non-boundary positions.
232 if (shape.front() == 1 && shape.back() == 1)
233 return std::nullopt;
234
235 return ShapeInfoFor1DMemRef{/*allOnes=*/false,
236 /*isLeadingDimNonUnit=*/shape.front() != 1};
237}
238
239static bool hasStaticZeroOffset(memref::ReinterpretCastOp rc) {
240 ArrayRef<int64_t> offsets = rc.getStaticOffsets();
241 // FIXME: Despite what `getStaticOffsets` implies, `reinterpret_cast` takes
242 // only a single offset. That should be fixed at the op definition level.
243 assert(offsets.size() == 1 && "Expecting single offset");
244 return !ShapedType::isDynamic(offsets[0]) && offsets[0] == 0;
245}
246
247static std::optional<int64_t> getConstantIndex(Value v) {
248 if (auto cst = v.getDefiningOp<arith::ConstantIndexOp>())
249 return cst.value();
250 // Non-constant and dynamic indices
251 return std::nullopt;
252}
253
254/// Return true if input index is in bounds, i.e. `0 <= idx < upperBound`.
255/// Fully dynamic index values (i.e. non-constant) that cannot be analysed are
256/// treated as in-bounds.
257static bool isConstantIndexExplicitlyOutOfBounds(Value idx,
258 int64_t upperBound) {
259 // Only statically known `arith.constant` indices are checked here.
260 std::optional<int64_t> idxVal = getConstantIndex(idx);
261 return idxVal && (*idxVal < 0 || *idxVal >= upperBound);
262}
263
264/// Examples accepted by this shape restriction:
265/// memref<999xf32> <-> memref<1x1x999xf32>
266/// memref<1x108xf32> <-> memref<1x1x1x108xf32>
267/// memref<100x1xf32> <-> memref<100x1x1xf32>
268/// memref<1> <-> memref<1x1x1>
269///
270/// General reinterpret_casts are intentionally rejected.
271static bool isPureRankExpansionOrCollapsingRC(memref::ReinterpretCastOp rc) {
272 auto inputTy = cast<MemRefType>(rc.getSource().getType());
273 auto outputTy = cast<MemRefType>(rc.getResult().getType());
274
275 // Only zero, statically known offsets are accepted. Non-zero or dynamic
276 // offsets would require reasoning about storage shifts in the underlying
277 // reinterpret_cast, which this helper does not model.
278 if (!hasStaticZeroOffset(rc))
279 return false;
280
281 // Dynamic sizes/strides prevent precise reasoning about the underlying
282 // reinterpret_cast, so only fully static shape metadata is accepted.
283 if (llvm::any_of(rc.getStaticSizes(), ShapedType::isDynamic) ||
284 llvm::any_of(rc.getStaticStrides(), ShapedType::isDynamic))
285 return false;
286
287 // Only shapes with at most one non-unit dimension are accepted. This rules
288 // out more general multi-dimensional reinterpret_casts and restricts the
289 // helper to unit-dim insertion/removal around a single logical dimension.
290 std::optional<ShapeInfoFor1DMemRef> inputNonUnitDim =
291 getShapeInfoFor1DMemRef(inputTy);
292 std::optional<ShapeInfoFor1DMemRef> outputNonUnitDim =
293 getShapeInfoFor1DMemRef(outputTy);
294 // Bail out if either type does not satisfy the single-boundary-non-unit-dim
295 // restriction described above.
296 if (!inputNonUnitDim || !outputNonUnitDim)
297 return false;
298
299 // The source and result must either both have a single non-unit dimension
300 // or both be all-ones.
301 if (inputNonUnitDim->allOnes != outputNonUnitDim->allOnes)
302 return false;
303 if (inputNonUnitDim->allOnes)
304 return true;
305
306 // The preserved non-unit dimension must have the same size.
307 if (inputTy.getDimSize(
308 inputNonUnitDim->isLeadingDimNonUnit ? 0 : inputTy.getRank() - 1) !=
309 outputTy.getDimSize(
310 outputNonUnitDim->isLeadingDimNonUnit ? 0 : outputTy.getRank() - 1))
311 return false;
312
313 // If both sides have rank > 1, the non-unit dimension must be on the same
314 // boundary. Rank-1 MemRefs are accepted against either boundary.
315 if (inputTy.getRank() != 1 && outputTy.getRank() != 1 &&
316 inputNonUnitDim->isLeadingDimNonUnit !=
317 outputNonUnitDim->isLeadingDimNonUnit)
318 return false;
319
320 return true;
321}
322
323/// Checks statically known and constant indices accessed by a load from a pure
324/// rank expansion/collapsing to ensure in-bounds only access. Fully dynamic
325/// indices are skipped (there is no way to verify them).
326[[maybe_unused]] static bool areIndicesInBounds(memref::LoadOp load) {
327 auto rc = load.getMemRef().getDefiningOp<memref::ReinterpretCastOp>();
328 auto rcOutputTy = cast<MemRefType>(rc.getResult().getType());
329
330 for (auto [pos, idx] : llvm::enumerate(load.getIndices())) {
331 // FIXME: This should be ensured by the memref.load semantics.
332 // In the long term, this sanity-check may live in the same debug-only
333 // checks as `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS`. This rejects
334 // only explicit constant OOB indices. Dynamic/non-constant indices are not
335 // filtered here.
336 if (isConstantIndexExplicitlyOutOfBounds(idx, rcOutputTy.getDimSize(pos)))
337 return false;
338 }
339 return true;
340}
341
342/// Rewrites `memref.load` through a pure rank-only `reinterpret_cast` by
343/// mapping the load indices directly onto the source MemRef.
344
345/// Shape restriction gated by isPureRankExpansionOrCollapsingRC().
346///
347/// BEFORE (rank expansion)
348/// %view = memref.reinterpret_cast %src
349/// : memref<Nxf32> to memref<1x1xNxf32>
350/// %v = memref.load %view[%c0, %c0, %i] : memref<1x1xNxf32>
351///
352/// AFTER
353/// %v = memref.load %src[%i] : memref<Nxf32>
354///
355/// BEFORE (rank collapsing)
356/// %view = memref.reinterpret_cast %src
357/// : memref<1x1xNxf32> to memref<Nxf32>
358/// %v = memref.load %view[%i] : memref<Nxf32>
359///
360/// AFTER
361/// %c0 = arith.constant 0 : index
362/// %v = memref.load %src[%c0, %c0, %i] : memref<1x1xNxf32>
363struct RewriteLoadFromReinterpretCast
364 : public OpRewritePattern<memref::LoadOp> {
365public:
367
368 LogicalResult matchAndRewrite(memref::LoadOp op,
369 PatternRewriter &rewriter) const override {
370 auto rc = op.getMemRef().getDefiningOp<memref::ReinterpretCastOp>();
371 if (!rc)
372 return rewriter.notifyMatchFailure(
373 op, "target is not a memref.reinterpret_cast");
374 if (!isPureRankExpansionOrCollapsingRC(rc))
375 return rewriter.notifyMatchFailure(
376 op, "reinterpret_cast is not a pure rank expansion or collapsing of "
377 "a single dimension");
378
379 assert(areIndicesInBounds(op) &&
380 "load from reinterpret_cast indexes out of bounds!");
381
382 auto rcOutputTy = cast<MemRefType>(rc.getResult().getType());
383 auto rcInputTy = cast<MemRefType>(rc.getSource().getType());
384
385 int64_t rcOutputRank = rcOutputTy.getRank();
386 int64_t rcInputRank = rcInputTy.getRank();
387
388 SmallVector<Value> idxs(op.getIndices().begin(), op.getIndices().end());
389 SmallVector<Value> rcInputIdxs;
390 rcInputIdxs.reserve(rcInputRank);
391
392 // The rewrite only supports reinterpret_casts with at most one non-unit
393 // dimension, located at the left or right boundary.
394 //
395 // The higher-rank side tells which side the reinterpret_cast has
396 // expanded/collapsed.
397 //
398 // expansion: rcOutput has the higher rank
399 // collapsing : rcInput has the higher rank
400 //
401 // Example:
402 // memref<999> -> memref<1x1x999> : leading extra dims
403 // memref<999x1x1> -> memref<999> : trailing extra dims
404 MemRefType expandedTy =
405 rcOutputRank >= rcInputRank ? rcOutputTy : rcInputTy;
406 std::optional<ShapeInfoFor1DMemRef> expandedNonUnitDim =
407 getShapeInfoFor1DMemRef(expandedTy);
408 assert(expandedNonUnitDim && "expected a single boundary non-unit dim");
409 bool keepLeadingIndices = expandedNonUnitDim->isLeadingDimNonUnit;
410
411 if (rcOutputRank >= rcInputRank) {
412 // Rank expansion:
413 // memref<N> -> memref<1x1xN> : keep the last rcInputRank indices
414 // memref<N> -> memref<Nx1x1> : keep the first rcInputRank indices
415 // memref<1> -> memref<1x1x1> : all indices are zero
416 //
417 // Any discarded indices are known to be zero from
418 // areIndicesInBounds().
419 int64_t firstKeptPos =
420 keepLeadingIndices ? 0 : rcOutputRank - rcInputRank;
421 rcInputIdxs.append(idxs.begin() + firstKeptPos,
422 idxs.begin() + firstKeptPos + rcInputRank);
423 } else {
424 // Rank collapsing:
425 // memref<1x1xN> -> memref<N> : reinsert leading zeros
426 // memref<Nx1x1> -> memref<N> : reinsert trailing zeros
427 // memref<1x1x1> -> memref<1> : all indices are zero
428 //
429 // The collapsed-away dimensions are unit dims, so re-adding them with
430 // zero indices preserves semantics.
431 Value c0 = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
432 int64_t rankDiff = rcInputRank - rcOutputRank;
433
434 if (keepLeadingIndices) {
435 rcInputIdxs.append(idxs.begin(), idxs.end());
436 rcInputIdxs.append(rankDiff, c0);
437 } else {
438 rcInputIdxs.append(rankDiff, c0);
439 rcInputIdxs.append(idxs.begin(), idxs.end());
440 }
441 }
442
443 assert(rcInputIdxs.size() == static_cast<size_t>(rcInputRank) &&
444 "Incorrect number of indices!");
445
446 auto rcInput = rc.getSource();
447 // If the only user of rc is the current Op (which is about to be erased),
448 // we can safely erase it.
449 if (rc.getResult().hasOneUse())
450 rewriter.eraseOp(rc);
451 rewriter.replaceOpWithNewOp<memref::LoadOp>(op, rcInput, rcInputIdxs);
452 return success();
453 }
454};
455
456struct ElideReinterpretCastPass
458 ElideReinterpretCastPass> {
459 void runOnOperation() override {
460 MLIRContext &ctx = getContext();
461
462 RewritePatternSet patterns(&ctx);
464 ConversionTarget target(ctx);
465 target.addDynamicallyLegalOp<memref::CopyOp>([](memref::CopyOp op) {
466 auto rc = op.getTarget().getDefiningOp<memref::ReinterpretCastOp>();
467 if (!rc)
468 return true;
469 return !isScalarSlice(rc);
470 });
471 target.addDynamicallyLegalOp<memref::LoadOp>([](memref::LoadOp op) {
472 auto rc = op.getMemRef().getDefiningOp<memref::ReinterpretCastOp>();
473 if (!rc)
474 return true;
475 return !isPureRankExpansionOrCollapsingRC(rc);
476 });
477 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
478 if (failed(applyPartialConversion(getOperation(), target,
479 std::move(patterns))))
480 signalPassFailure();
481 }
482};
483
484} // namespace
485
487 RewritePatternSet &patterns) {
488 patterns.add<CopyToScalarLoadAndStore, RewriteLoadFromReinterpretCast>(
489 patterns.getContext());
490}
return success()
b getContext())
auto load
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
void populateElideReinterpretCastPatterns(RewritePatternSet &patterns)
Collects a set of patterns that bypass memref.reinterpet_cast Ops.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...