17#include "llvm/ADT/Repeated.h"
23#define GEN_PASS_DEF_ELIDEREINTERPRETCASTPASS
24#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
76static bool isScalarSlice(memref::ReinterpretCastOp rc) {
77 auto rcInputTy = dyn_cast<MemRefType>(rc.getSource().getType());
78 auto rcOutputTy = dyn_cast<MemRefType>(rc.getType());
81 if (!rcInputTy.getLayout().isIdentity())
85 unsigned srcRank = rcInputTy.getRank();
86 if (srcRank != rcOutputTy.getRank())
92 if (!llvm::all_of(rcOutputTy.getShape(),
93 [](
int64_t dim) { return dim == 1; }))
97 if (!llvm::all_of(sizes, [](
int64_t size) {
98 return !ShapedType::isDynamic(size) && size == 1;
105 if (rcOutputTy.getDimSize(0) > 1)
110 std::count_if(rcInputTy.getShape().begin(), rcInputTy.getShape().end(),
111 [](
int dim) { return dim != 1; });
112 return nonUnitCount == 1;
157 LogicalResult matchAndRewrite(memref::CopyOp op,
158 PatternRewriter &rewriter)
const final {
159 Value rcOutput = op.getTarget();
160 auto rc = rcOutput.
getDefiningOp<memref::ReinterpretCastOp>();
162 return rewriter.notifyMatchFailure(
163 op,
"target is not a memref.reinterpret_cast");
165 if (!isScalarSlice(rc))
166 return rewriter.notifyMatchFailure(
167 op,
"reinterpret_cast does not match scalar slice");
169 Location loc = op.
getLoc();
171 Value src = op.getSource();
172 Value dst = rc.getSource();
174 auto dstType = cast<MemRefType>(dst.
getType());
175 unsigned dstRank = dstType.getRank();
179 auto srcType = cast<MemRefType>(src.
getType());
180 Repeated<Value> loadIndices(srcType.getRank(), zero);
181 auto offsets = rc.getMixedOffsets();
182 assert(offsets.size() == 1 &&
"Expecting single offset");
183 OpFoldResult offset = offsets[0];
185 unsigned offsetDim = dstType.getDimSize(0) == 1 ? dstRank - 1 : 0;
186 SmallVector<Value> storeIndices(dstRank, zero);
187 storeIndices[offsetDim] = storeOffset;
191 rewriter.eraseOp(rc);
193 Value val = memref::LoadOp::create(rewriter, loc, src, loadIndices);
194 memref::StoreOp::create(rewriter, loc, val, dst, storeIndices);
196 rewriter.eraseOp(op);
208struct ShapeInfoFor1DMemRef {
212 bool isLeadingDimNonUnit =
false;
220static std::optional<ShapeInfoFor1DMemRef>
221getShapeInfoFor1DMemRef(MemRefType type) {
224 llvm::count_if(
shape, [](
int64_t dim) {
return dim != 1; });
226 if (nonUnitCount == 0)
227 return ShapeInfoFor1DMemRef{};
229 if (nonUnitCount > 1)
235 return ShapeInfoFor1DMemRef{
false,
239static bool hasStaticZeroOffset(memref::ReinterpretCastOp rc) {
243 assert(offsets.size() == 1 &&
"Expecting single offset");
244 return !ShapedType::isDynamic(offsets[0]) && offsets[0] == 0;
247static std::optional<int64_t> getConstantIndex(
Value v) {
257static bool isConstantIndexExplicitlyOutOfBounds(
Value idx,
260 std::optional<int64_t> idxVal = getConstantIndex(idx);
261 return idxVal && (*idxVal < 0 || *idxVal >= upperBound);
271static bool isPureRankExpansionOrCollapsingRC(memref::ReinterpretCastOp rc) {
272 auto inputTy = cast<MemRefType>(rc.getSource().getType());
273 auto outputTy = cast<MemRefType>(rc.getResult().getType());
278 if (!hasStaticZeroOffset(rc))
283 if (llvm::any_of(rc.getStaticSizes(), ShapedType::isDynamic) ||
284 llvm::any_of(rc.getStaticStrides(), ShapedType::isDynamic))
290 std::optional<ShapeInfoFor1DMemRef> inputNonUnitDim =
291 getShapeInfoFor1DMemRef(inputTy);
292 std::optional<ShapeInfoFor1DMemRef> outputNonUnitDim =
293 getShapeInfoFor1DMemRef(outputTy);
296 if (!inputNonUnitDim || !outputNonUnitDim)
301 if (inputNonUnitDim->allOnes != outputNonUnitDim->allOnes)
303 if (inputNonUnitDim->allOnes)
307 if (inputTy.getDimSize(
308 inputNonUnitDim->isLeadingDimNonUnit ? 0 : inputTy.getRank() - 1) !=
310 outputNonUnitDim->isLeadingDimNonUnit ? 0 : outputTy.getRank() - 1))
315 if (inputTy.getRank() != 1 && outputTy.getRank() != 1 &&
316 inputNonUnitDim->isLeadingDimNonUnit !=
317 outputNonUnitDim->isLeadingDimNonUnit)
326[[maybe_unused]]
static bool areIndicesInBounds(memref::LoadOp
load) {
327 auto rc =
load.getMemRef().getDefiningOp<memref::ReinterpretCastOp>();
328 auto rcOutputTy = cast<MemRefType>(rc.getResult().getType());
330 for (
auto [pos, idx] : llvm::enumerate(
load.getIndices())) {
336 if (isConstantIndexExplicitlyOutOfBounds(idx, rcOutputTy.getDimSize(pos)))
363struct RewriteLoadFromReinterpretCast
368 LogicalResult matchAndRewrite(memref::LoadOp op,
369 PatternRewriter &rewriter)
const override {
370 auto rc = op.getMemRef().getDefiningOp<memref::ReinterpretCastOp>();
373 op,
"target is not a memref.reinterpret_cast");
374 if (!isPureRankExpansionOrCollapsingRC(rc))
376 op,
"reinterpret_cast is not a pure rank expansion or collapsing of "
377 "a single dimension");
379 assert(areIndicesInBounds(op) &&
380 "load from reinterpret_cast indexes out of bounds!");
382 auto rcOutputTy = cast<MemRefType>(rc.getResult().getType());
383 auto rcInputTy = cast<MemRefType>(rc.getSource().getType());
385 int64_t rcOutputRank = rcOutputTy.getRank();
386 int64_t rcInputRank = rcInputTy.getRank();
388 SmallVector<Value> idxs(op.getIndices().begin(), op.getIndices().end());
389 SmallVector<Value> rcInputIdxs;
390 rcInputIdxs.reserve(rcInputRank);
404 MemRefType expandedTy =
405 rcOutputRank >= rcInputRank ? rcOutputTy : rcInputTy;
406 std::optional<ShapeInfoFor1DMemRef> expandedNonUnitDim =
407 getShapeInfoFor1DMemRef(expandedTy);
408 assert(expandedNonUnitDim &&
"expected a single boundary non-unit dim");
409 bool keepLeadingIndices = expandedNonUnitDim->isLeadingDimNonUnit;
411 if (rcOutputRank >= rcInputRank) {
419 int64_t firstKeptPos =
420 keepLeadingIndices ? 0 : rcOutputRank - rcInputRank;
421 rcInputIdxs.append(idxs.begin() + firstKeptPos,
422 idxs.begin() + firstKeptPos + rcInputRank);
432 int64_t rankDiff = rcInputRank - rcOutputRank;
434 if (keepLeadingIndices) {
435 rcInputIdxs.append(idxs.begin(), idxs.end());
436 rcInputIdxs.append(rankDiff, c0);
438 rcInputIdxs.append(rankDiff, c0);
439 rcInputIdxs.append(idxs.begin(), idxs.end());
443 assert(rcInputIdxs.size() ==
static_cast<size_t>(rcInputRank) &&
444 "Incorrect number of indices!");
446 auto rcInput = rc.getSource();
449 if (rc.getResult().hasOneUse())
456struct ElideReinterpretCastPass
458 ElideReinterpretCastPass> {
459 void runOnOperation()
override {
462 RewritePatternSet patterns(&ctx);
464 ConversionTarget
target(ctx);
465 target.addDynamicallyLegalOp<memref::CopyOp>([](memref::CopyOp op) {
466 auto rc = op.getTarget().getDefiningOp<memref::ReinterpretCastOp>();
469 return !isScalarSlice(rc);
471 target.addDynamicallyLegalOp<memref::LoadOp>([](memref::LoadOp op) {
472 auto rc = op.getMemRef().getDefiningOp<memref::ReinterpretCastOp>();
475 return !isPureRankExpansionOrCollapsingRC(rc);
477 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
478 if (
failed(applyPartialConversion(getOperation(),
target,
479 std::move(patterns))))
488 patterns.
add<CopyToScalarLoadAndStore, RewriteLoadFromReinterpretCast>(
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Specialization of arith.constant op that returns an integer of index type.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
void populateElideReinterpretCastPatterns(RewritePatternSet &patterns)
Collects a set of patterns that bypass memref.reinterpet_cast Ops.
Include the generated interface declarations.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...