9#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_
10#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZABLEOPINTERFACE_H_
15#include "llvm/ADT/DenseMapInfoVariant.h"
16#include "llvm/ADT/SetVector.h"
19#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
31class BufferizableOpInterface;
34enum class BufferRelation {
43struct AliasingOpOperand {
44 AliasingOpOperand(OpOperand *opOperand, BufferRelation relation,
45 bool isDefinite =
true)
46 : opOperand(opOperand), relation(relation), isDefinite(isDefinite) {}
49 BufferRelation relation;
56 AliasingValue(Value value, BufferRelation relation,
bool isDefinite =
true)
57 : value(value), relation(relation), isDefinite(isDefinite) {}
60 BufferRelation relation;
68 AliasList() =
default;
71 AliasList(std::initializer_list<T> elems) {
77 AliasList(SmallVector<T> &&aliases) : aliases(std::move(aliases)) {}
79 ArrayRef<T> getAliases()
const {
return aliases; }
81 size_t getNumAliases()
const {
return aliases.size(); }
83 void addAlias(T alias) { aliases.push_back(alias); }
85 auto begin()
const {
return aliases.begin(); }
86 auto end()
const {
return aliases.end(); }
90 SmallVector<T> aliases;
95using AliasingOpOperandList = AliasList<AliasingOpOperand>;
99using AliasingValueList = AliasList<AliasingValue>;
107 using FilterFn = std::function<bool(Operation *)>;
110 enum FilterType : int8_t { DENY = 0, ALLOW = 1 };
122 bool isOpAllowed(Operation *op)
const;
127 template <
typename... DialectTs>
128 void allowDialect() {
131 (allowDialectImpl<DialectTs>(), ...);
137 template <
typename... DialectTs>
139 (denyDialectImpl<DialectTs>(), ...);
145 void allowDialect(StringRef dialectNamespace) {
146 Entry::FilterFn filterFn = [=](Operation *op) {
147 return op->getName().getDialectNamespace() == dialectNamespace;
149 entries.push_back(Entry{filterFn, Entry::FilterType::ALLOW});
155 void denyDialect(StringRef dialectNamespace) {
156 Entry::FilterFn filterFn = [=](Operation *op) {
157 return op->getName().getDialectNamespace() == dialectNamespace;
159 entries.push_back(Entry{filterFn, Entry::FilterType::DENY});
165 template <
typename... OpTys>
166 void allowOperation() {
167 (allowOperationImpl<OpTys>(), ...);
173 template <
typename... OpTys>
174 void denyOperation() {
175 (denyOperationImpl<OpTys>(), ...);
181 void allowOperation(StringRef opName) {
182 Entry::FilterFn filterFn = [=](Operation *op) {
183 return op->getName().getStringRef() == opName;
185 allowOperation(filterFn);
191 void denyOperation(StringRef opName) {
192 Entry::FilterFn filterFn = [=](Operation *op) {
193 return op->getName().getStringRef() == opName;
195 denyOperation(filterFn);
201 void allowOperation(Entry::FilterFn fn) {
202 entries.push_back(Entry{fn, Entry::FilterType::ALLOW});
208 void denyOperation(Entry::FilterFn fn) {
209 entries.push_back(Entry{fn, Entry::FilterType::DENY});
214 bool hasAllowRule()
const {
215 for (
const Entry &e : entries)
216 if (e.type == Entry::FilterType::ALLOW)
222 template <
typename DialectT>
223 void allowDialectImpl() {
224 allowDialect(DialectT::getDialectNamespace());
228 template <
typename DialectT>
229 void denyDialectImpl() {
230 denyDialect(DialectT::getDialectNamespace());
234 template <
typename OpTy>
235 void allowOperationImpl() {
236 allowOperation(OpTy::getOperationName());
240 template <
typename OpTy>
241 void denyOperationImpl() {
242 denyOperation(OpTy::getOperationName());
249 SmallVector<Entry> entries;
253struct BufferizationOptions {
257 OpBuilder &, Location, MemRefType,
ValueRange,
unsigned int)>;
260 std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
262 using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
265 using FunctionArgTypeConverterFn =
266 std::function<BufferLikeType(TensorLikeType, Attribute memorySpace,
267 func::FuncOp,
const BufferizationOptions &)>;
270 using UnknownTypeConverterFn = std::function<BaseMemRefType(
271 TensorType, Attribute memorySpace,
const BufferizationOptions &)>;
273 using DefaultMemorySpaceFn =
274 std::function<std::optional<Attribute>(TensorType t)>;
276 BufferizationOptions();
280 BufferizableOpInterface dynCastBufferizableOp(Operation *op)
const;
284 BufferizableOpInterface dynCastBufferizableOp(Value value)
const;
291 bool isOpAllowed(Operation *op)
const;
294 std::optional<AllocationFn> allocationFn;
295 std::optional<MemCpyFn> memCpyFn;
298 FailureOr<Value> createAlloc(OpBuilder &
b, Location loc, MemRefType type,
302 LogicalResult createMemCpy(OpBuilder &
b, Location loc, Value from,
308 bool allowUnknownOps =
false;
312 bool bufferizeFunctionBoundaries =
false;
317 bool checkParallelRegions =
true;
336 void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption);
346 FunctionArgTypeConverterFn functionArgTypeConverterFn =
nullptr;
353 bool inferFunctionResultLayout =
true;
358 UnknownTypeConverterFn unknownTypeConverterFn =
nullptr;
364 DefaultMemorySpaceFn defaultMemorySpaceFn =
365 [](TensorType t) -> std::optional<Attribute> {
return Attribute(); };
369 bool copyBeforeWrite =
false;
373 bool testAnalysisOnly =
false;
377 bool printConflicts =
false;
380 unsigned int bufferAlignment = 64;
384 SmallVector<AnalysisStateInitFn> stateInitializers;
388struct TraversalConfig {
391 bool alwaysIncludeLeaves =
true;
394 bool followInPlaceOnly =
false;
397 bool followEquivalentOnly =
false;
401 bool followUnknownOps =
false;
405 bool followSameTypeOrCastsOnly =
false;
409 bool revisitAlreadyVisitedValues =
false;
419 AliasingOpOperandList getAliasingOpOperands(Value value)
const;
423 AliasingValueList getAliasingValues(OpOperand &opOperand)
const;
427 bool bufferizesToMemoryRead(OpOperand &opOperand)
const;
431 bool bufferizesToMemoryWrite(OpOperand &opOperand)
const;
436 bool bufferizesToMemoryWrite(Value value)
const;
440 bool bufferizesToAliasOnly(OpOperand &opOperand)
const;
444 bool canOmitTensorCopy(OpOperand &opOperand)
const;
449 bool isValueRead(Value value)
const;
478 SetVector<Value> findValueInReverseUseDefChain(
479 OpOperand *opOperand, llvm::function_ref<
bool(Value)> condition,
480 TraversalConfig config = TraversalConfig(),
481 llvm::DenseSet<OpOperand *> *visitedOpOperands =
nullptr)
const;
515 SetVector<Value> findDefinitions(OpOperand *opOperand)
const;
518 virtual bool isInPlace(OpOperand &opOperand)
const;
521 virtual bool areEquivalentBufferizedValues(Value v1, Value v2)
const;
524 virtual bool areAliasingBufferizedValues(Value v1, Value v2)
const;
527 virtual bool hasUndefinedContents(OpOperand *opOperand)
const;
530 const BufferizationOptions &getOptions()
const {
return options; }
532 AnalysisState(
const BufferizationOptions &
options);
535 AnalysisState(
const AnalysisState &) =
delete;
537 virtual ~AnalysisState() =
default;
539 static bool classof(
const AnalysisState *base) {
return true; }
541 TypeID
getType()
const {
return type; }
545 const BufferizationOptions &
options);
550 const BufferizationOptions &
options);
554 const BufferizationOptions &
options);
556 virtual void resetCache();
564 AnalysisState(
const BufferizationOptions &
options, TypeID type);
568 const BufferizationOptions &
options;
574 DenseMap<std::variant<Operation *, Block *, Region *, Value>, Region *>
575 enclosingRepetitiveRegionCache;
579 DenseMap<std::pair<Operation *, Operation *>,
bool>
580 insideMutuallyExclusiveRegionsCache;
585class BufferizationState {
588 SymbolTableCollection &getSymbolTables();
594 SymbolTableCollection symbolTables;
601allocateTensorForShapedValue(OpBuilder &
b, Location loc, Value shapedValue,
602 const BufferizationOptions &
options,
603 const BufferizationState &state,
bool copy =
true);
608FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
609 const BufferizationOptions &
options,
610 const BufferizationState &state);
621 const BufferizationOptions &
options,
622 const BufferizationState &state);
635 const BufferizationOptions &
options,
636 const BufferizationState &state,
637 SmallVector<Value> &invocationStack);
643bool hasTensorSemantics(Operation *op);
647void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,
652template <
typename OpTy,
typename... Args>
653OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
656 OpTy::create(rewriter, op->getLoc(), std::forward<Args>(args)...);
657 replaceOpWithBufferizedValues(rewriter, op, newOp->getResults());
674 const BufferizationOptions &
options,
675 MemRefLayoutAttrInterface layout = {},
676 Attribute memorySpace =
nullptr);
681getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
682 Attribute memorySpace =
nullptr);
687getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
688 Attribute memorySpace =
nullptr);
696Region *getNextEnclosingRepetitiveRegion(Region *region,
697 const BufferizationOptions &
options);
705Region *getParallelRegion(Region *region,
const BufferizationOptions &
options);
711AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
712 const AnalysisState &state);
717FailureOr<BufferLikeType>
718defaultGetBufferType(Value value,
const BufferizationOptions &
options,
719 const BufferizationState &state,
720 SmallVector<Value> &invocationStack);
725bool defaultResultBufferizesToMemoryWrite(OpResult opResult,
726 const AnalysisState &state);
731bool defaultIsRepetitiveRegion(BufferizableOpInterface bufferizableOp,
736AliasingOpOperandList unknownGetAliasingOpOperands(Value value);
740AliasingValueList unknownGetAliasingValues(OpOperand &opOperand);
744bool defaultHasTensorSemantics(Operation *op);
752FailureOr<BaseMemRefType> asMemRefType(FailureOr<BufferLikeType> bufferType);
757bool typesMatchAfterBufferization(Operation &op, Value tensor, Value buffer);
769#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h.inc"
bufferization::BufferResultsToOutParamsOpts::AllocationFn AllocationFn
bufferization::BufferResultsToOutParamsOpts::MemCpyFn MemCpyFn
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static llvm::ManagedStatic< PassManagerOptions > options
static RankedTensorType getBufferType(const SparseTensorType &stt, bool needTmpCOO)
#define MLIR_DECLARE_EXPLICIT_TYPE_ID(CLASS_NAME)
static Operation * getOwnerOfValue(Value value)
This class helps build Operations.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
bool insideMutuallyExclusiveRegions(Operation *a, Operation *b)
Return true if a and b are in mutually exclusive regions as per RegionBranchOpInterface.
Region * getEnclosingRepetitiveRegion(Operation *op)
Return the first enclosing region of the given op that may be executed repetitively as per RegionBran...