27 #define GEN_PASS_DEF_CONVERTBUFFERIZATIONTOMEMREF
28 #include "mlir/Conversion/Passes.h.inc"
42 matchAndRewrite(bufferization::CloneOp op, OpAdaptor adaptor,
46 Type type = op.getType();
49 if (
auto unrankedType = dyn_cast<UnrankedMemRefType>(type)) {
51 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
52 Value one = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
55 Value rank = rewriter.
create<memref::RankOp>(loc, op.getInput());
56 MemRefType allocType =
58 Value shape = rewriter.
create<memref::AllocaOp>(loc, allocType, rank);
64 auto acc = args.front();
65 auto dim = rewriter.
create<memref::DimOp>(loc, op.getInput(), i);
67 rewriter.
create<memref::StoreOp>(loc, dim, shape, i);
68 acc = rewriter.
create<arith::MulIOp>(loc, acc, dim);
70 rewriter.
create<scf::YieldOp>(loc, acc);
78 unrankedType.getElementType());
82 alloc = rewriter.
create<memref::AllocOp>(loc, memrefType, size);
84 rewriter.
create<memref::ReshapeOp>(loc, unrankedType, alloc, shape);
86 MemRefType memrefType = cast<MemRefType>(type);
87 MemRefLayoutAttrInterface layout;
90 layout, memrefType.getMemorySpace());
93 if (!memref::CastOp::areCastCompatible({allocType}, {memrefType}))
99 for (
int i = 0; i < memrefType.getRank(); ++i) {
100 if (!memrefType.isDynamicDim(i))
103 dynamicOperands.push_back(dim);
107 alloc = rewriter.
create<memref::AllocOp>(loc, allocType, dynamicOperands);
109 if (memrefType != allocType)
111 rewriter.
create<memref::CastOp>(op->getLoc(), memrefType, alloc);
115 rewriter.
create<memref::CopyOp>(loc, op.getInput(), alloc);
123 struct BufferizationToMemRefPass
124 :
public impl::ConvertBufferizationToMemRefBase<BufferizationToMemRefPass> {
125 BufferizationToMemRefPass() =
default;
127 void runOnOperation()
override {
128 if (!isa<ModuleOp, FunctionOpInterface>(getOperation())) {
130 "root operation must be a builtin.module or a function");
136 if (
auto module = dyn_cast<ModuleOp>(getOperation())) {
140 getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
143 if (deallocOp.getMemrefs().size() > 1 &&
144 !deallocHelperFuncMap.contains(symtableOp)) {
146 func::FuncOp helperFuncOp =
148 builder, getOperation()->getLoc(), symbolTable);
149 deallocHelperFuncMap[symtableOp] = helperFuncOp;
160 target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
161 scf::SCFDialect, func::FuncDialect>();
162 target.addIllegalDialect<bufferization::BufferizationDialect>();
172 return std::make_unique<BufferizationToMemRefPass>();
static MLIRContext * getContext(OpFoldResult val)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void populateBufferizationDeallocLoweringPattern(RewritePatternSet &patterns, const DeallocHelperMap &deallocHelperFuncMap)
Adds the conversion pattern of the bufferization.dealloc operation to the given pattern set for use i...
func::FuncOp buildDeallocationLibraryFunction(OpBuilder &builder, Location loc, SymbolTable &symbolTable)
Construct the library function needed for the fully generic bufferization.dealloc lowering implemente...
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createBufferizationToMemRefPass()
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.