MLIR 23.0.0git
ExpandRealloc.cpp
Go to the documentation of this file.
1//===- ExpandRealloc.cpp - Expand memref.realloc ops into it's components -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11
16
17namespace mlir {
18namespace memref {
19#define GEN_PASS_DEF_EXPANDREALLOCPASS
20#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
21} // namespace memref
22} // namespace mlir
23
24using namespace mlir;
25
26namespace {
27
28/// The `realloc` operation performs a conditional allocation and copy to
29/// increase the size of a buffer if necessary. This pattern converts the
30/// `realloc` operation into this sequence of simpler operations.
31
32/// Example of an expansion:
33/// ```mlir
34/// %realloc = memref.realloc %alloc (%size) : memref<?xf32> to memref<?xf32>
35/// ```
36/// is expanded to
37/// ```mlir
38/// %c0 = arith.constant 0 : index
39/// %dim = memref.dim %alloc, %c0 : memref<?xf32>
40/// %is_old_smaller = arith.cmpi ult, %dim, %arg1
41/// %realloc = scf.if %is_old_smaller -> (memref<?xf32>) {
42/// %new_alloc = memref.alloc(%size) : memref<?xf32>
43/// %subview = memref.subview %new_alloc[0] [%dim] [1]
44/// memref.copy %alloc, %subview
45/// memref.dealloc %alloc
46/// scf.yield %alloc_0 : memref<?xf32>
47/// } else {
48/// %reinterpret_cast = memref.reinterpret_cast %alloc to
49/// offset: [0], sizes: [%size], strides: [1]
50/// scf.yield %reinterpret_cast : memref<?xf32>
51/// }
52/// ```
53struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {
54 ExpandReallocOpPattern(MLIRContext *ctx, bool emitDeallocs)
55 : OpRewritePattern(ctx), emitDeallocs(emitDeallocs) {}
56
57 LogicalResult matchAndRewrite(memref::ReallocOp op,
58 PatternRewriter &rewriter) const final {
59 Location loc = op.getLoc();
60 assert(op.getType().getRank() == 1 &&
61 "result MemRef must have exactly one rank");
62 assert(op.getSource().getType().getRank() == 1 &&
63 "source MemRef must have exactly one rank");
64 assert(op.getType().getLayout().isIdentity() &&
65 "result MemRef must have identity layout (or none)");
66 assert(op.getSource().getType().getLayout().isIdentity() &&
67 "source MemRef must have identity layout (or none)");
68
69 // Get the size of the original buffer.
70 int64_t inputSize =
71 cast<BaseMemRefType>(op.getSource().getType()).getDimSize(0);
72 OpFoldResult currSize = rewriter.getIndexAttr(inputSize);
73 if (ShapedType::isDynamic(inputSize)) {
74 Value dimZero = getValueOrCreateConstantIndexOp(rewriter, loc,
75 rewriter.getIndexAttr(0));
76 currSize = memref::DimOp::create(rewriter, loc, op.getSource(), dimZero)
77 .getResult();
78 }
79
80 // Get the requested size that the new buffer should have.
81 int64_t outputSize =
82 cast<BaseMemRefType>(op.getResult().getType()).getDimSize(0);
83 OpFoldResult targetSize = ShapedType::isDynamic(outputSize)
84 ? OpFoldResult{op.getDynamicResultSize()}
85 : rewriter.getIndexAttr(outputSize);
86
87 // Only allocate a new buffer and copy over the values in the old buffer if
88 // the old buffer is smaller than the requested size.
89 Value lhs = getValueOrCreateConstantIndexOp(rewriter, loc, currSize);
90 Value rhs = getValueOrCreateConstantIndexOp(rewriter, loc, targetSize);
91 Value cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
92 lhs, rhs);
93 auto ifOp = scf::IfOp::create(
94 rewriter, loc, cond,
95 [&](OpBuilder &builder, Location loc) {
96 // Allocate the new buffer. If it is a dynamic memref we need to pass
97 // an additional operand for the size at runtime, otherwise the static
98 // size is encoded in the result type.
99 SmallVector<Value> dynamicSizeOperands;
100 if (op.getDynamicResultSize())
101 dynamicSizeOperands.push_back(op.getDynamicResultSize());
102
103 Value newAlloc = memref::AllocOp::create(
104 builder, loc, op.getResult().getType(), dynamicSizeOperands,
105 op.getAlignmentAttr());
106
107 // Take a subview of the new (bigger) buffer such that we can copy the
108 // old values over (the copy operation requires both operands to have
109 // the same shape).
110 Value subview = memref::SubViewOp::create(
111 builder, loc, newAlloc,
112 ArrayRef<OpFoldResult>{rewriter.getIndexAttr(0)},
113 ArrayRef<OpFoldResult>{currSize},
114 ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});
115 memref::CopyOp::create(builder, loc, op.getSource(), subview);
116
117 // Insert the deallocation of the old buffer only if requested
118 // (enabled by default).
120 memref::DeallocOp::create(builder, loc, op.getSource());
122 scf::YieldOp::create(builder, loc, newAlloc);
123 },
124 [&](OpBuilder &builder, Location loc) {
125 // We need to reinterpret-cast here because either the input or output
126 // type might be static, which means we need to cast from static to
127 // dynamic or vice-versa. If both are static and the original buffer
128 // is already bigger than the requested size, the cast represents a
129 // subview operation.
130 Value casted = memref::ReinterpretCastOp::create(
131 builder, loc, cast<MemRefType>(op.getResult().getType()),
132 op.getSource(), rewriter.getIndexAttr(0),
133 ArrayRef<OpFoldResult>{targetSize},
134 ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});
135 scf::YieldOp::create(builder, loc, casted);
136 });
138 rewriter.replaceOp(op, ifOp.getResult(0));
139 return success();
141
142private:
143 const bool emitDeallocs;
144};
145
146struct ExpandReallocPass
147 : public memref::impl::ExpandReallocPassBase<ExpandReallocPass> {
149
150 void runOnOperation() override {
151 MLIRContext &ctx = getContext();
152
153 RewritePatternSet patterns(&ctx);
156
157 target.addLegalDialect<arith::ArithDialect, scf::SCFDialect,
158 memref::MemRefDialect>();
159 target.addIllegalOp<memref::ReallocOp>();
160 if (failed(applyPartialConversion(getOperation(), target,
161 std::move(patterns))))
163 }
165
166} // namespace
167
169 bool emitDeallocs) {
170 patterns.add<ExpandReallocOpPattern>(patterns.getContext(), emitDeallocs);
return success()
lhs
b getContext())
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
OpT getOperation()
Return the current operation being transformed.
Definition Pass.h:389
virtual void runOnOperation()=0
The polymorphic API that runs the pass over the currently held operation.
void signalPassFailure()
Signal that some invariant was broken when running.
Definition Pass.h:226
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
::mlir::Pass::Option< bool > emitDeallocs
void populateExpandReallocPatterns(RewritePatternSet &patterns, bool emitDeallocs=true)
Appends patterns for expanding memref.realloc operations.
Include the generated interface declarations.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...