MLIR  22.0.0git
BufferizationToMemRef.cpp
Go to the documentation of this file.
1 //===- BufferizationToMemRef.cpp - Bufferization to MemRef conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Bufferization dialect to MemRef
10 // dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
22 #include "mlir/IR/BuiltinTypes.h"
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_CONVERTBUFFERIZATIONTOMEMREFPASS
27 #include "mlir/Conversion/Passes.h.inc"
28 } // namespace mlir
29 
30 using namespace mlir;
31 
32 namespace {
33 /// The CloneOpConversion transforms all bufferization clone operations into
34 /// memref alloc and memref copy operations. In the dynamic-shape case, it also
35 /// emits additional dim and constant operations to determine the shape. This
36 /// conversion does not resolve memory leaks if it is used alone.
37 struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
39 
40  LogicalResult
41  matchAndRewrite(bufferization::CloneOp op, OpAdaptor adaptor,
42  ConversionPatternRewriter &rewriter) const override {
43  Location loc = op->getLoc();
44 
45  Type type = op.getType();
46  Value alloc;
47 
48  if (auto unrankedType = dyn_cast<UnrankedMemRefType>(type)) {
49  // Constants
50  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
51  Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
52 
53  // Dynamically evaluate the size and shape of the unranked memref
54  Value rank = memref::RankOp::create(rewriter, loc, op.getInput());
55  MemRefType allocType =
56  MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType());
57  Value shape = memref::AllocaOp::create(rewriter, loc, allocType, rank);
58 
59  // Create a loop to query dimension sizes, store them as a shape, and
60  // compute the total size of the memref
61  auto loopBody = [&](OpBuilder &builder, Location loc, Value i,
62  ValueRange args) {
63  auto acc = args.front();
64  auto dim = memref::DimOp::create(rewriter, loc, op.getInput(), i);
65 
66  memref::StoreOp::create(rewriter, loc, dim, shape, i);
67  acc = arith::MulIOp::create(rewriter, loc, acc, dim);
68 
69  scf::YieldOp::create(rewriter, loc, acc);
70  };
71  auto size = scf::ForOp::create(rewriter, loc, zero, rank, one,
72  ValueRange(one), loopBody)
73  .getResult(0);
74 
75  MemRefType memrefType = MemRefType::get({ShapedType::kDynamic},
76  unrankedType.getElementType());
77 
78  // Allocate new memref with 1D dynamic shape, then reshape into the
79  // shape of the original unranked memref
80  alloc = memref::AllocOp::create(rewriter, loc, memrefType, size);
81  alloc =
82  memref::ReshapeOp::create(rewriter, loc, unrankedType, alloc, shape);
83  } else {
84  MemRefType memrefType = cast<MemRefType>(type);
85  MemRefLayoutAttrInterface layout;
86  auto allocType =
87  MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
88  layout, memrefType.getMemorySpace());
89  // Since this implementation always allocates, certain result types of
90  // the clone op cannot be lowered.
91  if (!memref::CastOp::areCastCompatible({allocType}, {memrefType}))
92  return failure();
93 
94  // Transform a clone operation into alloc + copy operation and pay
95  // attention to the shape dimensions.
96  SmallVector<Value, 4> dynamicOperands;
97  for (int i = 0; i < memrefType.getRank(); ++i) {
98  if (!memrefType.isDynamicDim(i))
99  continue;
100  Value dim = rewriter.createOrFold<memref::DimOp>(loc, op.getInput(), i);
101  dynamicOperands.push_back(dim);
102  }
103 
104  // Allocate a memref with identity layout.
105  alloc =
106  memref::AllocOp::create(rewriter, loc, allocType, dynamicOperands);
107  // Cast the allocation to the specified type if needed.
108  if (memrefType != allocType)
109  alloc =
110  memref::CastOp::create(rewriter, op->getLoc(), memrefType, alloc);
111  }
112 
113  memref::CopyOp::create(rewriter, loc, op.getInput(), alloc);
114  rewriter.replaceOp(op, alloc);
115  return success();
116  }
117 };
118 
119 } // namespace
120 
121 namespace {
122 struct BufferizationToMemRefPass
123  : public impl::ConvertBufferizationToMemRefPassBase<
124  BufferizationToMemRefPass> {
125  BufferizationToMemRefPass() = default;
126 
127  void runOnOperation() override {
128  if (!isa<ModuleOp, FunctionOpInterface>(getOperation())) {
129  emitError(getOperation()->getLoc(),
130  "root operation must be a builtin.module or a function");
131  signalPassFailure();
132  return;
133  }
134 
135  bufferization::DeallocHelperMap deallocHelperFuncMap;
136  if (auto module = dyn_cast<ModuleOp>(getOperation())) {
137  OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
138 
139  // Build dealloc helper function if there are deallocs.
140  getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
141  Operation *symtableOp =
143  if (deallocOp.getMemrefs().size() > 1 &&
144  !deallocHelperFuncMap.contains(symtableOp)) {
145  SymbolTable symbolTable(symtableOp);
146  func::FuncOp helperFuncOp =
148  builder, getOperation()->getLoc(), symbolTable);
149  deallocHelperFuncMap[symtableOp] = helperFuncOp;
150  }
151  });
152  }
153 
155  patterns.add<CloneOpConversion>(patterns.getContext());
157  patterns, deallocHelperFuncMap);
158 
159  ConversionTarget target(getContext());
160  target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
161  scf::SCFDialect, func::FuncDialect>();
162  target.addIllegalDialect<bufferization::BufferizationDialect>();
163 
164  if (failed(applyPartialConversion(getOperation(), target,
165  std::move(patterns))))
166  signalPassFailure();
167  }
168 };
169 } // namespace
static MLIRContext * getContext(OpFoldResult val)
IndexType getIndexType()
Definition: Builders.cpp:50
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:205
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition: Builders.h:238
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:452
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
void populateBufferizationDeallocLoweringPattern(RewritePatternSet &patterns, const DeallocHelperMap &deallocHelperFuncMap)
Adds the conversion pattern of the bufferization.dealloc operation to the given pattern set for use i...
func::FuncOp buildDeallocationLibraryFunction(OpBuilder &builder, Location loc, SymbolTable &symbolTable)
Construct the library function needed for the fully generic bufferization.dealloc lowering implemente...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.