MLIR  19.0.0git
BufferizationToMemRef.cpp
Go to the documentation of this file.
1 //===- BufferizationToMemRef.cpp - Bufferization to MemRef conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Bufferization dialect to MemRef
10 // dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/Pass/Pass.h"
26 
27 namespace mlir {
28 #define GEN_PASS_DEF_CONVERTBUFFERIZATIONTOMEMREF
29 #include "mlir/Conversion/Passes.h.inc"
30 } // namespace mlir
31 
32 using namespace mlir;
33 
34 namespace {
35 /// The CloneOpConversion transforms all bufferization clone operations into
36 /// memref alloc and memref copy operations. In the dynamic-shape case, it also
37 /// emits additional dim and constant operations to determine the shape. This
38 /// conversion does not resolve memory leaks if it is used alone.
39 struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> {
41 
43  matchAndRewrite(bufferization::CloneOp op, OpAdaptor adaptor,
44  ConversionPatternRewriter &rewriter) const override {
45  // Check for unranked memref types which are currently not supported.
46  Type type = op.getType();
47  if (isa<UnrankedMemRefType>(type)) {
48  return rewriter.notifyMatchFailure(
49  op, "UnrankedMemRefType is not supported.");
50  }
51  MemRefType memrefType = cast<MemRefType>(type);
52  MemRefLayoutAttrInterface layout;
53  auto allocType =
54  MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
55  layout, memrefType.getMemorySpace());
56  // Since this implementation always allocates, certain result types of the
57  // clone op cannot be lowered.
58  if (!memref::CastOp::areCastCompatible({allocType}, {memrefType}))
59  return failure();
60 
61  // Transform a clone operation into alloc + copy operation and pay
62  // attention to the shape dimensions.
63  Location loc = op->getLoc();
64  SmallVector<Value, 4> dynamicOperands;
65  for (int i = 0; i < memrefType.getRank(); ++i) {
66  if (!memrefType.isDynamicDim(i))
67  continue;
68  Value dim = rewriter.createOrFold<memref::DimOp>(loc, op.getInput(), i);
69  dynamicOperands.push_back(dim);
70  }
71 
72  // Allocate a memref with identity layout.
73  Value alloc = rewriter.create<memref::AllocOp>(op->getLoc(), allocType,
74  dynamicOperands);
75  // Cast the allocation to the specified type if needed.
76  if (memrefType != allocType)
77  alloc = rewriter.create<memref::CastOp>(op->getLoc(), memrefType, alloc);
78  rewriter.replaceOp(op, alloc);
79  rewriter.create<memref::CopyOp>(loc, op.getInput(), alloc);
80  return success();
81  }
82 };
83 
84 } // namespace
85 
86 namespace {
87 struct BufferizationToMemRefPass
88  : public impl::ConvertBufferizationToMemRefBase<BufferizationToMemRefPass> {
89  BufferizationToMemRefPass() = default;
90 
91  void runOnOperation() override {
92  if (!isa<ModuleOp, FunctionOpInterface>(getOperation())) {
93  emitError(getOperation()->getLoc(),
94  "root operation must be a builtin.module or a function");
95  signalPassFailure();
96  return;
97  }
98 
99  func::FuncOp helperFuncOp;
100  if (auto module = dyn_cast<ModuleOp>(getOperation())) {
101  OpBuilder builder =
102  OpBuilder::atBlockBegin(&module.getBodyRegion().front());
103  SymbolTable symbolTable(module);
104 
105  // Build dealloc helper function if there are deallocs.
106  getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
107  if (deallocOp.getMemrefs().size() > 1) {
108  helperFuncOp = bufferization::buildDeallocationLibraryFunction(
109  builder, getOperation()->getLoc(), symbolTable);
110  return WalkResult::interrupt();
111  }
112  return WalkResult::advance();
113  });
114  }
115 
116  RewritePatternSet patterns(&getContext());
117  patterns.add<CloneOpConversion>(patterns.getContext());
119  helperFuncOp);
120 
121  ConversionTarget target(getContext());
122  target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
123  scf::SCFDialect, func::FuncDialect>();
124  target.addIllegalDialect<bufferization::BufferizationDialect>();
125 
126  if (failed(applyPartialConversion(getOperation(), target,
127  std::move(patterns))))
128  signalPassFailure();
129  }
130 };
131 } // namespace
132 
133 std::unique_ptr<Pass> mlir::createBufferizationToMemRefPass() {
134  return std::make_unique<BufferizationToMemRefPass>();
135 }
static MLIRContext * getContext(OpFoldResult val)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:209
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition: Builders.h:242
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
static WalkResult advance()
Definition: Visitors.h:52
void populateBufferizationDeallocLoweringPattern(RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc)
Adds the conversion pattern of the bufferization.dealloc operation to the given pattern set for use i...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createBufferizationToMemRefPass()
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26