19#define GEN_PASS_DEF_EXPANDREALLOCPASS
20#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
54 ExpandReallocOpPattern(MLIRContext *ctx,
bool emitDeallocs)
55 : OpRewritePattern(ctx), emitDeallocs(emitDeallocs) {}
57 LogicalResult matchAndRewrite(memref::ReallocOp op,
58 PatternRewriter &rewriter)
const final {
59 Location loc = op.getLoc();
60 assert(op.getType().getRank() == 1 &&
61 "result MemRef must have exactly one rank");
62 assert(op.getSource().getType().getRank() == 1 &&
63 "source MemRef must have exactly one rank");
64 assert(op.getType().getLayout().isIdentity() &&
65 "result MemRef must have identity layout (or none)");
66 assert(op.getSource().getType().getLayout().isIdentity() &&
67 "source MemRef must have identity layout (or none)");
71 cast<BaseMemRefType>(op.getSource().getType()).getDimSize(0);
72 OpFoldResult currSize = rewriter.getIndexAttr(inputSize);
73 if (ShapedType::isDynamic(inputSize)) {
75 rewriter.getIndexAttr(0));
76 currSize = memref::DimOp::create(rewriter, loc, op.getSource(), dimZero)
82 cast<BaseMemRefType>(op.getResult().getType()).getDimSize(0);
83 OpFoldResult targetSize = ShapedType::isDynamic(outputSize)
84 ? OpFoldResult{op.getDynamicResultSize()}
85 : rewriter.getIndexAttr(outputSize);
91 Value cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
93 auto ifOp = scf::IfOp::create(
95 [&](OpBuilder &builder, Location loc) {
99 SmallVector<Value> dynamicSizeOperands;
100 if (op.getDynamicResultSize())
101 dynamicSizeOperands.push_back(op.getDynamicResultSize());
103 Value newAlloc = memref::AllocOp::create(
104 builder, loc, op.getResult().getType(), dynamicSizeOperands,
105 op.getAlignmentAttr());
110 Value subview = memref::SubViewOp::create(
111 builder, loc, newAlloc,
112 ArrayRef<OpFoldResult>{rewriter.getIndexAttr(0)},
113 ArrayRef<OpFoldResult>{currSize},
114 ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});
115 memref::CopyOp::create(builder, loc, op.getSource(), subview);
120 memref::DeallocOp::create(builder, loc, op.getSource());
122 scf::YieldOp::create(builder, loc, newAlloc);
124 [&](OpBuilder &builder, Location loc) {
130 Value casted = memref::ReinterpretCastOp::create(
131 builder, loc, cast<MemRefType>(op.getResult().getType()),
132 op.getSource(), rewriter.getIndexAttr(0),
133 ArrayRef<OpFoldResult>{targetSize},
134 ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});
135 scf::YieldOp::create(builder, loc, casted);
138 rewriter.replaceOp(op, ifOp.getResult(0));
143 const bool emitDeallocs;
146struct ExpandReallocPass
147 :
public memref::impl::ExpandReallocPassBase<ExpandReallocPass> {
150 void runOnOperation()
override {
155 ConversionTarget
target(ctx);
157 target.addLegalDialect<arith::ArithDialect, scf::SCFDialect,
158 memref::MemRefDialect>();
159 target.addIllegalOp<memref::ReallocOp>();
160 if (
failed(applyPartialConversion(getOperation(),
target,
void populateExpandReallocPatterns(RewritePatternSet &patterns, bool emitDeallocs=true)
Appends patterns for expanding memref.realloc operations.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...