MLIR 22.0.0git
BufferizableOpInterface.h
Go to the documentation of this file.
1//===- BufferizableOpInterface.h - Bufferizable Ops -------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_
10#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_
11
12#include "mlir/IR/Operation.h"
14#include "mlir/Support/LLVM.h"
15#include "llvm/ADT/DenseMapInfoVariant.h"
16#include "llvm/ADT/SetVector.h"
17#include <optional>
18
19#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
21
22namespace mlir {
23class OpBuilder;
24namespace func {
25class FuncOp;
26}
27
28namespace bufferization {
29
30class AnalysisState;
31class BufferizableOpInterface;
32
33/// Specifies a fine-grain relationship between buffers to enable more analysis.
34enum class BufferRelation {
35 Unknown,
36 // TODO: ResultContainsOperand,
37 // TODO: OperandContainsResult,
38 Equivalent
39};
40
41/// A maybe aliasing OpOperand. If `isDefinite` is `true`, the OpOperand is
42/// guaranteed to alias at runtime.
43struct AliasingOpOperand {
44 AliasingOpOperand(OpOperand *opOperand, BufferRelation relation,
45 bool isDefinite = true)
46 : opOperand(opOperand), relation(relation), isDefinite(isDefinite) {}
47
48 OpOperand *opOperand;
49 BufferRelation relation;
50 bool isDefinite;
51};
52
53/// A maybe aliasing Value. If `isDefinite` is `true`, the Value is guaranteed
54/// to alias at runtime.
55struct AliasingValue {
56 AliasingValue(Value value, BufferRelation relation, bool isDefinite = true)
57 : value(value), relation(relation), isDefinite(isDefinite) {}
58
59 Value value;
60 BufferRelation relation;
61 bool isDefinite;
62};
63
64template <typename T>
65class AliasList {
66public:
67 /// Create an empty list of aliases.
68 AliasList() = default;
69
70 /// Create a list of aliases.
71 AliasList(std::initializer_list<T> elems) {
72 for (T alias : elems)
73 addAlias(alias);
74 }
75
76 /// Create a list of aliases.
77 AliasList(SmallVector<T> &&aliases) : aliases(std::move(aliases)) {}
78
79 ArrayRef<T> getAliases() const { return aliases; }
80
81 size_t getNumAliases() const { return aliases.size(); }
82
83 void addAlias(T alias) { aliases.push_back(alias); }
84
85 auto begin() const { return aliases.begin(); }
86 auto end() const { return aliases.end(); }
87
88private:
89 /// The list of aliases.
90 SmallVector<T> aliases;
91};
92
93/// A list of possible aliasing OpOperands. This list models the runtime
94/// aliasing relationship for a Value.
95using AliasingOpOperandList = AliasList<AliasingOpOperand>;
96
97/// A list of possible aliasing Values. This list models the runtime aliasing
98/// relationship for an OpOperand.
99using AliasingValueList = AliasList<AliasingValue>;
100
101class OpFilter {
102public:
103 /// An op filter entry. Filters can be used to specify which ops should be
104 /// processed by the bufferization.
105 struct Entry {
106 /// If the filter function evaluates to `true`, the filter matches.
107 using FilterFn = std::function<bool(Operation *)>;
108
109 /// Filter type: A filter can either be a DENY filter or an ALLOW filter.
110 enum FilterType : int8_t { DENY = 0, ALLOW = 1 };
111
112 FilterFn fn;
113 FilterType type;
114 };
115
116 /// Return whether the op is allowed or not.
117 ///
118 /// If the filter does not have an ALLOW rule, ops are allowed by default,
119 /// unless they are explicitly marked as DENY. If the filter has at least one
120 /// ALLOW rule, ops are denied by default and only allowed if they match
121 /// an ALLOW rule and no DENY rule.
122 bool isOpAllowed(Operation *op) const;
123
124 /// Allow the given dialects.
125 ///
126 /// This function adds one or multiple ALLOW entries.
127 template <typename... DialectTs>
128 void allowDialect() {
129 // The following expands a call to allowDialectImpl for each dialect
130 // in 'DialectTs'.
131 (allowDialectImpl<DialectTs>(), ...);
132 }
133
134 /// Deny the given dialects.
135 ///
136 /// This function adds one or multiple DENY entries.
137 template <typename... DialectTs>
138 void denyDialect() {
139 (denyDialectImpl<DialectTs>(), ...);
140 }
141
142 /// Allow the given dialect.
143 ///
144 /// This function adds an ALLOW entry.
145 void allowDialect(StringRef dialectNamespace) {
146 Entry::FilterFn filterFn = [=](Operation *op) {
147 return op->getName().getDialectNamespace() == dialectNamespace;
148 };
149 entries.push_back(Entry{filterFn, Entry::FilterType::ALLOW});
150 }
151
152 /// Deny the given dialect.
153 ///
154 /// This function adds a DENY entry.
155 void denyDialect(StringRef dialectNamespace) {
156 Entry::FilterFn filterFn = [=](Operation *op) {
157 return op->getName().getDialectNamespace() == dialectNamespace;
158 };
159 entries.push_back(Entry{filterFn, Entry::FilterType::DENY});
160 }
161
162 /// Allow the given ops.
163 ///
164 /// This function adds one or multiple ALLOW entries.
165 template <typename... OpTys>
166 void allowOperation() {
167 (allowOperationImpl<OpTys>(), ...);
168 }
169
170 /// Deny the given ops.
171 ///
172 /// This function adds one or multiple DENY entries.
173 template <typename... OpTys>
174 void denyOperation() {
175 (denyOperationImpl<OpTys>(), ...);
176 }
177
178 /// Allow the given op.
179 ///
180 /// This function adds an ALLOW entry.
181 void allowOperation(StringRef opName) {
182 Entry::FilterFn filterFn = [=](Operation *op) {
183 return op->getName().getStringRef() == opName;
184 };
185 allowOperation(filterFn);
186 }
187
188 /// Deny the given op.
189 ///
190 /// This function adds a DENY entry.
191 void denyOperation(StringRef opName) {
192 Entry::FilterFn filterFn = [=](Operation *op) {
193 return op->getName().getStringRef() == opName;
194 };
195 denyOperation(filterFn);
196 }
197
198 /// Allow ops that are matched by `fn`.
199 ///
200 /// This function adds an ALLOW entry.
201 void allowOperation(Entry::FilterFn fn) {
202 entries.push_back(Entry{fn, Entry::FilterType::ALLOW});
203 }
204
205 /// Deny ops that are matched by `fn`.
206 ///
207 /// This function adds a DENY entry.
208 void denyOperation(Entry::FilterFn fn) {
209 entries.push_back(Entry{fn, Entry::FilterType::DENY});
210 }
211
212private:
213 /// Return `true` if the filter has at least one ALLOW rule.
214 bool hasAllowRule() const {
215 for (const Entry &e : entries)
216 if (e.type == Entry::FilterType::ALLOW)
217 return true;
218 return false;
219 }
220
221 /// Allow a dialect.
222 template <typename DialectT>
223 void allowDialectImpl() {
224 allowDialect(DialectT::getDialectNamespace());
225 }
226
227 /// Deny a dialect.
228 template <typename DialectT>
229 void denyDialectImpl() {
230 denyDialect(DialectT::getDialectNamespace());
231 }
232
233 /// Allow an op.
234 template <typename OpTy>
235 void allowOperationImpl() {
236 allowOperation(OpTy::getOperationName());
237 }
238
239 /// Deny an op.
240 template <typename OpTy>
241 void denyOperationImpl() {
242 denyOperation(OpTy::getOperationName());
243 }
244
245 /// A list of filter entries that determine whether an op should be allowed or
246 /// denied. If the filter has an ALLOW rule, only ops that are allowed and not
247 /// denied are allowed. If the filter does not have an ALLOW rule, only ops
248 /// that are not denied are allowed.
249 SmallVector<Entry> entries;
250};
251
252/// Options for BufferizableOpInterface-based bufferization.
253struct BufferizationOptions {
254 /// Allocator function: Generate a memref allocation with the given type,
255 /// dynamic extents and alignment.
256 using AllocationFn = std::function<FailureOr<Value>(
257 OpBuilder &, Location, MemRefType, ValueRange, unsigned int)>;
258 /// Memcpy function: Generate a memcpy between two buffers.
259 using MemCpyFn =
260 std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
261 /// Initializer function for analysis state.
262 using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
263 /// Tensor-like -> Buffer-like type conversion.
264 /// Parameters: tensor-like type, memory space, func op, bufferization options
265 using FunctionArgTypeConverterFn =
266 std::function<BufferLikeType(TensorLikeType, Attribute memorySpace,
267 func::FuncOp, const BufferizationOptions &)>;
268 /// Tensor -> MemRef type conversion.
269 /// Parameters: tensor type, memory space, bufferization options
270 using UnknownTypeConverterFn = std::function<BaseMemRefType(
271 TensorType, Attribute memorySpace, const BufferizationOptions &)>;
272 // Produce a MemorySpace attribute from a tensor type
273 using DefaultMemorySpaceFn =
274 std::function<std::optional<Attribute>(TensorType t)>;
275
276 BufferizationOptions();
277
278 /// Try to cast the given op to BufferizableOpInterface if the op is allow
279 /// listed.
280 BufferizableOpInterface dynCastBufferizableOp(Operation *op) const;
281
282 /// Try to cast the given value to BufferizableOpInterface if the op is allow
283 /// listed.
284 BufferizableOpInterface dynCastBufferizableOp(Value value) const;
285
286 /// A filter that specifies which ops should be bufferized and which ops
287 /// should be ignored.
288 OpFilter opFilter;
289
290 /// Return `true` if the given op should be bufferized.
291 bool isOpAllowed(Operation *op) const;
292
293 /// Helper functions for allocation and memory copying.
294 std::optional<AllocationFn> allocationFn;
295 std::optional<MemCpyFn> memCpyFn;
296
297 /// Create a memref allocation with the given type and dynamic extents.
298 FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
299 ValueRange dynShape) const;
300
301 /// Creates a memcpy between two given buffers.
302 LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from,
303 Value to) const;
304
305 /// Specifies whether not bufferizable ops are allowed in the input. If so,
306 /// bufferization.to_buffer and bufferization.to_tensor ops are inserted at
307 /// the boundaries.
308 bool allowUnknownOps = false;
309
310 /// Specifies whether function boundaries (ops in the func dialect) should be
311 /// bufferized or not.
312 bool bufferizeFunctionBoundaries = false;
313
314 // Specifies whether to account for parallel regions in RaW analysis. If true,
315 // then writes inside of parallel regions that write to buffers defined
316 // outside of the parallel region will be given a new buffer.
317 bool checkParallelRegions = true;
318
319 /// This function controls buffer types on function signatures. Sets
320 /// `functionArgTypeConverterFn` and `inferFunctionResultLayout` accordingly.
321 ///
322 /// * InferLayoutMap: All function parameter types have a fully dynamic layout
323 /// map, but function result types are inferred from the body of the
324 /// function.
325 /// * FullyDynamicLayoutMap: All function parameter types and result types
326 /// have a fully dynamic layout map. This option is most efficient because
327 /// any layout map can be casted to a fully dynamic one.
328 /// * IdentityLayoutMap: All function parameter types and result types have a
329 /// static identity layout (i.e., no layout map). This option may introduce
330 /// additional buffer allocs and copies because layout maps cannot be casted
331 /// away.
332 ///
333 /// Note: Inferred layout maps may not be desireable when interacting with
334 /// external functions, because the generated function signatures will be less
335 /// predictable.
336 void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption);
337
338 /// Type conversion from tensors to buffers. This type conversion is used to
339 /// determine bufferized function argument and result types.
340 ///
341 /// By default, if tensor is a (builtin) tensor type, it is converted to a
342 /// memref type with a fully dynamic layout map; if tensor is a (generic)
343 /// tensor-like type, it is converted using TensorLikeType::getBufferType().
344 ///
345 /// If `bufferizeFunctionBoundaries` is not set, this function isn't used.
346 FunctionArgTypeConverterFn functionArgTypeConverterFn = nullptr;
347
348 /// If true, function result types are inferred from the body of the function.
349 /// Otherwise, function result type is determined by
350 /// `functionArgTypeConverterFn`.
351 ///
352 /// If `bufferizeFunctionBoundaries` is not set, this flag has no effect.
353 bool inferFunctionResultLayout = true;
354
355 /// Type conversion from tensors to memrefs. This type conversion is used if
356 /// no memref type could be inferred during bufferization. By default, returns
357 /// a memref type with a fully dynamic layout map.
358 UnknownTypeConverterFn unknownTypeConverterFn = nullptr;
359
360 // Use during type conversion to determine the memory space for memref based
361 // on the original tensor type if the memory space cannot be inferred.
362 // Returning std::nullopt will cause bufferization to fail (useful to indicate
363 // failure to determine memory space for a tensor type).
364 DefaultMemorySpaceFn defaultMemorySpaceFn =
365 [](TensorType t) -> std::optional<Attribute> { return Attribute(); };
366
367 /// If set to `true`, the analysis is skipped. A buffer is copied before every
368 /// write. This flag cannot be used together with `testAnalysisOnly = true`.
369 bool copyBeforeWrite = false;
370
371 /// If set to `true`, does not modify the IR apart from adding attributes (for
372 /// checking the results of the analysis) and post analysis steps.
373 bool testAnalysisOnly = false;
374
375 /// If set to `true`, the IR is annotated with details about RaW conflicts.
376 /// For debugging only. Should be used together with `testAnalysisOnly`.
377 bool printConflicts = false;
378
379 /// Buffer alignment for new memory allocations.
380 unsigned int bufferAlignment = 64;
381
382 /// Initializer functions for analysis state. These can be used to
383 /// initialize dialect-specific analysis state.
384 SmallVector<AnalysisStateInitFn> stateInitializers;
385};
386
387/// Traversal parameters for `findValueInReverseUseDefChain`.
388struct TraversalConfig {
389 /// Specifies if leaves (that do not have further OpOperands to follow)
390 /// should be returned even if they do not match the specified filter.
391 bool alwaysIncludeLeaves = true;
392
393 /// Specifies whether out-of-place/undecided OpOperands should be followed.
394 bool followInPlaceOnly = false;
395
396 /// Specifies whether non-equivalent OpOperands should be followed.
397 bool followEquivalentOnly = false;
398
399 /// Specifies whether unknown/non-bufferizable/ops not included in the
400 /// OpFilter of BufferizationOptions should be followed.
401 bool followUnknownOps = false;
402
403 /// Specifies whether OpOperands with a different type that are not the result
404 /// of a CastOpInterface op should be followed.
405 bool followSameTypeOrCastsOnly = false;
406
407 /// Specifies whether already visited values should be visited again.
408 /// (Note: This can result in infinite looping.)
409 bool revisitAlreadyVisitedValues = false;
410};
411
412/// AnalysisState provides a variety of helper functions for dealing with
413/// tensor values.
414class AnalysisState {
415public:
416 /// Determine which OpOperand* will alias with `value` if the op is
417 /// bufferized in place. Return all tensor OpOperand* if the op is not
418 /// bufferizable.
419 AliasingOpOperandList getAliasingOpOperands(Value value) const;
420
421 /// Determine which Value will alias with `opOperand` if the op is bufferized
422 /// in place. Return all tensor Values if the op is not bufferizable.
423 AliasingValueList getAliasingValues(OpOperand &opOperand) const;
424
425 /// Return true if `opOperand` bufferizes to a memory read. Return `true` if
426 /// the op is not bufferizable.
427 bool bufferizesToMemoryRead(OpOperand &opOperand) const;
428
429 /// Return true if `opOperand` bufferizes to a memory write. Return true` if
430 /// the op is not bufferizable.
431 bool bufferizesToMemoryWrite(OpOperand &opOperand) const;
432
433 /// Return true if the given `value` bufferizes to a memory write. Return
434 /// true if the value is a block argument. Return `true` if the defining op is
435 /// not bufferizable. Otherwise, consult the BufferizableOpInterface.
436 bool bufferizesToMemoryWrite(Value value) const;
437
438 /// Return true if `opOperand` does neither read nor write but bufferizes to
439 /// an alias. Return false if the op is not bufferizable.
440 bool bufferizesToAliasOnly(OpOperand &opOperand) const;
441
442 /// Return true if a copy can always be avoided when allocating a new tensor
443 /// for the given OpOperand.
444 bool canOmitTensorCopy(OpOperand &opOperand) const;
445
446 /// Return true if the given value is read by an op that bufferizes to a
447 /// memory read. Also takes into account ops that create an alias but do not
448 /// read by themselves (e.g., ExtractSliceOp).
449 bool isValueRead(Value value) const;
450
451 /// Starting from `opOperand`, follow the use-def chain in reverse, always
452 /// selecting the aliasing OpOperands. Find and return Values for which
453 /// `condition` evaluates to true. OpOperands of such matching Values are not
454 /// traversed any further, the visited aliasing opOperands will be preserved
455 /// through `visitedOpOperands`.
456 ///
457 /// When reaching the end of a chain, also return the last Value of that
458 /// chain if `config.alwaysIncludeLeaves` is set.
459 ///
460 /// Example:
461 ///
462 /// 8
463 /// |
464 /// 6* 7* +-----+----+
465 /// | | | |
466 /// 2* 3 4* 5
467 /// | | | |
468 /// +----------+----------+----------+
469 /// |
470 /// 1
471 ///
472 /// In the above example, Values with a star satisfy the condition. When
473 /// starting the traversal from Value 1, the resulting SetVector is:
474 /// { 2, 7, 8, 5 }
475 ///
476 /// Additional stopping conditions for the traversal can be specified in
477 /// `config`.
478 SetVector<Value> findValueInReverseUseDefChain(
479 OpOperand *opOperand, llvm::function_ref<bool(Value)> condition,
480 TraversalConfig config = TraversalConfig(),
481 llvm::DenseSet<OpOperand *> *visitedOpOperands = nullptr) const;
482
483 /// Find the values that may define the contents of the given value at
484 /// runtime. A block argument is always a definition. An OpResult is a
485 /// definition if it bufferizes to memory write. If it does not bufferize to
486 /// a memory write but has aliasing operands, we continue the lookup on these
487 /// values.
488 ///
489 /// Example: %r = tensor.insert %f into %t[%c0] : tensor<?xf32>
490 /// findDefinitions(%r) = {%r} because %r bufferizes to memory write.
491 ///
492 /// Example: %r = tensor.empty() : tensor<10xf32>
493 /// findDefinitions(%r) = {} because tensor.empty does not the define the
494 /// contents of its result (i.e., it does not bufferize to a memory write)
495 /// and it has no aliasing OpOperands.
496 ///
497 /// Example:
498 /// %a = arith.constant ... : tensor<10xf32>
499 /// %b1 = tensor.insert %f into %t : tensor<50xf32>
500 /// %b2 = tensor.extract_slice %b1[0][10][1] : tensor<50xf32> tensor<10xf32>
501 /// %r = arith.select %cond, %a, %b : tensor<10xf32>
502 /// findDefinitions(%r) = {%a, %b1}. %r and %b2 are skipped (lookup continues
503 /// in the operands) because their defining ops do not define the contents of
504 /// the tensor.
505 ///
506 /// Example:
507 /// %a = tensor.empty() : tensor<10xf32>
508 /// %b = arith.constant ... : tensor<10xf32>
509 /// %r = arith.select %cond, %a, %b : tensor<10xf32>
510 /// findDefinitions(%r) = {%b}. %a is excluded because it does not define the
511 /// contents of the tensor.
512 ///
513 /// Note: OpResults of unknown ops are handled conservatively and assumed to
514 /// be definitions.
515 SetVector<Value> findDefinitions(OpOperand *opOperand) const;
516
517 /// Return `true` if the given OpResult has been decided to bufferize inplace.
518 virtual bool isInPlace(OpOperand &opOperand) const;
519
520 /// Return true if `v1` and `v2` bufferize to equivalent buffers.
521 virtual bool areEquivalentBufferizedValues(Value v1, Value v2) const;
522
523 /// Return true if `v1` and `v2` may bufferize to aliasing buffers.
524 virtual bool areAliasingBufferizedValues(Value v1, Value v2) const;
525
526 /// Return `true` if the given tensor has undefined contents.
527 virtual bool hasUndefinedContents(OpOperand *opOperand) const;
528
529 /// Return a reference to the BufferizationOptions.
530 const BufferizationOptions &getOptions() const { return options; }
531
532 AnalysisState(const BufferizationOptions &options);
533
534 // AnalysisState should be passed as a reference.
535 AnalysisState(const AnalysisState &) = delete;
536
537 virtual ~AnalysisState() = default;
538
539 static bool classof(const AnalysisState *base) { return true; }
540
541 TypeID getType() const { return type; }
542
543 /// Return the closest enclosing repetitive region around the given op.
544 Region *getEnclosingRepetitiveRegion(Operation *op,
545 const BufferizationOptions &options);
546
547 /// Return the closest enclosing repetitive region around the place where the
548 /// given value is defined.
549 Region *getEnclosingRepetitiveRegion(Value value,
550 const BufferizationOptions &options);
551
552 /// Return the closest enclosing repetitive region around the given block.
553 Region *getEnclosingRepetitiveRegion(Block *block,
554 const BufferizationOptions &options);
555
556 virtual void resetCache();
557
558 /// Checks whether `op0` and `op1` are inside mutually exclusive regions.
559 /// The logic defers to `mlir::insideMutuallyExclusiveRegions`, but the
560 /// result is cached.
561 bool insideMutuallyExclusiveRegions(Operation *op0, Operation *op1);
562
563protected:
564 AnalysisState(const BufferizationOptions &options, TypeID type);
565
566private:
567 /// A reference to current bufferization options.
568 const BufferizationOptions &options;
569
570 /// The type of analysis.
571 TypeID type;
572
573 /// Cache containing closest ancestor repetitive Region.
574 DenseMap<std::variant<Operation *, Block *, Region *, Value>, Region *>
575 enclosingRepetitiveRegionCache;
576
577 /// Cache that specifies whether the two operations are in mutually exclusive
578 /// regions.
579 DenseMap<std::pair<Operation *, Operation *>, bool>
580 insideMutuallyExclusiveRegionsCache;
581};
582
583/// BufferizationState provides information about the state of the IR during the
584/// bufferization process.
585class BufferizationState {
586public:
587 /// Get a reference to the collection of cached symbol tables.
588 SymbolTableCollection &getSymbolTables();
589
590private:
591 /// The cached symbol tables.
592 /// The user is expected to update / invalidate the cached symbol tables if
593 /// the bufferized operation has the Symbol or SymbolTable traits.
594 SymbolTableCollection symbolTables;
595};
596
597/// Create an AllocTensorOp for the given shaped value (memref or tensor).
598/// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
599/// undefined contents is allocated.
600FailureOr<Value>
601allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue,
602 const BufferizationOptions &options,
603 const BufferizationState &state, bool copy = true);
604
605/// Lookup the buffer for the given value. If the value was not bufferized
606/// yet, wrap it in a ToBufferOp. Otherwise, it is the result of a ToTensorOp,
607/// from which the memref operand is returned.
608FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
609 const BufferizationOptions &options,
610 const BufferizationState &state);
611
612/// Return the buffer type for a given Value (tensor) after bufferization
613/// without bufferizing any IR.
614///
615/// Note: It should be sufficient to call `getBuffer()->getType()` in most
616/// cases. However, when a buffer type should be predicted without modifying any
617/// IR, this function can be used.
618///
619/// This function is a wrapper around BufferizableOpInterface::getBufferType.
620FailureOr<BufferLikeType> getBufferType(Value value,
621 const BufferizationOptions &options,
622 const BufferizationState &state);
623
624/// Return the buffer type for a given Value (tensor) after bufferization
625/// without bufferizing any IR. This function (and not the other overload
626/// without `invocationStack`) can be used from `getBufferType` implementations
627/// of the `BufferizableOpInterface`.
628///
629/// Note: It should be sufficient to call `getBuffer()->getType()` in most
630/// cases. However, when a buffer type should be predicted without modifying any
631/// IR, this function can be used.
632///
633/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
634FailureOr<BufferLikeType> getBufferType(Value value,
635 const BufferizationOptions &options,
636 const BufferizationState &state,
637 SmallVector<Value> &invocationStack);
638
639/// Return "true" if the given op has tensor semantics and should be bufferized.
640/// If the op is bufferizable, the BufferizableOpInterface is queried.
641/// Otherwise, an op has tensor semantics if it has tensor operands, tensor
642/// op results and/or tensor block arguments.
643bool hasTensorSemantics(Operation *op);
644
645/// Replace an op with replacement values. The op is deleted. Tensor OpResults
646/// must be replaced with memref values.
647void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,
648 ValueRange values);
649
650/// Replace an op with a new op. The new op must have the same number of
651/// results as the replaced op. The new op may not return any tensor values.
652template <typename OpTy, typename... Args>
653OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
654 Args &&...args) {
655 auto newOp =
656 OpTy::create(rewriter, op->getLoc(), std::forward<Args>(args)...);
657 replaceOpWithBufferizedValues(rewriter, op, newOp->getResults());
658 return newOp;
659}
660
661/// Return a MemRefType to which the TensorType can be bufferized.
662///
663/// If possible, op bufferization implementations should not use this function
664/// and instead infer precise memref types for tensor results by themselves.
665///
666/// Unless a layout map was specified, `options.unknownTypeConverterFn`
667/// determines what kind of layout map will be used. For best composability
668/// (without copies), the fully dynamic layout map is used by default.
669///
670/// Note: Canonicalization patterns could clean up layout maps and infer more
671/// precise layout maps after bufferization. However, many possible
672/// canonicalizations are currently not implemented.
673BaseMemRefType getMemRefType(TensorType tensorType,
674 const BufferizationOptions &options,
675 MemRefLayoutAttrInterface layout = {},
676 Attribute memorySpace = nullptr);
677
678/// Return a MemRef type with fully dynamic layout. If the given tensor type
679/// is unranked, return an unranked MemRef type.
680BaseMemRefType
681getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
682 Attribute memorySpace = nullptr);
683
684/// Return a MemRef type with a static identity layout (i.e., no layout map). If
685/// the given tensor type is unranked, return an unranked MemRef type.
686BaseMemRefType
687getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
688 Attribute memorySpace = nullptr);
689
690/// Return the owner of the given value. In case of a BlockArgument that is the
691/// owner of the block. In case of an OpResult that is the defining op.
692Operation *getOwnerOfValue(Value value);
693
694/// Assuming that the given region is repetitive, find the next enclosing
695/// repetitive region.
696Region *getNextEnclosingRepetitiveRegion(Region *region,
697 const BufferizationOptions &options);
698
699/// If `region` is a parallel region, return `region`. Otherwise, find the first
700/// enclosing parallel region of `region`. If there is no such region, return
701/// "nullptr".
702///
703/// Note: Whether a region is parallel or sequential is queried from the
704/// `BufferizableOpInterface`.
705Region *getParallelRegion(Region *region, const BufferizationOptions &options);
706
707namespace detail {
708/// This is the default implementation of
709/// BufferizableOpInterface::getAliasingOpOperands. Should not be called from
710/// other places.
711AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
712 const AnalysisState &state);
713
714/// This is the default implementation of
715/// BufferizableOpInterface::getBufferType. Should not be called from other
716/// places.
717FailureOr<BufferLikeType>
718defaultGetBufferType(Value value, const BufferizationOptions &options,
719 const BufferizationState &state,
720 SmallVector<Value> &invocationStack);
721
722/// This is the default implementation of
723/// BufferizableOpInterface::resultBufferizesToMemoryWrite. Should not be called
724/// from other places.
725bool defaultResultBufferizesToMemoryWrite(OpResult opResult,
726 const AnalysisState &state);
727
728/// This is the default implementation of
729/// BufferizableOpInterface::isRepetitiveRegion. Should not be called from other
730/// places.
731bool defaultIsRepetitiveRegion(BufferizableOpInterface bufferizableOp,
732 unsigned index);
733
734/// This is the default implementation of getAliasingOpOperands in case the
735/// defining op does not implement the BufferizableOpInterface.
736AliasingOpOperandList unknownGetAliasingOpOperands(Value value);
737
738/// This is the default implementation of getAliasingValues in case the owner
739/// op does not implement the BufferizableOpInterface.
740AliasingValueList unknownGetAliasingValues(OpOperand &opOperand);
741
742/// This is the default implementation of
743/// BufferizableOpInterface::hasTensorSemantics
744bool defaultHasTensorSemantics(Operation *op);
745
746/// This is a helper function used when buffer type is guaranteed to be memref.
747/// It performs two actions: failure state checking and an explicit llvm::cast<>
748/// from the buffer-like type interface to a BaseMemRefType. This allows easier
749/// management of differences in C++ types at the API boundaries. Valid buffer
750/// type is casted to the memref type. Otherwise, the failure state is
751/// propagated i.e. asMemRefType(mlir::failure()) returns mlir::failure().
752FailureOr<BaseMemRefType> asMemRefType(FailureOr<BufferLikeType> bufferType);
753
754/// This function is a free-standing helper that relies on
755/// bufferization::TensorLikeTypeInterface to verify the types in tensor and
756/// buffer worlds match.
757bool typesMatchAfterBufferization(Operation &op, Value tensor, Value buffer);
758} // namespace detail
759
760} // namespace bufferization
761} // namespace mlir
762
763MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::bufferization::AnalysisState)
764
765//===----------------------------------------------------------------------===//
766// Bufferization Interfaces
767//===----------------------------------------------------------------------===//
768
769#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h.inc"
770
771#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_
bufferization::BufferResultsToOutParamsOpts::AllocationFn AllocationFn
bufferization::BufferResultsToOutParamsOpts::MemCpyFn MemCpyFn
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO)
#define MLIR_DECLARE_EXPLICIT_TYPE_ID(CLASS_NAME)
Definition TypeID.h:321
static Operation * getOwnerOfValue(Value value)
This class helps build Operations.
Definition Builders.h:207
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
bool insideMutuallyExclusiveRegions(Operation *a, Operation *b)
Return true if a and b are in mutually exclusive regions as per RegionBranchOpInterface.
Region * getEnclosingRepetitiveRegion(Operation *op)
Return the first enclosing region of the given op that may be executed repetitively as per RegionBran...