MLIR  14.0.0git
BufferizableOpInterface.h
Go to the documentation of this file.
1 //===- BufferizableOpInterface.h - Bufferizable Ops -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_
10 #define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_
11 
12 #include <utility>
13 
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Operation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Support/LLVM.h"
21 #include "llvm/ADT/SetVector.h"
22 
23 namespace mlir {
24 class BlockAndValueMapping;
25 class DominanceInfo;
26 class FuncOp;
27 
28 namespace bufferization {
29 
30 // TODO: from some HW description.
31 static constexpr int64_t kBufferAlignments = 128;
32 
33 class BufferizableOpInterface;
35 class BufferizationState;
36 
37 /// Options for ComprehensiveBufferize.
39  using AllocationFn = std::function<FailureOr<Value>(
40  OpBuilder &, Location, MemRefType, ArrayRef<Value>)>;
41  using DeallocationFn =
42  std::function<LogicalResult(OpBuilder &, Location, Value)>;
43  using MemCpyFn =
44  std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
45 
47 
48  // BufferizationOptions cannot be copied.
49  BufferizationOptions(const BufferizationOptions &other) = delete;
50 
51  /// Return `true` if the op is allowed to be bufferized.
52  bool isOpAllowed(Operation *op) const {
53  if (!dialectFilter.hasValue())
54  return true;
55  return dialectFilter->contains(op->getDialect()->getNamespace());
56  }
57 
58  /// Allow-list the given dialects in the dialect filter. Only ops from
59  /// allow-listed dialects will be bufferized. If no dialect is added, ops from
60  /// any dialect will be bufferized.
61  template <typename... DialectTs>
63  // The following expands a call to addToDialectFilterImpl for each dialect
64  // in 'DialectTs'. This magic is necessary due to a limitation in the places
65  // that a parameter pack can be expanded in c++11.
66  // FIXME: In c++17 this can be simplified by using 'fold expressions'.
67  (void)std::initializer_list<int>{
68  0, (addToDialectFilterImpl<DialectTs>(), 0)...};
69  }
70 
71  /// Try to cast the given op to BufferizableOpInterface if the op is allow
72  /// listed.
73  BufferizableOpInterface dynCastBufferizableOp(Operation *op) const;
74 
75  /// Try to cast the given value to BufferizableOpInterface if the op is allow
76  /// listed.
77  BufferizableOpInterface dynCastBufferizableOp(Value value) const;
78 
79  /// Helper functions for allocation, deallocation, memory copying.
83 
84  /// Specifies whether returning newly allocated memrefs should be allowed.
85  /// Otherwise, a pass failure is triggered.
86  bool allowReturnMemref = false;
87 
88  /// Specifies whether not bufferizable ops are allowed in the input. If so,
89  /// bufferization.to_memref and bufferization.to_tensor ops are inserted at
90  /// the boundaries.
91  bool allowUnknownOps = false;
92 
93  /// Specifies whether dealloc ops should be generated along with alloc ops. If
94  /// not, new memory allocations will leak.
95  bool createDeallocs = true;
96 
97  /// Seed for the analysis fuzzer. If set to `0`, the fuzzer is deactivated.
98  /// Should be used only with `testAnalysisOnly = true`.
99  unsigned analysisFuzzerSeed = 0;
100 
101  /// If set to `true`, does not modify the IR apart from adding attributes (for
102  /// checking the results of the analysis) and post analysis steps.
103  bool testAnalysisOnly = false;
104 
105  /// If set to `true`, the IR is annotated with details about RaW conflicts.
106  /// For debugging only. Should be used together with `testAnalysisOnly`.
107  bool printConflicts = false;
108 
109  /// Only bufferize ops from dialects that are allowed-listed by the filter.
110  /// All other ops are ignored. This option controls the scope of partial
111  /// bufferization.
112  ///
113  /// Note: If no filter is specified, all ops are bufferized (as long as they
114  /// implement BufferizableOpInterface). If a filter is specified,
115  /// `allowUnknownOps` should be enabled. Otherwise, bufferization would fail
116  /// when encountering an op that is forbidden by the filter.
118 
119 private:
120  /// Allow-list a dialect in the dialect filter.
121  template <typename DialectT>
122  void addToDialectFilterImpl() {
123  if (!dialectFilter.hasValue())
124  dialectFilter.emplace();
125  dialectFilter->insert(DialectT::getDialectNamespace());
126  }
127 };
128 
129 /// Specify fine-grain relationship between buffers to enable more analysis.
130 enum class BufferRelation {
131  None,
132  // TODO: ResultContainsOperand,
133  // TODO: OperandContainsResult,
134  Equivalent
135 };
136 
137 /// Return `true` if the given value is a BlockArgument of a FuncOp.
139 
140 /// Dialect-specific bufferization state. Analysis/bufferization information
141 /// that is specific to ops from a certain dialect can be stored in derived
142 /// variants of this struct.
144  DialectBufferizationState() = default;
145 
146  virtual ~DialectBufferizationState() = default;
147 
148  // Copying state is forbidden. Always pass as reference.
150 };
151 
152 /// BufferizationState provides a variety of helper functions for dealing with
153 /// tensor values and memref buffers.
155 public:
156  /// Determine which OpOperand* will alias with `result` if the op is
157  /// bufferized in place. Return an empty vector if the op is not bufferizable.
158  SmallVector<OpOperand *> getAliasingOpOperand(OpResult result) const;
159 
160  /// Determine which OpResult will alias with `opOperand` if the op is
161  /// bufferized in place. Return an empty OpResult if the op is not
162  /// bufferizable.
163  OpResult getAliasingOpResult(OpOperand &opOperand) const;
164 
165  /// Return true if `opOperand` bufferizes to a memory read. Return `true` if
166  /// the op is not bufferizable.
167  bool bufferizesToMemoryRead(OpOperand &opOperand) const;
168 
169  /// Return true if `opOperand` bufferizes to a memory write. Return true` if
170  /// the op is not bufferizable.
171  bool bufferizesToMemoryWrite(OpOperand &opOperand) const;
172 
173  /// Return true if `opOperand` does neither read nor write but bufferizes to
174  /// an alias. Return false if the op is not bufferizable.
175  bool bufferizesToAliasOnly(OpOperand &opOperand) const;
176 
177  /// Return true if the given value is read by an op that bufferizes to a
178  /// memory read. Also takes into account ops that create an alias but do not
179  /// read by themselves (e.g., ExtractSliceOp).
180  bool isValueRead(Value value) const;
181 
182  /// Starting from `value`, follow the use-def chain in reverse, always
183  /// selecting the aliasing OpOperands. Find and return Values for which
184  /// `condition` evaluates to true. OpOperands of such matching Values are not
185  /// traversed any further.
186  ///
187  /// When reaching the end of a chain (BlockArgument or Value without aliasing
188  /// OpOperands), also return the last Value of that chain.
189  ///
190  /// Example:
191  ///
192  /// 8
193  /// |
194  /// 6* 7* +-----+----+
195  /// | | | |
196  /// 2* 3 4* 5
197  /// | | | |
198  /// +----------+----------+----------+
199  /// |
200  /// 1
201  ///
202  /// In the above example, Values with a star satisfy the condition. When
203  /// starting the traversal from Value 1, the resulting SetVector is:
204  /// { 2, 7, 8, 5 }
205  SetVector<Value> findValueInReverseUseDefChain(
206  Value value, llvm::function_ref<bool(Value)> condition) const;
207 
208  /// Find the Values of the last preceding write of a given Value.
209  ///
210  /// Note: Unknown ops are handled conservatively and assumed to be writes.
211  /// Furthermore, BlockArguments are also assumed to be writes. There is no
212  /// analysis across block boundaries.
213  ///
214  /// Note: When reaching an end of the reverse SSA use-def chain, that value
215  /// is returned regardless of whether it is a memory write or not.
216  SetVector<Value> findLastPrecedingWrite(Value value) const;
217 
218  /// Return `true` if the given OpResult has been decided to bufferize inplace.
219  virtual bool isInPlace(OpOperand &opOperand) const = 0;
220 
221  /// Return true if `v1` and `v2` bufferize to equivalent buffers.
222  virtual bool areEquivalentBufferizedValues(Value v1, Value v2) const = 0;
223 
224  /// Return the buffer (memref) for a given OpOperand (tensor). Allocate
225  /// a new buffer and copy over data from the existing buffer if out-of-place
226  /// bufferization was decided.
228  getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
229  bool forceInPlace = false,
230  Optional<Operation *> customCopyInsertionPoint = None) const;
231 
232  /// Return dialect-specific bufferization state.
233  template <typename StateT>
235  auto it = dialectState.find(name);
236  if (it == dialectState.end())
237  return None;
238  return static_cast<const StateT *>(it->getSecond().get());
239  }
240 
241  /// Return dialect-specific bufferization state or create one if none exists.
242  template <typename StateT>
243  StateT &getOrCreateDialectState(StringRef name) {
244  // Create state if it does not exist yet.
245  if (!dialectState.count(name))
246  dialectState[name] = std::make_unique<StateT>();
247  return static_cast<StateT &>(*dialectState[name]);
248  }
249 
250  /// Return a reference to the BufferizationOptions.
251  const BufferizationOptions &getOptions() const { return options; }
252 
253 protected:
255 
256  // BufferizationState should be passed as a reference.
257  BufferizationState(const BufferizationState &) = delete;
258 
259  ~BufferizationState() = default;
260 
261 private:
262  /// Dialect-specific bufferization state.
264 
265  /// A reference to current bufferization options.
267 };
268 
269 /// Replace an op with replacement values. The op is deleted. Tensor OpResults
270 /// must be replaced with memref values.
272  ValueRange values);
273 
274 /// Replace an op with a new op. The new op must have the same number of
275 /// results as the replaced op. The new op may not return any tensor values.
276 template <typename OpTy, typename... Args>
278  Args &&...args) {
279  auto newOp = rewriter.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
280  replaceOpWithBufferizedValues(rewriter, op, newOp->getResults());
281  return newOp;
282 }
283 
284 /// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
285 /// with the same shape as `shapedType` and specified `layout` and
286 /// `addressSpace`.
287 MemRefType getContiguousMemRefType(ShapedType shapedType,
288  MemRefLayoutAttrInterface layout = {},
289  Attribute memorySpace = {});
290 
291 /// Return an UnrankedMemRefType with the given element type and memory space.
293  Attribute memorySpace = {});
294 
295 /// Return a MemRefType to which the `tensorType` can be bufferized in a
296 /// composable fashion. The layout must be the most dynamic possible and
297 /// canonicalize away once bufferization is finished.
298 MemRefType getDynamicMemRefType(RankedTensorType tensorType,
299  unsigned addressSpace = 0);
300 
301 /// Creates a memref allocation.
302 FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
303  ArrayRef<Value> dynShape,
305 
306 /// Creates a memref allocation for the given shaped value. This function may
307 /// perform additional optimizations such as buffer allocation hoisting. If
308 /// `createDealloc`, a deallocation op is inserted at the point where the
309 /// allocation goes out of scope.
311  bool deallocMemref,
312  const BufferizationOptions &options);
313 
314 /// Creates a memref deallocation. The given memref buffer must have been
315 /// allocated using `createAlloc`.
316 LogicalResult createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
317  const BufferizationOptions &options);
318 
319 /// Creates a memcpy between two given buffers.
321  const BufferizationOptions &options);
322 
323 } // namespace bufferization
324 } // namespace mlir
325 
326 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h.inc"
327 
328 namespace mlir {
329 namespace bufferization {
330 
331 /// AllocationHoistingBarrierOnly is an external implementation of
332 /// BufferizableOpInterface for ops that are (not yet) bufferizable, but are
333 /// known to be allocation hoisting barriers. All interface methods (except for
334 /// `isAllocationHoistingBarrier`) are implemented conservatively.
335 template <typename OpTy>
337  : public BufferizableOpInterface::ExternalModel<
338  AllocationHoistingBarrierOnly<OpTy>, OpTy> {
340  const BufferizationState &state) const {
341  return true;
342  }
343 
345  const BufferizationState &state) const {
346  return true;
347  }
348 
351  const BufferizationState &state) const {
352  return {};
353  }
354 
356  const BufferizationState &state) const {
357  return OpResult();
358  }
359 
361  const BufferizationState &state) const {
362  return BufferRelation::None;
363  }
364 
366  const BufferizationState &state) const {
367  return false;
368  }
369 
371  const BufferizationState &state) const {
372  return failure();
373  }
374 
375  bool isAllocationHoistingBarrier(Operation *op) const { return true; }
376 };
377 
378 } // namespace bufferization
379 } // namespace mlir
380 
381 #endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
Dialect-specific bufferization state.
This is a value defined by a result of an operation.
Definition: Value.h:423
static constexpr int64_t kBufferAlignments
bool testAnalysisOnly
If set to true, does not modify the IR apart from adding attributes (for checking the results of the ...
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
void addToDialectFilter()
Allow-list the given dialects in the dialect filter.
std::function< LogicalResult(OpBuilder &, Location, Value, Value)> MemCpyFn
Optional< AllocationFn > allocationFn
Helper functions for allocation, deallocation, memory copying.
Optional< DenseSet< StringRef > > dialectFilter
Only bufferize ops from dialects that are allowed-listed by the filter.
bool allowUnknownOps
Specifies whether not bufferizable ops are allowed in the input.
static constexpr const bool value
std::function< FailureOr< Value >(OpBuilder &, Location, MemRefType, ArrayRef< Value >)> AllocationFn
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, const BufferizationState &state) const
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const BufferizationState &state) const
std::function< LogicalResult(OpBuilder &, Location, Value)> DeallocationFn
This class provides support for representing a failure result, or a valid value of type T...
Definition: LogicalResult.h:77
BufferRelation bufferRelation(Operation *op, OpResult opResult, const BufferizationState &state) const
MemRefType getContiguousMemRefType(ShapedType shapedType, MemRefLayoutAttrInterface layout={}, Attribute memorySpace={})
Return a contiguous MemRefType (i.e.
Attributes are known-constant values of operations.
Definition: Attributes.h:24
bool createDeallocs
Specifies whether dealloc ops should be generated along with alloc ops.
BufferizableOpInterface dynCastBufferizableOp(Operation *op) const
Try to cast the given op to BufferizableOpInterface if the op is allow listed.
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:106
bool isFunctionArgument(Value value)
Return true if the given value is a BlockArgument of a FuncOp.
Options for ComprehensiveBufferize.
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const BufferizationState &state) const
LogicalResult createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer, const BufferizationOptions &options)
Creates a memref deallocation.
AllocationHoistingBarrierOnly is an external implementation of BufferizableOpInterface for ops that a...
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationState &state) const
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, Value to, const BufferizationOptions &options)
Creates a memcpy between two given buffers.
MemRefType getDynamicMemRefType(RankedTensorType tensorType, unsigned addressSpace=0)
Return a MemRefType to which the tensorType can be bufferized in a composable fashion.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
Optional< const StateT * > getDialectState(StringRef name) const
Return dialect-specific bufferization state.
bool printConflicts
If set to true, the IR is annotated with details about RaW conflicts.
SmallVector< OpOperand * > getAliasingOpOperand(Operation *op, OpResult opResult, const BufferizationState &state) const
static llvm::ManagedStatic< PassManagerOptions > options
static bool isInPlace(Value val)
Returns true if tensor has an in-place annotation.
OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op, Args &&...args)
Replace an op with a new op.
BufferizationState provides a variety of helper functions for dealing with tensor values and memref b...
bool isOpAllowed(Operation *op) const
Return true if the op is allowed to be bufferized.
unsigned analysisFuzzerSeed
Seed for the analysis fuzzer.
StringRef getNamespace() const
Definition: Dialect.h:58
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:103
This class represents an operand of an operation.
Definition: Value.h:249
StateT & getOrCreateDialectState(StringRef name)
Return dialect-specific bufferization state or create one if none exists.
const BufferizationOptions & getOptions() const
Return a reference to the BufferizationOptions.
BufferRelation
Specify fine-grain relationship between buffers to enable more analysis.
bool allowReturnMemref
Specifies whether returning newly allocated memrefs should be allowed.
bool isWritable(Operation *op, Value value, const BufferizationState &state) const
result_range getResults()
Definition: Operation.h:284
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.
UnrankedMemRefType getUnrankedMemRefType(Type elementType, Attribute memorySpace={})
Return an UnrankedMemRefType with the given element type and memory space.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:688
FailureOr< Value > createAlloc(OpBuilder &b, Location loc, MemRefType type, ArrayRef< Value > dynShape, const BufferizationOptions &options)
Creates a memref allocation.