28 #define GEN_PASS_DEF_CONVERTBUFFERIZATIONTOMEMREF
29 #include "mlir/Conversion/Passes.h.inc"
43 matchAndRewrite(bufferization::CloneOp op, OpAdaptor adaptor,
46 Type type = op.getType();
47 if (isa<UnrankedMemRefType>(type)) {
49 op,
"UnrankedMemRefType is not supported.");
51 MemRefType memrefType = cast<MemRefType>(type);
52 MemRefLayoutAttrInterface layout;
55 layout, memrefType.getMemorySpace());
58 if (!memref::CastOp::areCastCompatible({allocType}, {memrefType}))
65 for (
int i = 0; i < memrefType.getRank(); ++i) {
66 if (!memrefType.isDynamicDim(i))
69 dynamicOperands.push_back(dim);
76 if (memrefType != allocType)
77 alloc = rewriter.
create<memref::CastOp>(op->
getLoc(), memrefType, alloc);
79 rewriter.
create<memref::CopyOp>(loc, op.getInput(), alloc);
87 struct BufferizationToMemRefPass
88 :
public impl::ConvertBufferizationToMemRefBase<BufferizationToMemRefPass> {
89 BufferizationToMemRefPass() =
default;
91 void runOnOperation()
override {
92 if (!isa<ModuleOp, FunctionOpInterface>(getOperation())) {
94 "root operation must be a builtin.module or a function");
99 func::FuncOp helperFuncOp;
100 if (
auto module = dyn_cast<ModuleOp>(getOperation())) {
106 getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
107 if (deallocOp.getMemrefs().size() > 1) {
108 helperFuncOp = bufferization::buildDeallocationLibraryFunction(
109 builder, getOperation()->getLoc(), symbolTable);
110 return WalkResult::interrupt();
117 patterns.add<CloneOpConversion>(patterns.getContext());
122 target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
123 scf::SCFDialect, func::FuncDialect>();
124 target.addIllegalDialect<bufferization::BufferizationDialect>();
127 std::move(patterns))))
134 return std::make_unique<BufferizationToMemRefPass>();
static MLIRContext * getContext(OpFoldResult val)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Location getLoc()
The source location the operation was defined or derived from.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static WalkResult advance()
void populateBufferizationDeallocLoweringPattern(RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc)
Adds the conversion pattern of the bufferization.dealloc operation to the given pattern set for use i...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createBufferizationToMemRefPass()
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.