26#define GEN_PASS_DEF_CONVERTBUFFERIZATIONTOMEMREFPASS
27#include "mlir/Conversion/Passes.h.inc"
37struct CloneOpConversion :
public OpConversionPattern<bufferization::CloneOp> {
38 using OpConversionPattern<bufferization::CloneOp>::OpConversionPattern;
41 matchAndRewrite(bufferization::CloneOp op, OpAdaptor adaptor,
42 ConversionPatternRewriter &rewriter)
const override {
43 Location loc = op->getLoc();
45 Type type = op.getType();
48 if (
auto unrankedType = dyn_cast<UnrankedMemRefType>(type)) {
54 Value rank = memref::RankOp::create(rewriter, loc, op.getInput());
55 MemRefType allocType =
56 MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType());
57 Value shape = memref::AllocaOp::create(rewriter, loc, allocType, rank);
61 auto loopBody = [&](OpBuilder &builder, Location loc, Value i,
63 auto acc = args.front();
64 auto dim = memref::DimOp::create(rewriter, loc, op.getInput(), i);
66 memref::StoreOp::create(rewriter, loc, dim, shape, i);
67 acc = arith::MulIOp::create(rewriter, loc, acc, dim);
69 scf::YieldOp::create(rewriter, loc, acc);
71 auto size = scf::ForOp::create(rewriter, loc, zero, rank, one,
75 MemRefType memrefType = MemRefType::get({ShapedType::kDynamic},
76 unrankedType.getElementType());
80 alloc = memref::AllocOp::create(rewriter, loc, memrefType, size);
82 memref::ReshapeOp::create(rewriter, loc, unrankedType, alloc, shape);
84 MemRefType memrefType = cast<MemRefType>(type);
85 MemRefLayoutAttrInterface layout;
87 MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
88 layout, memrefType.getMemorySpace());
91 if (!memref::CastOp::areCastCompatible({allocType}, {memrefType}))
96 SmallVector<Value, 4> dynamicOperands;
97 for (
int i = 0; i < memrefType.getRank(); ++i) {
98 if (!memrefType.isDynamicDim(i))
100 Value dim = rewriter.createOrFold<memref::DimOp>(loc, op.getInput(), i);
101 dynamicOperands.push_back(dim);
106 memref::AllocOp::create(rewriter, loc, allocType, dynamicOperands);
108 if (memrefType != allocType)
110 memref::CastOp::create(rewriter, op->getLoc(), memrefType, alloc);
113 memref::CopyOp::create(rewriter, loc, op.getInput(), alloc);
114 rewriter.replaceOp(op, alloc);
122struct BufferizationToMemRefPass
124 BufferizationToMemRefPass> {
125 BufferizationToMemRefPass() =
default;
127 void runOnOperation()
override {
128 if (!isa<ModuleOp, FunctionOpInterface>(getOperation())) {
130 "root operation must be a builtin.module or a function");
136 if (
auto module = dyn_cast<ModuleOp>(getOperation())) {
140 getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
141 Operation *symtableOp =
142 deallocOp->getParentWithTrait<OpTrait::SymbolTable>();
143 if (deallocOp.getMemrefs().size() > 1 &&
144 !deallocHelperFuncMap.contains(symtableOp)) {
145 SymbolTable symbolTable(symtableOp);
146 func::FuncOp helperFuncOp =
148 builder, getOperation()->getLoc(), symbolTable);
149 deallocHelperFuncMap[symtableOp] = helperFuncOp;
160 target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
161 scf::SCFDialect, func::FuncDialect>();
162 target.addIllegalDialect<bufferization::BufferizationDialect>();
164 if (
failed(applyPartialConversion(getOperation(),
target,
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
void populateBufferizationDeallocLoweringPattern(RewritePatternSet &patterns, const DeallocHelperMap &deallocHelperFuncMap)
Adds the conversion pattern of the bufferization.dealloc operation to the given pattern set for use i...
llvm::DenseMap< Operation *, func::FuncOp > DeallocHelperMap
Maps from symbol table to its corresponding dealloc helper function.
func::FuncOp buildDeallocationLibraryFunction(OpBuilder &builder, Location loc, SymbolTable &symbolTable)
Construct the library function needed for the fully generic bufferization.dealloc lowering implemente...
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns