16#include "llvm/ADT/Repeated.h"
21#define GEN_PASS_DEF_ELIDEREINTERPRETCASTPASS
22#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
74static bool isScalarSlice(memref::ReinterpretCastOp rc) {
75 auto rcInputTy = dyn_cast<MemRefType>(rc.getSource().getType());
76 auto rcOutputTy = dyn_cast<MemRefType>(rc.getType());
79 if (!rcInputTy.getLayout().isIdentity())
83 unsigned srcRank = rcInputTy.getRank();
84 if (srcRank != rcOutputTy.getRank())
90 if (!llvm::all_of(rcOutputTy.getShape(),
91 [](
int64_t dim) { return dim == 1; }))
95 if (!llvm::all_of(sizes, [](
int64_t size) {
96 return !ShapedType::isDynamic(size) && size == 1;
103 if (rcOutputTy.getDimSize(0) > 1)
108 std::count_if(rcInputTy.getShape().begin(), rcInputTy.getShape().end(),
109 [](
int dim) { return dim != 1; });
110 return nonUnitCount == 1;
155 LogicalResult matchAndRewrite(memref::CopyOp op,
156 PatternRewriter &rewriter)
const final {
157 Value rcOutput = op.getTarget();
158 auto rc = rcOutput.
getDefiningOp<memref::ReinterpretCastOp>();
160 return rewriter.notifyMatchFailure(
161 op,
"target is not a memref.reinterpret_cast");
163 if (!isScalarSlice(rc))
164 return rewriter.notifyMatchFailure(
165 op,
"reinterpret_cast does not match scalar slice");
167 Location loc = op.
getLoc();
169 Value src = op.getSource();
170 Value dst = rc.getSource();
172 auto dstType = cast<MemRefType>(dst.
getType());
173 unsigned dstRank = dstType.getRank();
177 auto srcType = cast<MemRefType>(src.
getType());
178 Repeated<Value> loadIndices(srcType.getRank(), zero);
179 auto offsets = rc.getMixedOffsets();
180 assert(offsets.size() == 1 &&
"Expecting single offset");
181 OpFoldResult offset = offsets[0];
183 unsigned offsetDim = dstType.getDimSize(0) == 1 ? dstRank - 1 : 0;
184 SmallVector<Value> storeIndices(dstRank, zero);
185 storeIndices[offsetDim] = storeOffset;
189 rewriter.eraseOp(rc);
191 Value val = memref::LoadOp::create(rewriter, loc, src, loadIndices);
192 memref::StoreOp::create(rewriter, loc, val, dst, storeIndices);
194 rewriter.eraseOp(op);
199struct ElideReinterpretCastPass
200 :
public memref::impl::ElideReinterpretCastPassBase<
201 ElideReinterpretCastPass> {
202 void runOnOperation()
override {
205 RewritePatternSet patterns(&ctx);
207 ConversionTarget
target(ctx);
208 target.addDynamicallyLegalOp<memref::CopyOp>([](memref::CopyOp op) {
209 auto rc = op.getTarget().getDefiningOp<memref::ReinterpretCastOp>();
212 return !isScalarSlice(rc);
214 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
215 if (
failed(applyPartialConversion(getOperation(),
target,
216 std::move(patterns))))
225 patterns.
add<CopyToScalarLoadAndStore>(patterns.
getContext());
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
void populateElideReinterpretCastPatterns(RewritePatternSet &patterns)
Collects a set of patterns that bypass memref.reinterpet_cast Ops.
Include the generated interface declarations.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...