MLIR  20.0.0git
BufferizationToMemRef.cpp
Go to the documentation of this file.
1 //===- BufferizationToMemRef.cpp - Bufferization to MemRef conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Bufferization dialect to MemRef
10 // dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/Pass/Pass.h"
25 
26 namespace mlir {
27 #define GEN_PASS_DEF_CONVERTBUFFERIZATIONTOMEMREF
28 #include "mlir/Conversion/Passes.h.inc"
29 } // namespace mlir
30 
31 using namespace mlir;
32 
33 namespace {
34 /// The CloneOpConversion transforms all bufferization clone operations into
35 /// memref alloc and memref copy operations. In the dynamic-shape case, it also
36 /// emits additional dim and constant operations to determine the shape. This
37 /// conversion does not resolve memory leaks if it is used alone.
38 struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
40 
41  LogicalResult
42  matchAndRewrite(bufferization::CloneOp op, OpAdaptor adaptor,
43  ConversionPatternRewriter &rewriter) const override {
44  Location loc = op->getLoc();
45 
46  Type type = op.getType();
47  Value alloc;
48 
49  if (auto unrankedType = dyn_cast<UnrankedMemRefType>(type)) {
50  // Constants
51  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
52  Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
53 
54  // Dynamically evaluate the size and shape of the unranked memref
55  Value rank = rewriter.create<memref::RankOp>(loc, op.getInput());
56  MemRefType allocType =
57  MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType());
58  Value shape = rewriter.create<memref::AllocaOp>(loc, allocType, rank);
59 
60  // Create a loop to query dimension sizes, store them as a shape, and
61  // compute the total size of the memref
62  auto loopBody = [&](OpBuilder &builder, Location loc, Value i,
63  ValueRange args) {
64  auto acc = args.front();
65  auto dim = rewriter.create<memref::DimOp>(loc, op.getInput(), i);
66 
67  rewriter.create<memref::StoreOp>(loc, dim, shape, i);
68  acc = rewriter.create<arith::MulIOp>(loc, acc, dim);
69 
70  rewriter.create<scf::YieldOp>(loc, acc);
71  };
72  auto size = rewriter
73  .create<scf::ForOp>(loc, zero, rank, one, ValueRange(one),
74  loopBody)
75  .getResult(0);
76 
77  MemRefType memrefType = MemRefType::get({ShapedType::kDynamic},
78  unrankedType.getElementType());
79 
80  // Allocate new memref with 1D dynamic shape, then reshape into the
81  // shape of the original unranked memref
82  alloc = rewriter.create<memref::AllocOp>(loc, memrefType, size);
83  alloc =
84  rewriter.create<memref::ReshapeOp>(loc, unrankedType, alloc, shape);
85  } else {
86  MemRefType memrefType = cast<MemRefType>(type);
87  MemRefLayoutAttrInterface layout;
88  auto allocType =
89  MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
90  layout, memrefType.getMemorySpace());
91  // Since this implementation always allocates, certain result types of
92  // the clone op cannot be lowered.
93  if (!memref::CastOp::areCastCompatible({allocType}, {memrefType}))
94  return failure();
95 
96  // Transform a clone operation into alloc + copy operation and pay
97  // attention to the shape dimensions.
98  SmallVector<Value, 4> dynamicOperands;
99  for (int i = 0; i < memrefType.getRank(); ++i) {
100  if (!memrefType.isDynamicDim(i))
101  continue;
102  Value dim = rewriter.createOrFold<memref::DimOp>(loc, op.getInput(), i);
103  dynamicOperands.push_back(dim);
104  }
105 
106  // Allocate a memref with identity layout.
107  alloc = rewriter.create<memref::AllocOp>(loc, allocType, dynamicOperands);
108  // Cast the allocation to the specified type if needed.
109  if (memrefType != allocType)
110  alloc =
111  rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc);
112  }
113 
114  rewriter.replaceOp(op, alloc);
115  rewriter.create<memref::CopyOp>(loc, op.getInput(), alloc);
116  return success();
117  }
118 };
119 
120 } // namespace
121 
122 namespace {
123 struct BufferizationToMemRefPass
124  : public impl::ConvertBufferizationToMemRefBase<BufferizationToMemRefPass> {
125  BufferizationToMemRefPass() = default;
126 
127  void runOnOperation() override {
128  if (!isa<ModuleOp, FunctionOpInterface>(getOperation())) {
129  emitError(getOperation()->getLoc(),
130  "root operation must be a builtin.module or a function");
131  signalPassFailure();
132  return;
133  }
134 
135  bufferization::DeallocHelperMap deallocHelperFuncMap;
136  if (auto module = dyn_cast<ModuleOp>(getOperation())) {
137  OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
138 
139  // Build dealloc helper function if there are deallocs.
140  getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
141  Operation *symtableOp =
143  if (deallocOp.getMemrefs().size() > 1 &&
144  !deallocHelperFuncMap.contains(symtableOp)) {
145  SymbolTable symbolTable(symtableOp);
146  func::FuncOp helperFuncOp =
148  builder, getOperation()->getLoc(), symbolTable);
149  deallocHelperFuncMap[symtableOp] = helperFuncOp;
150  }
151  });
152  }
153 
154  RewritePatternSet patterns(&getContext());
155  patterns.add<CloneOpConversion>(patterns.getContext());
157  patterns, deallocHelperFuncMap);
158 
159  ConversionTarget target(getContext());
160  target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
161  scf::SCFDialect, func::FuncDialect>();
162  target.addIllegalDialect<bufferization::BufferizationDialect>();
163 
164  if (failed(applyPartialConversion(getOperation(), target,
165  std::move(patterns))))
166  signalPassFailure();
167  }
168 };
169 } // namespace
170 
171 std::unique_ptr<Pass> mlir::createBufferizationToMemRefPass() {
172  return std::make_unique<BufferizationToMemRefPass>();
173 }
static MLIRContext * getContext(OpFoldResult val)
IndexType getIndexType()
Definition: Builders.cpp:95
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:215
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition: Builders.h:248
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void populateBufferizationDeallocLoweringPattern(RewritePatternSet &patterns, const DeallocHelperMap &deallocHelperFuncMap)
Adds the conversion pattern of the bufferization.dealloc operation to the given pattern set for use i...
func::FuncOp buildDeallocationLibraryFunction(OpBuilder &builder, Location loc, SymbolTable &symbolTable)
Construct the library function needed for the fully generic bufferization.dealloc lowering implemente...
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createBufferizationToMemRefPass()
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.