26 #define GEN_PASS_DEF_CONVERTBUFFERIZATIONTOMEMREFPASS
27 #include "mlir/Conversion/Passes.h.inc"
41 matchAndRewrite(bufferization::CloneOp op, OpAdaptor adaptor,
45 Type type = op.getType();
48 if (
auto unrankedType = dyn_cast<UnrankedMemRefType>(type)) {
54 Value rank = memref::RankOp::create(rewriter, loc, op.getInput());
55 MemRefType allocType =
57 Value shape = memref::AllocaOp::create(rewriter, loc, allocType, rank);
63 auto acc = args.front();
64 auto dim = memref::DimOp::create(rewriter, loc, op.getInput(), i);
66 memref::StoreOp::create(rewriter, loc, dim, shape, i);
67 acc = arith::MulIOp::create(rewriter, loc, acc, dim);
69 scf::YieldOp::create(rewriter, loc, acc);
71 auto size = scf::ForOp::create(rewriter, loc, zero, rank, one,
76 unrankedType.getElementType());
80 alloc = memref::AllocOp::create(rewriter, loc, memrefType, size);
82 memref::ReshapeOp::create(rewriter, loc, unrankedType, alloc, shape);
84 MemRefType memrefType = cast<MemRefType>(type);
85 MemRefLayoutAttrInterface layout;
88 layout, memrefType.getMemorySpace());
91 if (!memref::CastOp::areCastCompatible({allocType}, {memrefType}))
97 for (
int i = 0; i < memrefType.getRank(); ++i) {
98 if (!memrefType.isDynamicDim(i))
101 dynamicOperands.push_back(dim);
106 memref::AllocOp::create(rewriter, loc, allocType, dynamicOperands);
108 if (memrefType != allocType)
110 memref::CastOp::create(rewriter, op->getLoc(), memrefType, alloc);
113 memref::CopyOp::create(rewriter, loc, op.getInput(), alloc);
122 struct BufferizationToMemRefPass
123 :
public impl::ConvertBufferizationToMemRefPassBase<
124 BufferizationToMemRefPass> {
125 BufferizationToMemRefPass() =
default;
127 void runOnOperation()
override {
128 if (!isa<ModuleOp, FunctionOpInterface>(getOperation())) {
130 "root operation must be a builtin.module or a function");
136 if (
auto module = dyn_cast<ModuleOp>(getOperation())) {
140 getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
143 if (deallocOp.getMemrefs().size() > 1 &&
144 !deallocHelperFuncMap.contains(symtableOp)) {
146 func::FuncOp helperFuncOp =
148 builder, getOperation()->getLoc(), symbolTable);
149 deallocHelperFuncMap[symtableOp] = helperFuncOp;
160 target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
161 scf::SCFDialect, func::FuncDialect>();
162 target.addIllegalDialect<bufferization::BufferizationDialect>();
static MLIRContext * getContext(OpFoldResult val)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
void populateBufferizationDeallocLoweringPattern(RewritePatternSet &patterns, const DeallocHelperMap &deallocHelperFuncMap)
Adds the conversion pattern of the bufferization.dealloc operation to the given pattern set for use i...
func::FuncOp buildDeallocationLibraryFunction(OpBuilder &builder, Location loc, SymbolTable &symbolTable)
Construct the library function needed for the fully generic bufferization.dealloc lowering implemente...
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.