9 #ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_ 10 #define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_ 15 #include "llvm/ADT/SetVector.h" 20 namespace bufferization {
23 class BufferizableOpInterface;
32 using FilterFn = std::function<bool(Operation *)>;
52 template <
typename... DialectTs>
56 (allowDialectImpl<DialectTs>(), ...);
62 template <
typename... DialectTs>
64 (denyDialectImpl<DialectTs>(), ...);
72 return op->getDialect()->getNamespace() == dialectNamespace;
74 entries.push_back(
Entry{filterFn, Entry::FilterType::ALLOW});
80 template <
typename... OpTys>
82 (allowOperationImpl<OpTys>(), ...);
88 template <
typename... OpTys>
90 (denyOperationImpl<OpTys>(), ...);
98 return op->getName().getStringRef() == opName;
108 return op->getName().getStringRef() == opName;
117 entries.push_back(
Entry{
fn, Entry::FilterType::ALLOW});
124 entries.push_back(
Entry{
fn, Entry::FilterType::DENY});
129 bool hasAllowRule()
const {
130 for (
const Entry &e : entries)
131 if (e.type == Entry::FilterType::ALLOW)
137 template <
typename DialectT>
138 void allowDialectImpl() {
143 template <
typename DialectT>
144 void denyDialectImpl() {
149 template <
typename OpTy>
150 void allowOperationImpl() {
155 template <
typename OpTy>
156 void denyOperationImpl() {
176 std::function<LogicalResult(OpBuilder &, Location, Value)>;
179 std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
184 std::function<std::unique_ptr<DialectAnalysisState>()>;
192 IdentityLayoutMap = 1,
193 FullyDynamicLayoutMap = 2
200 BufferizableOpInterface dynCastBufferizableOp(
Operation *op)
const;
204 BufferizableOpInterface dynCastBufferizableOp(
Value value)
const;
220 ValueRange dynShape)
const;
225 Value allocatedBuffer)
const;
234 bool allowUnknownOps =
false;
238 bool bufferizeFunctionBoundaries =
false;
253 bool enforceAliasingInvariants =
true;
274 LayoutMapOption::InferLayoutMap;
280 UnknownTypeConverterFn unknownTypeConverterFn =
nullptr;
284 bool createDeallocs =
true;
288 unsigned analysisFuzzerSeed = 0;
292 bool testAnalysisOnly =
false;
296 bool printConflicts =
false;
299 unsigned int bufferAlignment = 128;
347 bool bufferizesToMemoryRead(
OpOperand &opOperand)
const;
351 bool bufferizesToMemoryWrite(
OpOperand &opOperand)
const;
355 bool bufferizesToAliasOnly(
OpOperand &opOperand)
const;
359 bool canOmitTensorCopy(
OpOperand &opOperand)
const;
364 bool isValueRead(
Value value)
const;
403 virtual bool isInPlace(
OpOperand &opOperand)
const;
406 virtual bool areEquivalentBufferizedValues(
Value v1,
Value v2)
const;
409 virtual bool areAliasingBufferizedValues(
Value v1,
Value v2)
const;
412 virtual bool hasUndefinedContents(
OpOperand *opOperand)
const;
419 virtual bool isTensorYielded(
Value tensor)
const;
423 auto it = dialectState.find(name);
424 return it != dialectState.end();
428 template <
typename StateT>
430 auto it = dialectState.find(name);
431 if (it == dialectState.end())
433 return static_cast<const StateT *
>(it->getSecond().get());
437 template <
typename StateT>
440 if (!hasDialectState(name))
441 dialectState[name] = std::make_unique<StateT>();
442 return static_cast<StateT &
>(*dialectState[name]);
446 std::unique_ptr<DialectAnalysisState> state) {
447 assert(!dialectState.count(name) &&
"dialect state already initialized");
448 dialectState[name] = std::move(state);
501 template <
typename OpTy,
typename... Args>
504 auto newOp = rewriter.
create<OpTy>(op->
getLoc(), std::forward<Args>(args)...);
528 MemRefLayoutAttrInterface layout = {},
529 unsigned memorySpace = 0);
534 unsigned memorySpace = 0);
539 unsigned memorySpace = 0);
544 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h.inc" 546 #endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_ Include the generated interface declarations.
void allowOperation(StringRef opName)
Allow the given op.
bool isOpAllowed(Operation *op) const
Return whether the op is allowed or not.
Optional< const StateT * > getDialectState(StringRef name) const
Return dialect-specific bufferization state.
FailureOr< Value > allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue, bool escape, const BufferizationOptions &options, bool copy=true)
Create an AllocTensorOp for the given shaped value (memref or tensor).
FilterType
Filter type: A filter can either be a DENY filter or an ALLOW filter.
void denyOperation(StringRef opName)
Deny the given op.
Operation is a basic unit of execution within MLIR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
OpFilter opFilter
A filter that specifies which ops should be bufferized and which ops should be ignored.
This is a value defined by a result of an operation.
const BufferizationOptions & getOptions() const
Return a reference to the BufferizationOptions.
void allowOperation()
Allow the given ops.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
void allowDialect()
Allow the given dialects.
Optional< DeallocationFn > deallocationFn
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
std::function< LogicalResult(OpBuilder &, Location, Value, Value)> MemCpyFn
Memcpy function: Generate a memcpy between two buffers.
bool hasDialectState(StringRef name) const
Return true if the given dialect state exists.
Optional< AllocationFn > allocationFn
Helper functions for allocation, deallocation, memory copying.
std::function< std::unique_ptr< DialectAnalysisState >()> DialectStateInitFn
Initializer function for dialect-specific analysis state.
bool shouldDeallocateOpResult(OpResult opResult, const BufferizationOptions &options)
Return true if the buffer of given OpResult should be deallocated.
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, unsigned memorySpace=0)
Return a MemRef type with a static identity layout (i.e., no layout map).
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
AnalysisState provides a variety of helper functions for dealing with tensor values.
This class represents an efficient way to signal success or failure.
std::function< LogicalResult(OpBuilder &, Location, Value)> DeallocationFn
Deallocator function: Deallocate a buffer that was allocated with AllocatorFn.
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization.
This class provides support for representing a failure result, or a valid value of type T...
Dialect-specific analysis state.
SmallVector< AnalysisStateInitFn > stateInitializers
Initializer functions for analysis state.
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, unsigned memorySpace=0)
Return a MemRefType to which the type of the given value can be bufferized.
void denyOperation()
Deny the given ops.
bool allocationDoesNotEscape(OpResult opResult)
Return true if the allocation of the given op is guaranteed to not escape the containing block...
Location getLoc()
The source location the operation was defined or derived from.
bool isFunctionArgument(Value value)
Return true if the given value is a BlockArgument of a func::FuncOp.
std::function< bool(Operation *)> FilterFn
If the filter function evaluates to true, the filter matches.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Options for BufferizableOpInterface-based bufferization.
void denyDialect()
Deny the given dialects.
std::function< BaseMemRefType(Value, unsigned, const BufferizationOptions &)> UnknownTypeConverterFn
Tensor -> MemRef type converter.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static llvm::ManagedStatic< PassManagerOptions > options
void allowDialect(StringRef dialectNamespace)
Allow the given dialect.
OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op, Args &&...args)
Replace an op with a new op.
This class provides a shared interface for ranked and unranked memref types.
void allowOperation(Entry::FilterFn fn)
Allow ops that are matched by fn.
This class represents an operand of an operation.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, unsigned memorySpace=0)
Return a MemRef type with fully dynamic layout.
StateT & getOrCreateDialectState(StringRef name)
Return dialect-specific analysis state or create one if none exists.
BufferRelation
Specify fine-grain relationship between buffers to enable more analysis.
void insertDialectState(StringRef name, std::unique_ptr< DialectAnalysisState > state)
void denyOperation(Entry::FilterFn fn)
Deny ops that are matched by fn.
Optional< MemCpyFn > memCpyFn
std::function< FailureOr< Value >(OpBuilder &, Location, MemRefType, ValueRange, unsigned int)> AllocationFn
Allocator function: Generate a memref allocation with the given type, dynamic extents and alignment...
result_range getResults()
This class helps build Operations.
This class provides an abstraction over the different types of ranges over Values.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::function< void(AnalysisState &)> AnalysisStateInitFn
Initializer function for analysis state.