MLIR  22.0.0git
ExpandRealloc.cpp
Go to the documentation of this file.
1 //===- ExpandRealloc.cpp - Expand memref.realloc ops into it's components -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 
16 
17 namespace mlir {
18 namespace memref {
19 #define GEN_PASS_DEF_EXPANDREALLOCPASS
20 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
21 } // namespace memref
22 } // namespace mlir
23 
24 using namespace mlir;
25 
26 namespace {
27 
28 /// The `realloc` operation performs a conditional allocation and copy to
29 /// increase the size of a buffer if necessary. This pattern converts the
30 /// `realloc` operation into this sequence of simpler operations.
31 
32 /// Example of an expansion:
33 /// ```mlir
34 /// %realloc = memref.realloc %alloc (%size) : memref<?xf32> to memref<?xf32>
35 /// ```
36 /// is expanded to
37 /// ```mlir
38 /// %c0 = arith.constant 0 : index
39 /// %dim = memref.dim %alloc, %c0 : memref<?xf32>
40 /// %is_old_smaller = arith.cmpi ult, %dim, %arg1
41 /// %realloc = scf.if %is_old_smaller -> (memref<?xf32>) {
42 /// %new_alloc = memref.alloc(%size) : memref<?xf32>
43 /// %subview = memref.subview %new_alloc[0] [%dim] [1]
44 /// memref.copy %alloc, %subview
45 /// memref.dealloc %alloc
46 /// scf.yield %alloc_0 : memref<?xf32>
47 /// } else {
48 /// %reinterpret_cast = memref.reinterpret_cast %alloc to
49 /// offset: [0], sizes: [%size], strides: [1]
50 /// scf.yield %reinterpret_cast : memref<?xf32>
51 /// }
52 /// ```
53 struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {
54  ExpandReallocOpPattern(MLIRContext *ctx, bool emitDeallocs)
55  : OpRewritePattern(ctx), emitDeallocs(emitDeallocs) {}
56 
57  LogicalResult matchAndRewrite(memref::ReallocOp op,
58  PatternRewriter &rewriter) const final {
59  Location loc = op.getLoc();
60  assert(op.getType().getRank() == 1 &&
61  "result MemRef must have exactly one rank");
62  assert(op.getSource().getType().getRank() == 1 &&
63  "source MemRef must have exactly one rank");
64  assert(op.getType().getLayout().isIdentity() &&
65  "result MemRef must have identity layout (or none)");
66  assert(op.getSource().getType().getLayout().isIdentity() &&
67  "source MemRef must have identity layout (or none)");
68 
69  // Get the size of the original buffer.
70  int64_t inputSize =
71  cast<BaseMemRefType>(op.getSource().getType()).getDimSize(0);
72  OpFoldResult currSize = rewriter.getIndexAttr(inputSize);
73  if (ShapedType::isDynamic(inputSize)) {
74  Value dimZero = getValueOrCreateConstantIndexOp(rewriter, loc,
75  rewriter.getIndexAttr(0));
76  currSize = memref::DimOp::create(rewriter, loc, op.getSource(), dimZero)
77  .getResult();
78  }
79 
80  // Get the requested size that the new buffer should have.
81  int64_t outputSize =
82  cast<BaseMemRefType>(op.getResult().getType()).getDimSize(0);
83  OpFoldResult targetSize = ShapedType::isDynamic(outputSize)
84  ? OpFoldResult{op.getDynamicResultSize()}
85  : rewriter.getIndexAttr(outputSize);
86 
87  // Only allocate a new buffer and copy over the values in the old buffer if
88  // the old buffer is smaller than the requested size.
89  Value lhs = getValueOrCreateConstantIndexOp(rewriter, loc, currSize);
90  Value rhs = getValueOrCreateConstantIndexOp(rewriter, loc, targetSize);
91  Value cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
92  lhs, rhs);
93  auto ifOp = scf::IfOp::create(
94  rewriter, loc, cond,
95  [&](OpBuilder &builder, Location loc) {
96  // Allocate the new buffer. If it is a dynamic memref we need to pass
97  // an additional operand for the size at runtime, otherwise the static
98  // size is encoded in the result type.
99  SmallVector<Value> dynamicSizeOperands;
100  if (op.getDynamicResultSize())
101  dynamicSizeOperands.push_back(op.getDynamicResultSize());
102 
103  Value newAlloc = memref::AllocOp::create(
104  builder, loc, op.getResult().getType(), dynamicSizeOperands,
105  op.getAlignmentAttr());
106 
107  // Take a subview of the new (bigger) buffer such that we can copy the
108  // old values over (the copy operation requires both operands to have
109  // the same shape).
110  Value subview = memref::SubViewOp::create(
111  builder, loc, newAlloc,
112  ArrayRef<OpFoldResult>{rewriter.getIndexAttr(0)},
113  ArrayRef<OpFoldResult>{currSize},
114  ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});
115  memref::CopyOp::create(builder, loc, op.getSource(), subview);
116 
117  // Insert the deallocation of the old buffer only if requested
118  // (enabled by default).
119  if (emitDeallocs)
120  memref::DeallocOp::create(builder, loc, op.getSource());
121 
122  scf::YieldOp::create(builder, loc, newAlloc);
123  },
124  [&](OpBuilder &builder, Location loc) {
125  // We need to reinterpret-cast here because either the input or output
126  // type might be static, which means we need to cast from static to
127  // dynamic or vice-versa. If both are static and the original buffer
128  // is already bigger than the requested size, the cast represents a
129  // subview operation.
130  Value casted = memref::ReinterpretCastOp::create(
131  builder, loc, cast<MemRefType>(op.getResult().getType()),
132  op.getSource(), rewriter.getIndexAttr(0),
133  ArrayRef<OpFoldResult>{targetSize},
134  ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});
135  scf::YieldOp::create(builder, loc, casted);
136  });
137 
138  rewriter.replaceOp(op, ifOp.getResult(0));
139  return success();
140  }
141 
142 private:
143  const bool emitDeallocs;
144 };
145 
146 struct ExpandReallocPass
147  : public memref::impl::ExpandReallocPassBase<ExpandReallocPass> {
148  using Base::Base;
149 
150  void runOnOperation() override {
151  MLIRContext &ctx = getContext();
152 
154  memref::populateExpandReallocPatterns(patterns, emitDeallocs.getValue());
155  ConversionTarget target(ctx);
156 
157  target.addLegalDialect<arith::ArithDialect, scf::SCFDialect,
158  memref::MemRefDialect>();
159  target.addIllegalOp<memref::ReallocOp>();
160  if (failed(applyPartialConversion(getOperation(), target,
161  std::move(patterns))))
162  signalPassFailure();
163  }
164 };
165 
166 } // namespace
167 
169  bool emitDeallocs) {
170  patterns.add<ExpandReallocOpPattern>(patterns.getContext(), emitDeallocs);
171 }
static MLIRContext * getContext(OpFoldResult val)
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:205
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void populateExpandReallocPatterns(RewritePatternSet &patterns, bool emitDeallocs=true)
Appends patterns for expanding memref.realloc operations.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314