19 #define GEN_PASS_DEF_EXPANDREALLOCPASS
20 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
54 ExpandReallocOpPattern(
MLIRContext *ctx,
bool emitDeallocs)
57 LogicalResult matchAndRewrite(memref::ReallocOp op,
60 assert(op.getType().getRank() == 1 &&
61 "result MemRef must have exactly one rank");
62 assert(op.getSource().getType().getRank() == 1 &&
63 "source MemRef must have exactly one rank");
64 assert(op.getType().getLayout().isIdentity() &&
65 "result MemRef must have identity layout (or none)");
66 assert(op.getSource().getType().getLayout().isIdentity() &&
67 "source MemRef must have identity layout (or none)");
71 cast<BaseMemRefType>(op.getSource().getType()).getDimSize(0);
72 OpFoldResult currSize = rewriter.getIndexAttr(inputSize);
73 if (ShapedType::isDynamic(inputSize)) {
75 rewriter.getIndexAttr(0));
76 currSize = memref::DimOp::create(rewriter, loc, op.getSource(), dimZero)
82 cast<BaseMemRefType>(op.getResult().getType()).getDimSize(0);
83 OpFoldResult targetSize = ShapedType::isDynamic(outputSize)
85 : rewriter.getIndexAttr(outputSize);
91 Value cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
93 auto ifOp = scf::IfOp::create(
100 if (op.getDynamicResultSize())
101 dynamicSizeOperands.push_back(op.getDynamicResultSize());
103 Value newAlloc = memref::AllocOp::create(
104 builder, loc, op.getResult().getType(), dynamicSizeOperands,
105 op.getAlignmentAttr());
110 Value subview = memref::SubViewOp::create(
111 builder, loc, newAlloc,
115 memref::CopyOp::create(builder, loc, op.getSource(), subview);
120 memref::DeallocOp::create(builder, loc, op.getSource());
122 scf::YieldOp::create(builder, loc, newAlloc);
130 Value casted = memref::ReinterpretCastOp::create(
131 builder, loc, cast<MemRefType>(op.getResult().getType()),
132 op.getSource(), rewriter.getIndexAttr(0),
135 scf::YieldOp::create(builder, loc, casted);
138 rewriter.replaceOp(op, ifOp.getResult(0));
143 const bool emitDeallocs;
146 struct ExpandReallocPass
147 :
public memref::impl::ExpandReallocPassBase<ExpandReallocPass> {
150 void runOnOperation()
override {
157 target.addLegalDialect<arith::ArithDialect, scf::SCFDialect,
158 memref::MemRefDialect>();
159 target.addIllegalOp<memref::ReallocOp>();
static MLIRContext * getContext(OpFoldResult val)
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateExpandReallocPatterns(RewritePatternSet &patterns, bool emitDeallocs=true)
Appends patterns for expanding memref.realloc operations.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...