19 #define GEN_PASS_DEF_EXPANDREALLOC
20 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
54 ExpandReallocOpPattern(
MLIRContext *ctx,
bool emitDeallocs)
57 LogicalResult matchAndRewrite(memref::ReallocOp op,
60 assert(op.getType().getRank() == 1 &&
61 "result MemRef must have exactly one rank");
62 assert(op.getSource().getType().getRank() == 1 &&
63 "source MemRef must have exactly one rank");
64 assert(op.getType().getLayout().isIdentity() &&
65 "result MemRef must have identity layout (or none)");
66 assert(op.getSource().getType().getLayout().isIdentity() &&
67 "source MemRef must have identity layout (or none)");
71 cast<BaseMemRefType>(op.getSource().getType()).getDimSize(0);
72 OpFoldResult currSize = rewriter.getIndexAttr(inputSize);
73 if (ShapedType::isDynamic(inputSize)) {
75 rewriter.getIndexAttr(0));
76 currSize = rewriter.create<memref::DimOp>(loc, op.getSource(), dimZero)
82 cast<BaseMemRefType>(op.getResult().getType()).getDimSize(0);
83 OpFoldResult targetSize = ShapedType::isDynamic(outputSize)
85 : rewriter.getIndexAttr(outputSize);
91 Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
93 auto ifOp = rewriter.create<scf::IfOp>(
100 if (op.getDynamicResultSize())
101 dynamicSizeOperands.push_back(op.getDynamicResultSize());
103 Value newAlloc = builder.create<memref::AllocOp>(
104 loc, op.getResult().getType(), dynamicSizeOperands,
105 op.getAlignmentAttr());
110 Value subview = builder.create<memref::SubViewOp>(
114 builder.create<memref::CopyOp>(loc, op.getSource(), subview);
119 builder.create<memref::DeallocOp>(loc, op.getSource());
121 builder.create<scf::YieldOp>(loc, newAlloc);
129 Value casted = builder.
create<memref::ReinterpretCastOp>(
130 loc, cast<MemRefType>(op.getResult().getType()), op.getSource(),
133 builder.
create<scf::YieldOp>(loc, casted);
136 rewriter.replaceOp(op, ifOp.getResult(0));
141 const bool emitDeallocs;
144 struct ExpandReallocPass
145 :
public memref::impl::ExpandReallocBase<ExpandReallocPass> {
146 ExpandReallocPass(
bool emitDeallocs)
147 : memref::
impl::ExpandReallocBase<ExpandReallocPass>() {
148 this->emitDeallocs.setValue(emitDeallocs);
150 void runOnOperation()
override {
157 target.addLegalDialect<arith::ArithDialect, scf::SCFDialect,
158 memref::MemRefDialect>();
159 target.addIllegalOp<memref::ReallocOp>();
174 return std::make_unique<ExpandReallocPass>(emitDeallocs);
static MLIRContext * getContext(OpFoldResult val)
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
std::unique_ptr< Pass > createExpandReallocPass(bool emitDeallocs=true)
Creates an operation pass to expand memref.realloc operations into their components.
void populateExpandReallocPatterns(RewritePatternSet &patterns, bool emitDeallocs=true)
Appends patterns for expanding memref.realloc operations.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...