MLIR  20.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 
15 using namespace mlir;
16 using namespace mlir::bufferization;
17 using namespace mlir::ml_program;
18 
19 namespace mlir {
20 namespace ml_program {
21 namespace {
22 
23 template <typename Interface, typename Op>
24 struct ExternalModelBase
25  : public BufferizableOpInterface::ExternalModel<Interface, Op> {
26 
27  AliasingValueList getAliasingValues(Operation *, OpOperand &,
28  const AnalysisState &) const {
29  return {};
30  }
31 
32  BufferRelation bufferRelation(Operation *, OpResult,
33  const AnalysisState &) const {
35  }
36 };
37 
38 /// Bufferization of ml_program.global into a memref.global
39 struct GlobalOpInterface
40  : public ExternalModelBase<GlobalOpInterface, GlobalOp> {
41 
42  bool bufferizesToMemoryRead(Operation *, OpOperand &,
43  const AnalysisState &) const {
44  return false;
45  }
46 
47  bool bufferizesToMemoryWrite(Operation *, OpOperand &,
48  const AnalysisState &) const {
49  return false;
50  }
51 
52  bool hasTensorSemantics(Operation *) const { return true; }
53 
54  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
55  const BufferizationOptions &) const {
56  auto globalOp = cast<GlobalOp>(op);
57  if (!globalOp.getValue().has_value())
58  return globalOp.emitError("global op must have a value");
59 
60  auto tensorType = cast<TensorType>(globalOp.getType());
61  auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
62 
63  replaceOpWithNewBufferizedOp<memref::GlobalOp>(
64  rewriter, globalOp, globalOp.getSymName(),
65  /*sym_visibility=*/globalOp.getSymVisibilityAttr(),
66  /*type=*/cast<MemRefType>(memrefType),
67  /*initial_value=*/globalOp.getValue().value(),
68  /*constant=*/!globalOp.getIsMutable(),
69  /*alignment=*/nullptr);
70 
71  return success();
72  }
73 };
74 
75 /// Bufferization of ml_program.global_load into a memref.get_global
76 struct GlobalLoadOpInterface
77  : public ExternalModelBase<GlobalLoadOpInterface, GlobalLoadOp> {
78 
79  bool bufferizesToMemoryRead(Operation *, OpOperand &,
80  const AnalysisState &) const {
81  return false;
82  }
83 
84  bool bufferizesToMemoryWrite(Operation *, OpOperand &,
85  const AnalysisState &) const {
86  return false;
87  }
88 
89  bool isWritable(Operation *, Value, const AnalysisState &) const {
90  return false;
91  }
92 
93  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
94  const BufferizationOptions &) const {
95  auto globalLoadOp = cast<GlobalLoadOp>(op);
96 
97  auto tensorType = cast<TensorType>(globalLoadOp.getType());
98  auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
99 
100  replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
101  rewriter, globalLoadOp, memrefType,
102  globalLoadOp.getGlobalAttr().getLeafReference());
103 
104  return success();
105  }
106 };
107 
108 /// Bufferization of ml_program.global_store into a memref.get_global and
109 /// memcpy
110 struct GlobalStoreOpInterface
111  : public ExternalModelBase<GlobalStoreOpInterface, GlobalStoreOp> {
112 
113  bool bufferizesToMemoryRead(Operation *, OpOperand &,
114  const AnalysisState &) const {
115  return false;
116  }
117 
118  bool bufferizesToMemoryWrite(Operation *, OpOperand &,
119  const AnalysisState &) const {
120  return true;
121  }
122 
123  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
124  const BufferizationOptions &options) const {
125  auto globalStoreOp = cast<GlobalStoreOp>(op);
126 
127  auto tensorType = cast<TensorType>(globalStoreOp.getValue().getType());
128  auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
129 
130  auto loc = globalStoreOp.getLoc();
131  auto targetMemref = rewriter.create<memref::GetGlobalOp>(
132  loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
133 
134  auto sourceMemref = getBuffer(rewriter, globalStoreOp.getValue(), options);
135  if (failed(sourceMemref)) {
136  return failure();
137  }
138 
139  auto memcpy =
140  options.createMemCpy(rewriter, loc, sourceMemref.value(), targetMemref);
141  if (failed(memcpy)) {
142  return failure();
143  }
144  rewriter.eraseOp(globalStoreOp);
145 
146  return success();
147  }
148 };
149 } // namespace
150 
152  registry.addExtension(+[](MLIRContext *ctx, MLProgramDialect *) {
153  GlobalOp::attachInterface<GlobalOpInterface>(*ctx);
154  GlobalLoadOp::attachInterface<GlobalLoadOpInterface>(*ctx);
155  GlobalStoreOp::attachInterface<GlobalStoreOpInterface>(*ctx);
156  });
157 }
158 } // namespace ml_program
159 } // namespace mlir
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
bool hasTensorSemantics(Operation *op)
Return "true" if the given op has tensor semantics and should be bufferized.
BufferRelation
Specifies a fine-grain relationship between buffers to enable more analysis.
Include the generated interface declarations.
Options for BufferizableOpInterface-based bufferization.