24 for (
Block &
b : funcOp.getBody())
25 if (
auto returnOp = dyn_cast<func::ReturnOp>(
b.getTerminator()))
26 result.push_back(returnOp);
36 auto createdAliasingResults =
41 (
void)createdAliasingResults;
45 assert(createdEquiv.second &&
"equivalence info exists already");
46 assert(createdAliasingResults.second &&
"aliasing info exists already");
47 assert(createdRead.second &&
"bbarg access info exists already");
48 assert(createdWritten.second &&
"bbarg access info exists already");
56 TensorLikeType type) {
57 if (
auto tensorType = dyn_cast<TensorType>(type)) {
58 return *
options.defaultMemorySpaceFn(tensorType);
70 dyn_cast<TensorLikeType>(funcOp.getFunctionType().getInput(
index));
71 assert(type &&
"expected TensorLikeType");
74 if (
auto tensorType = dyn_cast<TensorType>(type)) {
75 BufferLikeType memrefType =
options.functionArgTypeConverterFn(
78 auto layoutAttr = funcOp.getArgAttrOfType<MemRefLayoutAttrInterface>(
79 index, BufferizationDialect::kBufferLayoutAttrName);
83 auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
84 assert(rankedMemrefType &&
85 "buffer layout not supported on unranked tensors");
86 return cast<BufferLikeType>(MemRefType::get(
87 rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
88 layoutAttr, rankedMemrefType.getMemorySpace()));
91 return options.functionArgTypeConverterFn(type,
nullptr, funcOp,
98 return dyn_cast_or_null<FuncOp>(callOp.resolveCallableInTable(&symbolTables));
106 if (
auto *funcAnalysisState =
117static const FuncAnalysisState &
119 assert(isa<OneShotAnalysisState>(state) &&
"expected OneShotAnalysisState");
121 .getExtension<FuncAnalysisState>();
122 assert(
result &&
"FuncAnalysisState does not exist");
129 if (!isa<OneShotAnalysisState>(state))
132 .getExtension<FuncAnalysisState>();
135 const auto &analyzedFuncOps = funcState->analyzedFuncOps;
136 auto it = analyzedFuncOps.find(funcOp);
137 if (it == analyzedFuncOps.end())
144static std::optional<int64_t>
152 auto retValIt = funcOpIt->getSecond().find(returnValIdx);
153 if (retValIt == funcOpIt->getSecond().end())
157 return retValIt->getSecond();
161 :
public BufferizableOpInterface::ExternalModel<CallOpInterface,
165 func::CallOp callOp = cast<func::CallOp>(op);
167 assert(funcOp &&
"expected CallOp to a FuncOp");
174 return funcState.
readBbArgs.lookup(funcOp).contains(
180 func::CallOp callOp = cast<func::CallOp>(op);
182 assert(funcOp &&
"expected CallOp to a FuncOp");
195 func::CallOp callOp = cast<func::CallOp>(op);
197 assert(funcOp &&
"expected CallOp to a FuncOp");
200 return detail::unknownGetAliasingValues(opOperand);
204 auto aliasingReturnVals =
209 std::optional<int64_t> equivalent = {};
210 if (aliasingReturnVals.size() == 1) {
212 aliasingReturnVals.front());
213 assert((!equivalent.has_value() ||
215 "inconsistent analysis state");
218 for (
int64_t resultIdx : aliasingReturnVals)
219 result.addAlias({callOp->getOpResult(resultIdx),
220 equivalent.has_value() ? BufferRelation::Equivalent
221 : BufferRelation::Unknown,
222 equivalent.has_value()});
226 FailureOr<BufferLikeType>
228 const BufferizationState &state,
230 auto callOp = cast<func::CallOp>(op);
234 assert(funcOp &&
"expected CallOp to a FuncOp");
238 FunctionType funcType = funcOp.getFunctionType();
240 funcType.getResult(cast<OpResult>(value).getResultNumber());
241 if (
auto bufferizedType = dyn_cast<BufferLikeType>(resultType))
242 return bufferizedType;
245 auto tensorType = cast<TensorLikeType>(resultType);
246 return cast<BufferLikeType>(
options.functionArgTypeConverterFn(
255 BufferizationState &state)
const {
256 func::CallOp callOp = cast<func::CallOp>(op);
262 if (!isa<TensorLikeType>(returnType)) {
264 resultTypes.push_back(returnType);
269 FailureOr<BufferLikeType> resultType =
271 if (failed(resultType))
273 resultTypes.push_back(*resultType);
281 assert(funcOp &&
"expected CallOp to a FuncOp");
282 FunctionType funcType = funcOp.getFunctionType();
284 for (
OpOperand &opOperand : callOp->getOpOperands()) {
286 if (!isa<TensorLikeType>(opOperand.get().getType())) {
287 newOperands.push_back(opOperand.get());
292 FailureOr<Value> maybeBuffer =
293 getBuffer(rewriter, opOperand.get(),
options, state);
294 if (failed(maybeBuffer))
296 Value buffer = *maybeBuffer;
299 auto bufferType = funcType.getInput(opOperand.getOperandNumber());
300 if (!isa<BufferLikeType>(bufferType)) {
304 FailureOr<BufferLikeType> maybeBufferType =
305 bufferization::getBufferType(
306 funcOp.getArgument(opOperand.getOperandNumber()),
options,
308 if (failed(maybeBufferType))
310 bufferType = *maybeBufferType;
319 if (buffer.
getType() != bufferType) {
320 auto memrefDstType = dyn_cast<MemRefType>(bufferType);
321 assert(memrefDstType &&
322 "buffer layout not supported on unranked tensors");
324 rewriter, buffer, memrefDstType,
options);
329 newOperands.push_back(buffer);
334 func::CallOp::create(rewriter, callOp.getLoc(), funcOp.getSymName(),
335 resultTypes, newOperands);
336 newCallOp->
setAttrs(callOp->getAttrs());
339 replaceOpWithBufferizedValues(rewriter, callOp, newCallOp->
getResults());
346 :
public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
365 BufferizationState &state)
const {
367 auto returnOp = cast<func::ReturnOp>(op);
368 assert(isa<FuncOp>(returnOp->getParentOp()) &&
369 "only support FuncOp parent for ReturnOp");
379 FuncOpInterface, FuncOp> {
384 auto isaTensor = llvm::IsaPred<TensorLikeType>;
387 auto funcOp = cast<FuncOp>(op);
388 bool hasTensorArg = any_of(funcOp.getArgumentTypes(),
isaTensor);
389 bool hasTensorResult = any_of(funcOp.getResultTypes(),
isaTensor);
390 if (hasTensorArg || hasTensorResult)
398 for (
Block &block : funcOp.getBody())
399 if (any_of(block.getArgumentTypes(),
isaTensor))
405 AliasingOpOperandList
411 FailureOr<BufferLikeType>
413 const BufferizationState &state,
415 auto funcOp = cast<FuncOp>(op);
416 auto bbArg = cast<BlockArgument>(value);
419 if (bbArg.getOwner() == &funcOp.getBody().front())
436 BufferizationState &state)
const {
437 auto funcOp = cast<FuncOp>(op);
438 FunctionType funcType = funcOp.getFunctionType();
442 for (
const auto &it : llvm::enumerate(funcType.getInputs())) {
443 Type argType = it.value();
444 if (isa<TensorLikeType>(argType)) {
449 argTypes.push_back(argType);
454 for (
Type resultType : funcType.getResults()) {
455 if (
auto tensorType = dyn_cast<TensorLikeType>(resultType)) {
456 BufferLikeType resultType =
options.functionArgTypeConverterFn(
459 retTypes.push_back(resultType);
462 retTypes.push_back(resultType);
466 auto newFuncType = FunctionType::get(op->
getContext(), argTypes, retTypes);
469 if (funcOp.isExternal()) {
470 funcOp.setType(newFuncType);
475 for (
Block &block : funcOp.getBody())
482 assert(returnOp->getNumOperands() == retTypes.size() &&
483 "incorrect number of return values");
485 for (
auto [returnVal, bufferizedType] :
486 llvm::zip_equal(returnOp->getOperands(), retTypes)) {
487 auto tensorType = dyn_cast<TensorLikeType>(returnVal.getType());
492 returnValues.push_back(returnVal);
498 Value toBufferOp = bufferization::ToBufferOp::create(
499 rewriter, returnOp.getLoc(), bufferizedType, returnVal);
500 returnValues.push_back(toBufferOp);
503 returnOp.getOperandsMutable().assign(returnValues);
507 funcOp.setType(newFuncType);
514 auto funcOp = cast<FuncOp>(op);
516 assert(bbArg &&
"expected BlockArgument");
526 bbArg.
getArgNumber(), BufferizationDialect::kWritableAttrName))
527 return writable.getValue();
541 func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx);
542 func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx);
543 func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx);
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static bool isaTensor(Type t)
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block * getOwner() const
Returns the block that owns this argument.
Block represents an ordered list of Operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
result_range getResults()
MLIRContext * getContext()
Return the context this operation is associated with.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
State for analysis-enabled bufferization.
static BufferLikeType getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, const BufferizationOptions &options)
Return the index-th bufferized function argument type.
FuncOpAnalysisState
The state of analysis of a FuncOp.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
static FuncOp getCalledFunction(CallOpInterface callOp, SymbolTableCollection &symbolTables)
Return the FuncOp called by callOp.
static std::optional< int64_t > getEquivalentFuncArgIdx(FuncOp funcOp, const FuncAnalysisState &state, int64_t returnValIdx)
Return the index of the bbArg in the given FuncOp that is equivalent to the specified return value (i...
static mlir::Attribute getDefaultMemorySpace(const BufferizationOptions &options, TensorLikeType type)
static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, FuncOp funcOp)
Return the state (phase) of analysis of the FuncOp.
static const FuncAnalysisState & getFuncAnalysisState(const AnalysisState &state)
Get FuncAnalysisState.
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type, const BufferizationOptions &options)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
SmallVector< func::ReturnOp > getReturnOps(func::FuncOp funcOp)
Helper function that returns all func.return ops in the given function.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
Include the generated interface declarations.
A template that provides a default implementation of getAliasingOpOperands for ops that support unstr...
AliasingOpOperandList getAliasingBranchOpOperands(Operation *op, BlockArgument bbArg, const AnalysisState &state) const
FailureOr< BufferLikeType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector< Value > &invocationStack) const
FailureOr< BufferLikeType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector< Value > &invocationStack) const
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state) const
All function arguments are writable.
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
Extra analysis state that is required for bufferization of function boundaries.
DenseMap< FuncOp, IndexMapping > equivalentFuncArgs
A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg indices.
DenseMap< int64_t, SmallVector< int64_t > > IndexToIndexListMapping
A mapping of indices to a list of indices.
DenseMap< FuncOp, IndexToIndexListMapping > aliasingReturnVals
A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
DenseMap< FuncOp, BbArgIndexSet > readBbArgs
A set of all read BlockArguments of FuncOps.
DenseSet< int64_t > BbArgIndexSet
A set of block argument indices.
DenseMap< FuncOp, BbArgIndexSet > writtenBbArgs
A set of all written-to BlockArguments of FuncOps.
DenseMap< FuncOp, FuncOpAnalysisState > analyzedFuncOps
Keep track of which FuncOps are fully analyzed or currently being analyzed.
void startFunctionAnalysis(FuncOp funcOp)
This function is called right before analyzing the given FuncOp.
DenseMap< int64_t, int64_t > IndexMapping
A mapping of indices to indices.
FailureOr< BufferLikeType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector< Value > &invocationStack) const
AliasingOpOperandList getAliasingOpOperands(Operation *op, Value value, const AnalysisState &state) const
bool isWritable(Operation *op, Value value, const AnalysisState &state) const
Return true if the given function argument is writable.
static bool supportsUnstructuredControlFlow()
bool hasTensorSemantics(Operation *op) const
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state) const
Rewrite function bbArgs and return values into buffer form.
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state) const