MLIR  21.0.0git
BufferizableOpInterfaceImpl.cpp
Go to the documentation of this file.
1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 
16 using namespace mlir;
17 using namespace mlir::bufferization;
18 using namespace mlir::ml_program;
19 
20 namespace mlir {
21 namespace ml_program {
22 namespace {
23 
24 template <typename Interface, typename Op>
25 struct ExternalModelBase
26  : public BufferizableOpInterface::ExternalModel<Interface, Op> {
27 
28  AliasingValueList getAliasingValues(Operation *, OpOperand &,
29  const AnalysisState &) const {
30  return {};
31  }
32 
33  BufferRelation bufferRelation(Operation *, OpResult,
34  const AnalysisState &) const {
36  }
37 };
38 
39 /// Bufferization of ml_program.global into a memref.global
40 struct GlobalOpInterface
41  : public ExternalModelBase<GlobalOpInterface, GlobalOp> {
42 
43  bool bufferizesToMemoryRead(Operation *, OpOperand &,
44  const AnalysisState &) const {
45  return false;
46  }
47 
48  bool bufferizesToMemoryWrite(Operation *, OpOperand &,
49  const AnalysisState &) const {
50  return false;
51  }
52 
53  bool hasTensorSemantics(Operation *) const { return true; }
54 
55  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
56  const BufferizationOptions &,
57  BufferizationState &state) const {
58  auto globalOp = cast<GlobalOp>(op);
59  if (!globalOp.getValue().has_value())
60  return globalOp.emitError("global op must have a value");
61 
62  bufferization::removeSymbol(globalOp, state);
63 
64  auto tensorType = cast<TensorType>(globalOp.getType());
65  auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
66 
67  auto replacement = replaceOpWithNewBufferizedOp<memref::GlobalOp>(
68  rewriter, globalOp, globalOp.getSymName(),
69  /*sym_visibility=*/globalOp.getSymVisibilityAttr(),
70  /*type=*/cast<MemRefType>(memrefType),
71  /*initial_value=*/globalOp.getValue().value(),
72  /*constant=*/!globalOp.getIsMutable(),
73  /*alignment=*/nullptr);
74 
75  bufferization::insertSymbol(replacement, state);
76  return success();
77  }
78 };
79 
80 /// Bufferization of ml_program.global_load into a memref.get_global
81 struct GlobalLoadOpInterface
82  : public ExternalModelBase<GlobalLoadOpInterface, GlobalLoadOp> {
83 
84  bool bufferizesToMemoryRead(Operation *, OpOperand &,
85  const AnalysisState &) const {
86  return false;
87  }
88 
89  bool bufferizesToMemoryWrite(Operation *, OpOperand &,
90  const AnalysisState &) const {
91  return false;
92  }
93 
94  bool isWritable(Operation *, Value, const AnalysisState &) const {
95  return false;
96  }
97 
98  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
99  const BufferizationOptions &,
100  BufferizationState &state) const {
101  auto globalLoadOp = cast<GlobalLoadOp>(op);
102 
103  auto tensorType = cast<TensorType>(globalLoadOp.getType());
104  auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
105 
106  replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
107  rewriter, globalLoadOp, memrefType,
108  globalLoadOp.getGlobalAttr().getLeafReference());
109 
110  return success();
111  }
112 };
113 
114 /// Bufferization of ml_program.global_store into a memref.get_global and
115 /// memcpy
116 struct GlobalStoreOpInterface
117  : public ExternalModelBase<GlobalStoreOpInterface, GlobalStoreOp> {
118 
119  bool bufferizesToMemoryRead(Operation *, OpOperand &,
120  const AnalysisState &) const {
121  return false;
122  }
123 
124  bool bufferizesToMemoryWrite(Operation *, OpOperand &,
125  const AnalysisState &) const {
126  return true;
127  }
128 
129  LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
131  BufferizationState &state) const {
132  auto globalStoreOp = cast<GlobalStoreOp>(op);
133 
134  auto tensorType = cast<TensorType>(globalStoreOp.getValue().getType());
135  auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
136 
137  auto loc = globalStoreOp.getLoc();
138  auto targetMemref = rewriter.create<memref::GetGlobalOp>(
139  loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
140 
141  auto sourceMemref =
142  getBuffer(rewriter, globalStoreOp.getValue(), options, state);
143  if (failed(sourceMemref)) {
144  return failure();
145  }
146 
147  auto memcpy =
148  options.createMemCpy(rewriter, loc, sourceMemref.value(), targetMemref);
149  if (failed(memcpy)) {
150  return failure();
151  }
152  rewriter.eraseOp(globalStoreOp);
153 
154  return success();
155  }
156 };
157 } // namespace
158 
160  registry.addExtension(+[](MLIRContext *ctx, MLProgramDialect *) {
161  GlobalOp::attachInterface<GlobalOpInterface>(*ctx);
162  GlobalLoadOp::attachInterface<GlobalLoadOpInterface>(*ctx);
163  GlobalStoreOp::attachInterface<GlobalStoreOpInterface>(*ctx);
164  });
165 }
166 } // namespace ml_program
167 } // namespace mlir
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents an operand of an operation.
Definition: Value.h:257
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
BufferizationState provides information about the state of the IR during the bufferization process.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry)
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
void insertSymbol(Operation *op, BufferizationState &state)
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
void removeSymbol(Operation *op, BufferizationState &state)
bool hasTensorSemantics(Operation *op)
Return "true" if the given op has tensor semantics and should be bufferized.
BufferRelation
Specifies a fine-grain relationship between buffers to enable more analysis.
Include the generated interface declarations.
Options for BufferizableOpInterface-based bufferization.