20 #include "llvm/ADT/ScopeExit.h"
27 namespace bufferization {
29 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc"
36 #define DEBUG_TYPE "bufferizable-op-interface"
39 using namespace bufferization;
44 if (
auto bufferizableOp =
options.dynCastBufferizableOp(op))
54 if (
auto iter = enclosingRepetitiveRegionCache.find_as(op);
55 iter != enclosingRepetitiveRegionCache.end())
57 return enclosingRepetitiveRegionCache[op] =
63 if (
auto iter = enclosingRepetitiveRegionCache.find_as(value);
64 iter != enclosingRepetitiveRegionCache.end())
72 visitedRegions.push_back(region);
77 enclosingRepetitiveRegionCache[value] = region;
78 for (
Region *r : visitedRegions)
79 enclosingRepetitiveRegionCache[r] = region;
85 if (
auto iter = enclosingRepetitiveRegionCache.find_as(block);
86 iter != enclosingRepetitiveRegionCache.end())
100 enclosingRepetitiveRegionCache[block] = region;
101 for (
Region *r : visitedRegions)
102 enclosingRepetitiveRegionCache[r] = region;
108 auto key = std::make_pair(op0, op1);
109 if (
auto iter = insideMutuallyExclusiveRegionsCache.find(key);
110 iter != insideMutuallyExclusiveRegionsCache.end())
114 insideMutuallyExclusiveRegionsCache[key] = result;
115 insideMutuallyExclusiveRegionsCache[std::make_pair(op1, op0)] = result;
120 enclosingRepetitiveRegionCache.clear();
121 insideMutuallyExclusiveRegionsCache.clear();
142 if (bufferizableOp &&
145 "expected that all parallel regions are also repetitive regions");
154 if (
auto opResult = llvm::dyn_cast<OpResult>(value))
155 return opResult.getDefiningOp();
156 return llvm::cast<BlockArgument>(value).getOwner()->
getParentOp();
167 if (llvm::isa<RankedTensorType>(shapedValue.
getType())) {
168 tensor = shapedValue;
169 }
else if (llvm::isa<MemRefType>(shapedValue.
getType())) {
170 tensor = ToTensorOp::create(
173 }
else if (llvm::isa<UnrankedTensorType>(shapedValue.
getType()) ||
174 llvm::isa<UnrankedMemRefType>(shapedValue.
getType())) {
176 ->
emitError(
"copying of unranked tensors is not implemented");
178 llvm_unreachable(
"expected RankedTensorType or MemRefType");
180 RankedTensorType tensorType = llvm::cast<RankedTensorType>(tensor.
getType());
185 bool reifiedShapes =
false;
186 if (llvm::isa<RankedTensorType>(shapedValue.
getType()) &&
187 llvm::isa<OpResult>(shapedValue)) {
191 reifiedShapes =
true;
193 resultDims[llvm::cast<OpResult>(shapedValue).getResultNumber()];
194 for (
const auto &dim :
enumerate(tensorType.getShape())) {
195 if (ShapedType::isDynamic(dim.value())) {
196 dynamicSizes.push_back(
209 auto allocTensorOp = AllocTensorOp::create(b, loc, tensorType, dynamicSizes,
214 return allocTensorOp.getResult();
215 auto copyBufferType =
217 if (
failed(copyBufferType))
219 std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
221 memorySpace =
options.defaultMemorySpaceFn(tensorType);
222 if (memorySpace.has_value())
223 allocTensorOp.setMemorySpaceAttr(memorySpace.value());
224 return allocTensorOp.getResult();
227 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
239 Type operandType = opOperand.get().getType();
240 if (!llvm::isa<TensorType>(operandType))
242 if (analysisState.isInPlace(opOperand))
244 if (llvm::isa<UnrankedTensorType>(operandType))
245 return op->
emitError(
"copying of unranked tensors is not implemented");
248 analysisState.getAliasingValues(opOperand);
250 isa<OpResult>(aliasingValues.
getAliases()[0].value) &&
251 !analysisState.bufferizesToMemoryWrite(opOperand) &&
253 .getAliasingOpOperands(aliasingValues.
getAliases()[0].value)
254 .getNumAliases() == 1 &&
255 !isa<UnrankedTensorType>(
256 aliasingValues.
getAliases()[0].value.getType())) {
264 outOfPlaceValues.push_back(value);
265 if (!analysisState.canOmitTensorCopy(opOperand))
266 copiedOpValues.insert(value);
269 outOfPlaceOpOperands.push_back(&opOperand);
270 if (!analysisState.canOmitTensorCopy(opOperand))
271 copiedOpOperands.insert(&opOperand);
277 for (
OpOperand *opOperand : outOfPlaceOpOperands) {
279 rewriter, op->
getLoc(), opOperand->get(), analysisState.getOptions(),
280 bufferizationState, copiedOpOperands.contains(opOperand));
288 for (
Value value : outOfPlaceValues) {
290 rewriter, op->
getLoc(), value, analysisState.getOptions(),
291 bufferizationState, copiedOpValues.count(value));
298 if (use->getOwner() ==
copy->getDefiningOp())
302 if (isa<tensor::DimOp>(use->getOwner()))
317 bool isAllowed = !hasAllowRule();
318 for (
const Entry &entry : entries) {
319 bool filterResult = entry.fn(op);
320 switch (entry.type) {
322 isAllowed |= filterResult;
342 defaultFunctionArgTypeConverter(TensorLikeType type,
Attribute memorySpace,
345 if (
auto tensorType = mlir::dyn_cast<TensorType>(type)) {
346 return cast<BufferLikeType>(
352 type.getBufferType(
options, [&]() {
return funcOp->emitError(); });
353 assert(succeeded(bufferType) &&
354 "a valid buffer is always expected at function boundary");
368 : functionArgTypeConverterFn(defaultFunctionArgTypeConverter),
369 unknownTypeConverterFn(defaultUnknownTypeConverter) {}
374 bool isFuncBoundaryOp = isa_and_nonnull<func::FuncDialect>(op->
getDialect());
381 BufferizableOpInterface
385 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
388 return bufferizableOp;
391 BufferizableOpInterface
397 LayoutMapOption layoutMapOption) {
401 if (
auto tensorType = mlir::dyn_cast<TensorType>(type)) {
402 if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
403 return cast<BufferLikeType>(
406 return cast<BufferLikeType>(
413 type.getBufferType(
options, [&]() {
return funcOp->emitError(); });
414 assert(succeeded(bufferType) &&
415 "a valid buffer is always expected at function boundary");
419 layoutMapOption == LayoutMapOption::InferLayoutMap;
427 if (
auto bbArg = llvm::dyn_cast<BlockArgument>(value)) {
438 if (
auto bufferizableOp =
getOptions().dynCastBufferizableOp(op))
439 return bufferizableOp.getAliasingOpOperands(value, *
this);
448 if (
auto bufferizableOp =
450 return bufferizableOp.getAliasingValues(opOperand, *
this);
459 if (
auto bufferizableOp =
461 return bufferizableOp.bufferizesToMemoryRead(opOperand, *
this);
471 if (
auto bufferizableOp =
473 return bufferizableOp.bufferizesToMemoryWrite(opOperand, *
this);
483 if (
auto bufferizableOp =
485 return bufferizableOp.bufferizesToAliasOnly(opOperand, *
this);
493 auto opResult = llvm::dyn_cast<OpResult>(value);
499 return bufferizableOp.resultBufferizesToMemoryWrite(opResult, *
this);
506 assert(llvm::isa<TensorType>(value.
getType()) &&
"expected TensorType");
510 workingSet.push_back(&use);
512 while (!workingSet.empty()) {
513 OpOperand *uMaybeReading = workingSet.pop_back_val();
514 if (!visited.insert(uMaybeReading).second)
520 for (
OpOperand &use : alias.value.getUses())
521 workingSet.push_back(&use);
540 workingSet.insert(opOperand->
get());
542 if (visitedOpOperands)
543 visitedOpOperands->insert(opOperand);
545 while (!workingSet.empty()) {
546 Value value = workingSet.pop_back_val();
548 if (!
config.revisitAlreadyVisitedValues && visited.contains(value)) {
550 if (
config.alwaysIncludeLeaves)
551 result.insert(value);
554 visited.insert(value);
556 if (condition(value)) {
557 result.insert(value);
564 if (
config.alwaysIncludeLeaves)
565 result.insert(value);
573 if (
config.alwaysIncludeLeaves)
574 result.insert(value);
579 if (
config.followEquivalentOnly &&
583 if (
config.alwaysIncludeLeaves)
584 result.insert(value);
591 if (
config.alwaysIncludeLeaves)
592 result.insert(value);
596 if (
config.followSameTypeOrCastsOnly &&
597 a.opOperand->get().getType() != value.
getType() &&
601 if (
config.alwaysIncludeLeaves)
602 result.insert(value);
606 workingSet.insert(a.opOperand->get());
607 if (visitedOpOperands)
608 visitedOpOperands->insert(a.opOperand);
619 config.alwaysIncludeLeaves =
false;
648 llvm::none_of(aliases,
658 if (isa<ToBufferOp>(opOperand.
getOwner()))
686 auto rankedTensorType = llvm::dyn_cast<RankedTensorType>(tensor.
getType());
687 assert((!rankedTensorType || llvm::cast<MemRefType>(memrefType).getRank() ==
688 rankedTensorType.getRank()) &&
689 "to_buffer would be invalid: mismatching ranks");
697 auto tensorType = llvm::dyn_cast<TensorLikeType>(value.
getType());
698 assert(tensorType &&
"unexpected non-tensor type");
702 if (
auto toTensorOp = value.
getDefiningOp<bufferization::ToTensorOp>())
703 return toTensorOp.getBuffer();
712 return bufferization::ToBufferOp::create(rewriter, value.
getLoc(),
718 FailureOr<BufferLikeType>
726 FailureOr<BufferLikeType>
730 assert(llvm::isa<TensorLikeType>(value.
getType()) &&
731 "unexpected non-tensor type");
732 invocationStack.push_back(value);
734 llvm::make_scope_exit([&]() { invocationStack.pop_back(); });
738 auto bufferizableOp =
options.dynCastBufferizableOp(op);
740 return bufferizableOp.getBufferType(value,
options, state, invocationStack);
743 return cast<TensorLikeType>(value.
getType()).getBufferType(
options, [&]() {
744 return op->emitError();
749 if (
auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
750 return bufferizableOp.hasTensorSemantics();
758 "expected one value per OpResult");
764 Value replacement = values[opResult.getResultNumber()];
765 if (llvm::isa<TensorLikeType>(opResult.getType())) {
768 assert(llvm::isa<BufferLikeType>(replacement.
getType()) &&
769 "tensor op result should be replaced with a buffer value");
774 replacement = bufferization::ToTensorOp::create(
775 rewriter, replacement.
getLoc(), opResult.getType(), replacement);
777 replacements.push_back(replacement);
796 return memref::AllocOp::create(b, loc, type, dynShape,
799 return memref::AllocOp::create(b, loc, type, dynShape).getResult();
806 return (*
memCpyFn)(b, loc, from, to);
808 memref::CopyOp::create(b, loc, from, to);
818 MemRefLayoutAttrInterface layout,
821 if (
auto unrankedTensorType =
822 llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
823 assert(!layout &&
"UnrankedTensorType cannot have a layout map");
829 auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
832 rankedTensorType.getElementType(), layout,
836 return options.unknownTypeConverterFn(tensorType, memorySpace,
options);
843 if (
auto unrankedTensorType =
844 llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
850 auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
851 int64_t dynamicOffset = ShapedType::kDynamic;
853 ShapedType::kDynamic);
855 dynamicOffset, dynamicStrides);
857 rankedTensorType.getElementType(), stridedLayout,
867 if (
auto unrankedTensorType =
868 llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
874 auto rankedTensorType = llvm::cast<RankedTensorType>(tensorType);
875 MemRefLayoutAttrInterface layout = {};
877 rankedTensorType.getElementType(), layout,
887 auto bufferizableOp = cast<BufferizableOpInterface>(opResult.
getDefiningOp());
889 bufferizableOp.getAliasingOpOperands(opResult, state);
899 return state.bufferizesToMemoryWrite(*alias.
opOperand);
932 auto isMemoryWriteInsideOp = [&](
Value v) {
936 return state.bufferizesToMemoryWrite(v);
939 config.alwaysIncludeLeaves =
false;
942 .findValueInReverseUseDefChain(alias.
opOperand,
943 isMemoryWriteInsideOp,
config)
957 if (!llvm::isa<TensorType>(opOperand.
get().
getType()))
960 for (
const auto &it : aliasingValues)
961 if (it.value == value)
962 result.emplace_back(&opOperand, it.relation, it.isDefinite);
971 assert(llvm::isa<TensorType>(value.
getType()) &&
"expected tensor type");
972 auto tensorType = cast<TensorType>(value.
getType());
975 if (llvm::isa<BlockArgument>(value)) {
976 return cast<BufferLikeType>(
982 auto opResult = llvm::cast<OpResult>(value);
989 Value equivalentOperand = aliases.
getAliases().front().opOperand->get();
998 if (!memSpace.has_value())
999 return op->
emitError(
"could not infer memory space");
1001 return cast<BufferLikeType>(
1006 BufferizableOpInterface bufferizableOp,
unsigned index) {
1007 assert(index < bufferizableOp->getNumRegions() &&
"invalid region index");
1008 auto regionInterface =
1009 dyn_cast<RegionBranchOpInterface>(bufferizableOp.getOperation());
1010 if (!regionInterface)
1012 return regionInterface.isRepetitiveRegion(index);
1019 if (
auto bbArg = dyn_cast<BlockArgument>(value))
1020 if (bbArg.getOwner() != &bbArg.getOwner()->getParent()->getBlocks().front())
1028 if (isa<TensorType>(operand.get().getType()))
1041 if (llvm::isa<TensorType>(result.getType()))
1044 if (!region.getBlocks().empty())
1045 for (
BlockArgument bbArg : region.getBlocks().front().getArguments())
1046 if (isa<TensorType>(bbArg.getType()))
1052 auto isaTensor = [](
Type t) {
return isa<TensorLikeType>(t); };
1054 return any_of(r.getBlocks(), [&](Block &b) {
1055 return any_of(b.getArguments(), [&](BlockArgument bbArg) {
1056 return isaTensor(bbArg.getType());
1060 if (hasTensorBlockArgument)
1068 FailureOr<BaseMemRefType>
1072 return cast<BaseMemRefType>(*bufferType);
1078 return mlir::succeeded(
1079 cast<TensorLikeType>(tensor.
getType())
1080 .verifyCompatibleBufferType(cast<BufferLikeType>(buffer.
getType()),
1081 [&]() { return op.emitError(); }));
static void setInsertionPointAfter(OpBuilder &b, Value value)
static bool isRepetitiveRegion(Region *region, const BufferizationOptions &options)
static void ensureToBufferOpIsValid(Value tensor, Type memrefType)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static bool isaTensor(Type t)
static llvm::ManagedStatic< PassManagerOptions > options
#define MLIR_DEFINE_EXPLICIT_TYPE_ID(CLASS_NAME)
Base class for generic analysis states.
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
IntegerAttr getI64IntegerAttr(int64_t value)
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_type_range getOperandTypes()
MutableArrayRef< OpOperand > getOpOperands()
result_type_range getResultTypes()
bool isAncestor(Operation *other)
Return true if this operation is an ancestor of the other operation.
result_range getOpResults()
Region * getParentRegion()
Returns the region to which the instruction belongs.
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Region * getParentRegion()
Return the region containing this region or nullptr if the region is attached to a top-level operatio...
unsigned getRegionNumber()
Return the number of this region in the parent operation.
Operation * getParentOp()
Return the parent operation this region is attached to.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class represents a collection of SymbolTables.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
This class provides an efficient unique identifier for a specific C++ type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Region * getParentRegion()
Return the Region in which this Value is defined.
size_t getNumAliases() const
ArrayRef< T > getAliases() const
AnalysisState provides a variety of helper functions for dealing with tensor values.
bool isValueRead(Value value) const
Return true if the given value is read by an op that bufferizes to a memory read.
AliasingValueList getAliasingValues(OpOperand &opOperand) const
Determine which Value will alias with opOperand if the op is bufferized in place.
SetVector< Value > findValueInReverseUseDefChain(OpOperand *opOperand, llvm::function_ref< bool(Value)> condition, TraversalConfig config=TraversalConfig(), llvm::DenseSet< OpOperand * > *visitedOpOperands=nullptr) const
Starting from opOperand, follow the use-def chain in reverse, always selecting the aliasing OpOperand...
virtual bool areAliasingBufferizedValues(Value v1, Value v2) const
Return true if v1 and v2 may bufferize to aliasing buffers.
virtual bool hasUndefinedContents(OpOperand *opOperand) const
Return true if the given tensor has undefined contents.
bool insideMutuallyExclusiveRegions(Operation *op0, Operation *op1)
Checks whether op0 and op1 are inside mutually exclusive regions.
bool canOmitTensorCopy(OpOperand &opOperand) const
Return true if a copy can always be avoided when allocating a new tensor for the given OpOperand.
bool bufferizesToMemoryWrite(OpOperand &opOperand) const
Return true if opOperand bufferizes to a memory write.
virtual bool isInPlace(OpOperand &opOperand) const
Return true if the given OpResult has been decided to bufferize inplace.
SetVector< Value > findDefinitions(OpOperand *opOperand) const
Find the values that may define the contents of the given value at runtime.
bool bufferizesToAliasOnly(OpOperand &opOperand) const
Return true if opOperand does neither read nor write but bufferizes to an alias.
AliasingOpOperandList getAliasingOpOperands(Value value) const
Determine which OpOperand* will alias with value if the op is bufferized in place.
AnalysisState(const BufferizationOptions &options)
Region * getEnclosingRepetitiveRegion(Operation *op, const BufferizationOptions &options)
Return the closest enclosing repetitive region around the given op.
const BufferizationOptions & getOptions() const
Return a reference to the BufferizationOptions.
bool bufferizesToMemoryRead(OpOperand &opOperand) const
Return true if opOperand bufferizes to a memory read.
virtual bool areEquivalentBufferizedValues(Value v1, Value v2) const
Return true if v1 and v2 bufferize to equivalent buffers.
virtual void resetCache()
BufferizationState provides information about the state of the IR during the bufferization process.
SymbolTableCollection & getSymbolTables()
Get a reference to the collection of cached symbol tables.
bool isOpAllowed(Operation *op) const
Return whether the op is allowed or not.
Operation * getOwner() const
Return the owner of this operand.
AliasingOpOperandList defaultGetAliasingOpOperands(Value value, const AnalysisState &state)
This is the default implementation of BufferizableOpInterface::getAliasingOpOperands.
bool defaultResultBufferizesToMemoryWrite(OpResult opResult, const AnalysisState &state)
This is the default implementation of BufferizableOpInterface::resultBufferizesToMemoryWrite.
FailureOr< BufferLikeType > defaultGetBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector< Value > &invocationStack)
This is the default implementation of BufferizableOpInterface::getBufferType.
AliasingValueList unknownGetAliasingValues(OpOperand &opOperand)
This is the default implementation of getAliasingValues in case the owner op does not implement the B...
bool defaultIsRepetitiveRegion(BufferizableOpInterface bufferizableOp, unsigned index)
This is the default implementation of BufferizableOpInterface::isRepetitiveRegion.
FailureOr< BaseMemRefType > asMemRefType(FailureOr< BufferLikeType > bufferType)
This is a helper function used when buffer type is guaranteed to be memref.
AliasingOpOperandList unknownGetAliasingOpOperands(Value value)
This is the default implementation of getAliasingOpOperands in case the defining op does not implemen...
bool defaultHasTensorSemantics(Operation *op)
This is the default implementation of BufferizableOpInterface::hasTensorSemantics.
bool typesMatchAfterBufferization(Operation &op, Value tensor, Value buffer)
This function is a free-standing helper that relies on bufferization::TensorLikeTypeInterface to veri...
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
Operation * getOwnerOfValue(Value value)
Return the owner of the given value.
Region * getParallelRegion(Region *region, const BufferizationOptions &options)
If region is a parallel region, return region.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
Region * getNextEnclosingRepetitiveRegion(Region *region, const BufferizationOptions &options)
Assuming that the given region is repetitive, find the next enclosing repetitive region.
FailureOr< BufferLikeType > getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
AliasList< AliasingOpOperand > AliasingOpOperandList
A list of possible aliasing OpOperands.
FailureOr< Value > allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue, const BufferizationOptions &options, const BufferizationState &state, bool copy=true)
Create an AllocTensorOp for the given shaped value (memref or tensor).
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue, SmallVector< Value > &dynamicDims)
Populate dynamicDims with tensor::DimOp / memref::DimOp results for all dynamic dimensions of the giv...
bool hasTensorSemantics(Operation *op)
Return "true" if the given op has tensor semantics and should be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
const FrozenRewritePatternSet GreedyRewriteConfig config
bool insideMutuallyExclusiveRegions(Operation *a, Operation *b)
Return true if a and b are in mutually exclusive regions as per RegionBranchOpInterface.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
A maybe aliasing OpOperand.
Options for BufferizableOpInterface-based bufferization.
std::function< void(AnalysisState &)> AnalysisStateInitFn
Initializer function for analysis state.
void setFunctionBoundaryTypeConversion(LayoutMapOption layoutMapOption)
This function controls buffer types on function signatures.
BufferizableOpInterface dynCastBufferizableOp(Operation *op) const
Try to cast the given op to BufferizableOpInterface if the op is allow listed.
bool inferFunctionResultLayout
If true, function result types are inferred from the body of the function.
unsigned int bufferAlignment
Buffer alignment for new memory allocations.
FunctionArgTypeConverterFn functionArgTypeConverterFn
Type conversion from tensors to buffers.
std::optional< AllocationFn > allocationFn
Helper functions for allocation and memory copying.
OpFilter opFilter
A filter that specifies which ops should be bufferized and which ops should be ignored.
bool isOpAllowed(Operation *op) const
Return true if the given op should be bufferized.
std::optional< MemCpyFn > memCpyFn
bool bufferizeFunctionBoundaries
Specifies whether function boundaries (ops in the func dialect) should be bufferized or not.
FailureOr< Value > createAlloc(OpBuilder &b, Location loc, MemRefType type, ValueRange dynShape) const
Create a memref allocation with the given type and dynamic extents.
LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, Value to) const
Creates a memcpy between two given buffers.
SmallVector< AnalysisStateInitFn > stateInitializers
Initializer functions for analysis state.
Traversal parameters for findValueInReverseUseDefChain.