19 #define GEN_PASS_DEF_EXPANDREALLOC
20 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
54 ExpandReallocOpPattern(
MLIRContext *ctx,
bool emitDeallocs)
57 LogicalResult matchAndRewrite(memref::ReallocOp op,
60 assert(op.getType().getRank() == 1 &&
61 "result MemRef must have exactly one rank");
62 assert(op.getSource().getType().getRank() == 1 &&
63 "source MemRef must have exactly one rank");
64 assert(op.getType().getLayout().isIdentity() &&
65 "result MemRef must have identity layout (or none)");
66 assert(op.getSource().getType().getLayout().isIdentity() &&
67 "source MemRef must have identity layout (or none)");
71 cast<BaseMemRefType>(op.getSource().getType()).getDimSize(0);
72 OpFoldResult currSize = rewriter.getIndexAttr(inputSize);
73 if (ShapedType::isDynamic(inputSize)) {
75 rewriter.getIndexAttr(0));
76 currSize = rewriter.create<memref::DimOp>(loc, op.getSource(), dimZero)
83 OpFoldResult targetSize = ShapedType::isDynamic(outputSize)
85 : rewriter.getIndexAttr(outputSize);
91 Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
93 auto ifOp = rewriter.create<scf::IfOp>(
100 if (op.getDynamicResultSize())
101 dynamicSizeOperands.push_back(op.getDynamicResultSize());
103 Value newAlloc = builder.create<memref::AllocOp>(
105 op.getAlignmentAttr());
110 Value subview = builder.create<memref::SubViewOp>(
114 builder.create<memref::CopyOp>(loc, op.getSource(), subview);
119 builder.create<memref::DeallocOp>(loc, op.getSource());
121 builder.create<scf::YieldOp>(loc, newAlloc);
129 Value casted = builder.
create<memref::ReinterpretCastOp>(
133 builder.
create<scf::YieldOp>(loc, casted);
136 rewriter.replaceOp(op, ifOp.getResult(0));
141 const bool emitDeallocs;
144 struct ExpandReallocPass
145 :
public memref::impl::ExpandReallocBase<ExpandReallocPass> {
146 ExpandReallocPass(
bool emitDeallocs)
147 : memref::
impl::ExpandReallocBase<ExpandReallocPass>() {
148 this->emitDeallocs.setValue(emitDeallocs);
150 void runOnOperation()
override {
157 target.addLegalDialect<arith::ArithDialect, scf::SCFDialect,
158 memref::MemRefDialect>();
159 target.addIllegalOp<memref::ReallocOp>();
161 std::move(patterns))))
170 patterns.
add<ExpandReallocOpPattern>(patterns.
getContext(), emitDeallocs);
174 return std::make_unique<ExpandReallocPass>(emitDeallocs);
static MLIRContext * getContext(OpFoldResult val)
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
std::unique_ptr< Pass > createExpandReallocPass(bool emitDeallocs=true)
Creates an operation pass to expand memref.realloc operations into their components.
void populateExpandReallocPatterns(RewritePatternSet &patterns, bool emitDeallocs=true)
Appends patterns for expanding memref.realloc operations.
Include the generated interface declarations.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...