20 namespace ml_program {
23 template <
typename Interface,
typename Op>
24 struct ExternalModelBase
25 :
public BufferizableOpInterface::ExternalModel<Interface, Op> {
39 struct GlobalOpInterface
40 :
public ExternalModelBase<GlobalOpInterface, GlobalOp> {
56 auto globalOp = cast<GlobalOp>(op);
57 if (!globalOp.getValue().has_value())
58 return globalOp.emitError(
"global op must have a value");
60 auto tensorType = cast<TensorType>(globalOp.getType());
63 replaceOpWithNewBufferizedOp<memref::GlobalOp>(
64 rewriter, globalOp, globalOp.getSymName(),
65 globalOp.getSymVisibilityAttr(),
66 cast<MemRefType>(memrefType),
67 globalOp.getValue().value(),
68 !globalOp.getIsMutable(),
76 struct GlobalLoadOpInterface
77 :
public ExternalModelBase<GlobalLoadOpInterface, GlobalLoadOp> {
95 auto globalLoadOp = cast<GlobalLoadOp>(op);
97 auto tensorType = cast<TensorType>(globalLoadOp.getType());
100 replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
101 rewriter, globalLoadOp, memrefType,
102 globalLoadOp.getGlobalAttr().getLeafReference());
110 struct GlobalStoreOpInterface
111 :
public ExternalModelBase<GlobalStoreOpInterface, GlobalStoreOp> {
125 auto globalStoreOp = cast<GlobalStoreOp>(op);
127 auto tensorType = cast<TensorType>(globalStoreOp.getValue().getType());
130 auto loc = globalStoreOp.getLoc();
131 auto targetMemref = rewriter.
create<memref::GetGlobalOp>(
132 loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
135 if (failed(sourceMemref)) {
140 options.createMemCpy(rewriter, loc, sourceMemref.value(), targetMemref);
141 if (failed(memcpy)) {
144 rewriter.
eraseOp(globalStoreOp);
153 GlobalOp::attachInterface<GlobalOpInterface>(*ctx);
154 GlobalLoadOp::attachInterface<GlobalLoadOpInterface>(*ctx);
155 GlobalStoreOp::attachInterface<GlobalStoreOpInterface>(*ctx);
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
bool hasTensorSemantics(Operation *op)
Return "true" if the given op has tensor semantics and should be bufferized.
BufferRelation
Specifies a fine-grain relationship between buffers to enable more analysis.
Include the generated interface declarations.
Options for BufferizableOpInterface-based bufferization.