21 namespace ml_program {
24 template <
typename Interface,
typename Op>
25 struct ExternalModelBase
26 :
public BufferizableOpInterface::ExternalModel<Interface, Op> {
40 struct GlobalOpInterface
41 :
public ExternalModelBase<GlobalOpInterface, GlobalOp> {
58 auto globalOp = cast<GlobalOp>(op);
59 if (!globalOp.getValue().has_value())
60 return globalOp.emitError(
"global op must have a value");
64 auto tensorType = cast<TensorType>(globalOp.getType());
67 auto replacement = replaceOpWithNewBufferizedOp<memref::GlobalOp>(
68 rewriter, globalOp, globalOp.getSymName(),
69 globalOp.getSymVisibilityAttr(),
70 cast<MemRefType>(memrefType),
71 globalOp.getValue().value(),
72 !globalOp.getIsMutable(),
81 struct GlobalLoadOpInterface
82 :
public ExternalModelBase<GlobalLoadOpInterface, GlobalLoadOp> {
101 auto globalLoadOp = cast<GlobalLoadOp>(op);
103 auto tensorType = cast<TensorType>(globalLoadOp.getType());
106 replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
107 rewriter, globalLoadOp, memrefType,
108 globalLoadOp.getGlobalAttr().getLeafReference());
116 struct GlobalStoreOpInterface
117 :
public ExternalModelBase<GlobalStoreOpInterface, GlobalStoreOp> {
132 auto globalStoreOp = cast<GlobalStoreOp>(op);
134 auto tensorType = cast<TensorType>(globalStoreOp.getValue().getType());
137 auto loc = globalStoreOp.getLoc();
138 auto targetMemref = rewriter.
create<memref::GetGlobalOp>(
139 loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
143 if (failed(sourceMemref)) {
148 options.createMemCpy(rewriter, loc, sourceMemref.value(), targetMemref);
149 if (failed(memcpy)) {
152 rewriter.
eraseOp(globalStoreOp);
161 GlobalOp::attachInterface<GlobalOpInterface>(*ctx);
162 GlobalLoadOp::attachInterface<GlobalLoadOpInterface>(*ctx);
163 GlobalStoreOp::attachInterface<GlobalStoreOpInterface>(*ctx);
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
BufferizationState provides information about the state of the IR during the bufferization process.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
void insertSymbol(Operation *op, BufferizationState &state)
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
void removeSymbol(Operation *op, BufferizationState &state)
bool hasTensorSemantics(Operation *op)
Return "true" if the given op has tensor semantics and should be bufferized.
BufferRelation
Specifies a fine-grain relationship between buffers to enable more analysis.
Include the generated interface declarations.
Options for BufferizableOpInterface-based bufferization.