21 #include "llvm/ADT/ScopeExit.h"
22 #include "llvm/Support/Debug.h"
29 namespace bufferization {
31 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
38 #define DEBUG_TYPE "bufferizable-op-interface"
39 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
40 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
43 using namespace bufferization;
48 if (
auto bufferizableOp =
options.dynCastBufferizableOp(op))
58 if (
auto iter = enclosingRepetitiveRegionCache.find_as(op);
59 iter != enclosingRepetitiveRegionCache.end())
61 return enclosingRepetitiveRegionCache[op] =
67 if (
auto iter = enclosingRepetitiveRegionCache.find_as(value);
68 iter != enclosingRepetitiveRegionCache.end())
76 visitedRegions.push_back(region);
81 enclosingRepetitiveRegionCache[value] = region;
82 for (
Region *r : visitedRegions)
83 enclosingRepetitiveRegionCache[r] = region;
89 if (
auto iter = enclosingRepetitiveRegionCache.find_as(block);
90 iter != enclosingRepetitiveRegionCache.end())
104 enclosingRepetitiveRegionCache[block] = region;
105 for (
Region *r : visitedRegions)
106 enclosingRepetitiveRegionCache[r] = region;
126 if (bufferizableOp &&
129 "expected that all parallel regions are also repetitive regions");
138 if (
auto opResult = llvm::dyn_cast<OpResult>(value))
139 return opResult.getDefiningOp();
140 return llvm::cast<BlockArgument>(value).getOwner()->
getParentOp();
150 if (llvm::isa<RankedTensorType>(shapedValue.
getType())) {
151 tensor = shapedValue;
152 }
else if (llvm::isa<MemRefType>(shapedValue.
getType())) {
153 tensor = b.
create<ToTensorOp>(loc, shapedValue);
154 }
else if (llvm::isa<UnrankedTensorType>(shapedValue.
getType()) ||
155 llvm::isa<UnrankedMemRefType>(shapedValue.
getType())) {
157 ->
emitError(
"copying of unranked tensors is not implemented");
159 llvm_unreachable(
"expected RankedTensorType or MemRefType");
161 RankedTensorType tensorType = llvm::cast<RankedTensorType>(tensor.
getType());
166 bool reifiedShapes =
false;
167 if (llvm::isa<RankedTensorType>(shapedValue.
getType()) &&
168 llvm::isa<OpResult>(shapedValue)) {
172 reifiedShapes =
true;
174 resultDims[llvm::cast<OpResult>(shapedValue).getResultNumber()];
175 for (
const auto &dim :
enumerate(tensorType.getShape()))
176 if (ShapedType::isDynamic(dim.value()))
177 dynamicSizes.push_back(shape[dim.index()].get<
Value>());
187 auto allocTensorOp = b.
create<AllocTensorOp>(loc, tensorType, dynamicSizes,
194 if (
failed(copyBufferType))
196 Attribute memorySpace = copyBufferType->getMemorySpace();
199 allocTensorOp.setMemorySpaceAttr(memorySpace);
200 return allocTensorOp.getResult();
203 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
214 Type operandType = opOperand.get().getType();
215 if (!llvm::isa<TensorType>(operandType))
217 if (state.isInPlace(opOperand))
219 if (llvm::isa<UnrankedTensorType>(operandType))
220 return op->
emitError(
"copying of unranked tensors is not implemented");
224 isa<OpResult>(aliasingValues.
getAliases()[0].value) &&
225 !state.bufferizesToMemoryWrite(opOperand) &&
226 state.getAliasingOpOperands(aliasingValues.
getAliases()[0].value)
227 .getNumAliases() == 1 &&
228 !isa<UnrankedTensorType>(
229 aliasingValues.
getAliases()[0].value.getType())) {
237 outOfPlaceValues.push_back(value);
238 if (!state.canOmitTensorCopy(opOperand))
239 copiedOpValues.insert(value);
242 outOfPlaceOpOperands.push_back(&opOperand);
243 if (!state.canOmitTensorCopy(opOperand))
244 copiedOpOperands.insert(&opOperand);
250 for (
OpOperand *opOperand : outOfPlaceOpOperands) {
252 rewriter, op->
getLoc(), opOperand->get(), state.getOptions(),
253 copiedOpOperands.contains(opOperand));
261 for (
Value value : outOfPlaceValues) {
263 rewriter, op->
getLoc(), value, state.getOptions(),
264 copiedOpValues.count(value));
271 if (use->getOwner() ==
copy->getDefiningOp())
275 if (isa<tensor::DimOp>(use->getOwner()))
290 bool isAllowed = !hasAllowRule();
291 for (
const Entry &entry : entries) {
292 bool filterResult = entry.fn(op);
293 switch (entry.type) {
295 isAllowed |= filterResult;
325 llvm::cast<TensorType>(value.
getType()), memorySpace);
332 : functionArgTypeConverterFn(defaultFunctionArgTypeConverter),
333 unknownTypeConverterFn(defaultUnknownTypeConverter) {}
338 bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->
getDialect());
345 BufferizableOpInterface
347 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
352 return bufferizableOp;
355 BufferizableOpInterface
361 LayoutMapOption layoutMapOption) {
365 if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
372 layoutMapOption == LayoutMapOption::InferLayoutMap;
380 if (
auto bbArg = llvm::dyn_cast<BlockArgument>(value)) {
391 if (
auto bufferizableOp =
getOptions().dynCastBufferizableOp(op))
392 return bufferizableOp.getAliasingOpOperands(value, *
this);
401 if (
auto bufferizableOp =
403 return bufferizableOp.getAliasingValues(opOperand, *
this);
412 if (
auto bufferizableOp =
414 return bufferizableOp.bufferizesToMemoryRead(opOperand, *
this);
424 if (
auto bufferizableOp =
426 return bufferizableOp.bufferizesToMemoryWrite(opOperand, *
this);
436 if (
auto bufferizableOp =
438 return bufferizableOp.bufferizesToAliasOnly(opOperand, *
this);
446 auto opResult = llvm::dyn_cast<OpResult>(value);
452 return bufferizableOp.resultBufferizesToMemoryWrite(opResult, *
this);
459 assert(llvm::isa<TensorType>(value.
getType()) &&
"expected TensorType");
463 workingSet.push_back(&use);
465 while (!workingSet.empty()) {
466 OpOperand *uMaybeReading = workingSet.pop_back_val();
467 if (visited.contains(uMaybeReading))
469 visited.insert(uMaybeReading);
474 for (
OpOperand &use : alias.value.getUses())
475 workingSet.push_back(&use);
492 workingSet.insert(value);
494 while (!workingSet.empty()) {
495 Value value = workingSet.pop_back_val();
500 result.insert(value);
503 visited.insert(value);
505 if (condition(value)) {
506 result.insert(value);
514 result.insert(value);
523 result.insert(value);
533 result.insert(value);
541 result.insert(value);
546 a.opOperand->get().getType() != value.
getType() &&
551 result.insert(value);
555 workingSet.insert(a.opOperand->get());
593 llvm::none_of(aliases,
603 if (isa<ToMemrefOp>(opOperand.
getOwner()))
631 auto rankedTensorType = llvm::dyn_cast<RankedTensorType>(tensor.
getType());
632 assert((!rankedTensorType || llvm::cast<MemRefType>(memrefType).getRank() ==
633 rankedTensorType.getRank()) &&
634 "to_memref would be invalid: mismatching ranks");
641 auto tensorType = llvm::dyn_cast<TensorType>(value.
getType());
642 assert(tensorType &&
"unexpected non-tensor type");
646 if (
auto toTensorOp = value.
getDefiningOp<bufferization::ToTensorOp>())
647 return toTensorOp.getMemref();
657 .
create<bufferization::ToMemrefOp>(value.
getLoc(), *memrefType, value)
672 assert(llvm::isa<TensorType>(value.
getType()) &&
673 "unexpected non-tensor type");
674 invocationStack.push_back(value);
676 llvm::make_scope_exit([&]() { invocationStack.pop_back(); });
680 auto bufferizableOp =
options.dynCastBufferizableOp(op);
682 return bufferizableOp.getBufferType(value,
options, invocationStack);
685 if (!
options.defaultMemorySpace.has_value())
686 return op->
emitError(
"could not infer memory space");
696 "expected one value per OpResult");
702 Value replacement = values[opResult.getResultNumber()];
703 if (llvm::isa<TensorType>(opResult.getType())) {
706 assert((llvm::isa<MemRefType>(replacement.
getType()) ||
707 llvm::isa<UnrankedMemRefType>(replacement.
getType())) &&
708 "tensor op result should be replaced with a memref value");
713 replacement = rewriter.
create<bufferization::ToTensorOp>(
714 replacement.
getLoc(), replacement);
716 replacements.push_back(replacement);
736 .
create<memref::AllocOp>(loc, type, dynShape,
739 return b.
create<memref::AllocOp>(loc, type, dynShape).getResult();
746 return (*
memCpyFn)(b, loc, from, to);
748 b.
create<memref::CopyOp>(loc, from, to);
757 auto bbArg = llvm::dyn_cast<BlockArgument>(value);
760 return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
765 MemRefLayoutAttrInterface layout,
767 auto tensorType = llvm::cast<TensorType>(value.
getType());
770 if (
auto unrankedTensorType =
771 llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
772 assert(!layout &&
"UnrankedTensorType cannot have a layout map");
778 auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
781 rankedTensorType.getElementType(), layout,
785 return options.unknownTypeConverterFn(value, memorySpace,
options);
792 if (
auto unrankedTensorType =
793 llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
799 auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
800 int64_t dynamicOffset = ShapedType::kDynamic;
802 ShapedType::kDynamic);
804 dynamicOffset, dynamicStrides);
806 rankedTensorType.getElementType(), stridedLayout,
816 if (
auto unrankedTensorType =
817 llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
823 auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
824 MemRefLayoutAttrInterface layout = {};
826 rankedTensorType.getElementType(), layout,
836 auto bufferizableOp = cast<BufferizableOpInterface>(opResult.
getDefiningOp());
838 bufferizableOp.getAliasingOpOperands(opResult, state);
848 return state.bufferizesToMemoryWrite(*alias.
opOperand);
881 auto isMemoryWriteInsideOp = [&](
Value v) {
885 return state.bufferizesToMemoryWrite(v);
892 isMemoryWriteInsideOp, config)
906 if (!llvm::isa<TensorType>(opOperand.get().getType()))
909 for (
const auto &it : aliasingValues)
910 if (it.value == value)
911 result.emplace_back(&opOperand, it.relation, it.isDefinite);
919 assert(llvm::isa<TensorType>(value.
getType()) &&
"expected tensor type");
922 if (llvm::isa<BlockArgument>(value))
927 auto opResult = llvm::cast<OpResult>(value);
934 Value equivalentOperand = aliases.
getAliases().front().opOperand->get();
940 if (!
options.defaultMemorySpace.has_value())
941 return op->
emitError(
"could not infer memory space");
948 BufferizableOpInterface bufferizableOp,
unsigned index) {
949 assert(index < bufferizableOp->getNumRegions() &&
"invalid region index");
950 auto regionInterface =
951 dyn_cast<RegionBranchOpInterface>(bufferizableOp.getOperation());
952 if (!regionInterface)
954 return regionInterface.isRepetitiveRegion(index);
961 if (
auto bbArg = dyn_cast<BlockArgument>(value))
962 if (bbArg.getOwner() != &bbArg.getOwner()->getParent()->getBlocks().front())
970 if (isa<TensorType>(operand.get().getType()))
983 if (llvm::isa<TensorType>(result.getType()))
986 if (!region.getBlocks().empty())
987 for (
BlockArgument bbArg : region.getBlocks().front().getArguments())
static void ensureToMemrefOpIsValid(Value tensor, Type memrefType)
static void setInsertionPointAfter(OpBuilder &b, Value value)
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static llvm::ManagedStatic< PassManagerOptions > options
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Base class for generic analysis states.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
IntegerAttr getI64IntegerAttr(int64_t value)
This class provides support for representing a failure result, or a valid value of type T.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
MutableArrayRef< OpOperand > getOpOperands()
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
result_range getOpResults()
Region * getParentRegion()
Returns the region to which the instruction belongs.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
This class provides an efficient unique identifier for a specific C++ type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
size_t getNumAliases() const
ArrayRef< T > getAliases() const
AnalysisState provides a variety of helper functions for dealing with tensor values.
bool isValueRead(Value value) const
Return true if the given value is read by an op that bufferizes to a memory read.
AliasingValueList getAliasingValues(OpOperand &opOperand) const
Determine which Value will alias with opOperand if the op is bufferized in place.
virtual bool areAliasingBufferizedValues(Value v1, Value v2) const
Return true if v1 and v2 may bufferize to aliasing buffers.
virtual bool hasUndefinedContents(OpOperand *opOperand) const
Return true if the given tensor has undefined contents.
bool canOmitTensorCopy(OpOperand &opOperand) const
Return true if a copy can always be avoided when allocating a new tensor for the given OpOperand.
bool bufferizesToMemoryWrite(OpOperand &opOperand) const
Return true if opOperand bufferizes to a memory write.
virtual bool isInPlace(OpOperand &opOperand) const
Return true if the given OpResult has been decided to bufferize inplace.
bool bufferizesToAliasOnly(OpOperand &opOperand) const
Return true if opOperand does neither read nor write but bufferizes to an alias.
AliasingOpOperandList getAliasingOpOperands(Value value) const
Determine which OpOperand* will alias with value if the op is bufferized in place.
AnalysisState(const BufferizationOptions &options)
Region * getEnclosingRepetitiveRegion(Operation *op, const BufferizationOptions &options)
Return the closest enclosing repetitive region around the given op.
const BufferizationOptions & getOptions() const
Return a reference to the BufferizationOptions.
bool bufferizesToMemoryRead(OpOperand &opOperand) const
Return true if opOperand bufferizes to a memory read.
SetVector< Value > findValueInReverseUseDefChain(Value value, llvm::function_ref< bool(Value)> condition, TraversalConfig config=TraversalConfig()) const
Starting from value, follow the use-def chain in reverse, always selecting the aliasing OpOperands.
SetVector< Value > findDefinitions(Value value) const
Find the values that may define the contents of the given value at runtime.
virtual bool areEquivalentBufferizedValues(Value v1, Value v2) const
Return true if v1 and v2 bufferize to equivalent buffers.
virtual void resetCache()
bool isOpAllowed(Operation *op) const
Return whether the op is allowed or not.
Operation * getOwner() const
Return the owner of this operand.
AliasingOpOperandList defaultGetAliasingOpOperands(Value value, const AnalysisState &state)
This is the default implementation of BufferizableOpInterface::getAliasingOpOperands.
bool defaultResultBufferizesToMemoryWrite(OpResult opResult, const AnalysisState &state)
This is the default implementation of BufferizableOpInterface::resultBufferizesToMemoryWrite.
AliasingValueList unknownGetAliasingValues(OpOperand &opOperand)
This is the default implementation of getAliasingValues in case the owner op does not implement the B...
bool defaultIsRepetitiveRegion(BufferizableOpInterface bufferizableOp, unsigned index)
This is the default implementation of BufferizableOpInterface::isRepetitiveRegion.
AliasingOpOperandList unknownGetAliasingOpOperands(Value value)
This is the default implementation of getAliasingOpOperands in case the defining op does not implemen...
FailureOr< BaseMemRefType > defaultGetBufferType(Value value, const BufferizationOptions &options, SmallVector< Value > &invocationStack)
This is the default implementation of BufferizableOpInterface::getBufferType.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
Operation * getOwnerOfValue(Value value)
Return the owner of the given value.
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
Region * getParallelRegion(Region *region, const BufferizationOptions &options)
If region is a parallel region, return region.
Region * getNextEnclosingRepetitiveRegion(Region *region, const BufferizationOptions &options)
Assuming that the given region is repetitive, find the next enclosing repetitive region.
AliasList< AliasingOpOperand > AliasingOpOperandList
A list of possible aliasing OpOperands.
bool isFunctionArgument(Value value)
Return true if the given value is a BlockArgument of a func::FuncOp.
FailureOr< Value > allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue, const BufferizationOptions &options, bool copy=true)
Create an AllocTensorOp for the given shaped value (memref or tensor).
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue, SmallVector< Value > &dynamicDims)
Populate dynamicDims with tensor::DimOp / memref::DimOp results for all dynamic dimensions of the giv...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
A maybe aliasing OpOperand.
Options for BufferizableOpInterface-based bufferization.
std::function< void(AnalysisState &)> AnalysisStateInitFn
Initializer function for analysis state.
void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption)
This function controls buffer types on function signatures.
BufferizableOpInterface dynCastBufferizableOp(Operation *op) const
Try to cast the given op to BufferizableOpInterface if the op is allow listed.
bool inferFunctionResultLayout
If true, function result types are inferred from the body of the function.
unsigned int bufferAlignment
Buffer alignment for new memory allocations.
FunctionArgTypeConverterFn functionArgTypeConverterFn
Type converter from tensors to memrefs.
std::optional< AllocationFn > allocationFn
Helper functions for allocation and memory copying.
OpFilter opFilter
A filter that specifies which ops should be bufferized and which ops should be ignored.
bool isOpAllowed(Operation *op) const
Return true if the given op should be bufferized.
std::optional< MemCpyFn > memCpyFn
bool bufferizeFunctionBoundaries
Specifies whether function boundaries (ops in the func dialect) should be bufferized or not.
FailureOr< Value > createAlloc(OpBuilder &b, Location loc, MemRefType type, ValueRange dynShape) const
Create a memref allocation with the given type and dynamic extents.
LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, Value to) const
Creates a memcpy between two given buffers.
SmallVector< AnalysisStateInitFn > stateInitializers
Initializer functions for analysis state.
Traversal parameters for findValueInReverseUseDefChain.
bool followUnknownOps
Specifies whether unknown/non-bufferizable/ops not included in the OpFilter of BufferizationOptions s...
bool alwaysIncludeLeaves
Specifies if leaves (that do not have further OpOperands to follow) should be returned even if they d...
bool followSameTypeOrCastsOnly
Specifies whether OpOperands with a different type that are not the result of a CastOpInterface op sh...
bool followInPlaceOnly
Specifies whether out-of-place/undecided OpOperands should be followed.
bool followEquivalentOnly
Specifies whether non-equivalent OpOperands should be followed.
bool revisitAlreadyVisitedValues
Specifies whether already visited values should be visited again.