19#define GEN_PASS_DEF_EXPANDREALLOCPASS
20#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
54 ExpandReallocOpPattern(MLIRContext *ctx,
bool emitDeallocs)
55 : OpRewritePattern(ctx), emitDeallocs(emitDeallocs) {}
57 LogicalResult matchAndRewrite(memref::ReallocOp op,
58 PatternRewriter &rewriter)
const final {
59 Location loc = op.getLoc();
60 assert(op.getType().getRank() == 1 &&
61 "result MemRef must have exactly one rank");
62 assert(op.getSource().getType().getRank() == 1 &&
63 "source MemRef must have exactly one rank");
64 assert(op.getType().getLayout().isIdentity() &&
65 "result MemRef must have identity layout (or none)");
66 assert(op.getSource().getType().getLayout().isIdentity() &&
67 "source MemRef must have identity layout (or none)");
71 cast<BaseMemRefType>(op.getSource().getType()).getDimSize(0);
72 OpFoldResult currSize = rewriter.getIndexAttr(inputSize);
73 if (ShapedType::isDynamic(inputSize)) {
75 rewriter.getIndexAttr(0));
76 currSize = memref::DimOp::create(rewriter, loc, op.getSource(), dimZero)
82 cast<BaseMemRefType>(op.getResult().getType()).getDimSize(0);
83 OpFoldResult targetSize = ShapedType::isDynamic(outputSize)
84 ? OpFoldResult{op.getDynamicResultSize()}
85 : rewriter.getIndexAttr(outputSize);
91 Value cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
93 auto ifOp = scf::IfOp::create(
95 [&](OpBuilder &builder, Location loc) {
99 SmallVector<Value> dynamicSizeOperands;
100 if (op.getDynamicResultSize())
101 dynamicSizeOperands.push_back(op.getDynamicResultSize());
103 Value newAlloc = memref::AllocOp::create(
104 builder, loc, op.getResult().getType(), dynamicSizeOperands,
105 op.getAlignmentAttr());
110 Value subview = memref::SubViewOp::create(
111 builder, loc, newAlloc,
112 ArrayRef<OpFoldResult>{rewriter.getIndexAttr(0)},
113 ArrayRef<OpFoldResult>{currSize},
114 ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});
115 memref::CopyOp::create(builder, loc, op.getSource(), subview);
120 memref::DeallocOp::create(builder, loc, op.getSource());
122 scf::YieldOp::create(builder, loc, newAlloc);
130 Value casted = memref::ReinterpretCastOp::create(
131 builder, loc, cast<MemRefType>(op.getResult().getType()),
132 op.getSource(), rewriter.getIndexAttr(0),
135 scf::YieldOp::create(builder, loc, casted);
138 rewriter.replaceOp(op, ifOp.getResult(0));
146struct ExpandReallocPass
157 target.addLegalDialect<arith::ArithDialect, scf::SCFDialect,
158 memref::MemRefDialect>();
159 target.addIllegalOp<memref::ReallocOp>();
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
OpT getOperation()
Return the current operation being transformed.
virtual void runOnOperation()=0
The polymorphic API that runs the pass over the currently held operation.
void signalPassFailure()
Signal that some invariant was broken when running.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
::mlir::Pass::Option< bool > emitDeallocs
ExpandReallocPassBase Base
void populateExpandReallocPatterns(RewritePatternSet &patterns, bool emitDeallocs=true)
Appends patterns for expanding memref.realloc operations.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...