24 for (
Block &
b : funcOp.getBody())
25 if (
auto returnOp = dyn_cast<func::ReturnOp>(
b.getTerminator()))
26 result.push_back(returnOp);
36 auto createdAliasingResults =
41 (
void)createdAliasingResults;
45 assert(createdEquiv.second &&
"equivalence info exists already");
46 assert(createdAliasingResults.second &&
"aliasing info exists already");
47 assert(createdRead.second &&
"bbarg access info exists already");
48 assert(createdWritten.second &&
"bbarg access info exists already");
56 TensorLikeType type) {
57 if (
auto tensorType = dyn_cast<TensorType>(type)) {
58 return *
options.defaultMemorySpaceFn(tensorType);
70 dyn_cast<TensorLikeType>(funcOp.getFunctionType().getInput(
index));
71 assert(type &&
"expected TensorLikeType");
74 if (
auto tensorType = dyn_cast<TensorType>(type)) {
75 BufferLikeType memrefType =
options.functionArgTypeConverterFn(
78 auto layoutAttr = funcOp.getArgAttrOfType<MemRefLayoutAttrInterface>(
79 index, BufferizationDialect::kBufferLayoutAttrName);
83 auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
84 assert(rankedMemrefType &&
85 "buffer layout not supported on unranked tensors");
86 return cast<BufferLikeType>(MemRefType::get(
87 rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
88 layoutAttr, rankedMemrefType.getMemorySpace()));
91 return options.functionArgTypeConverterFn(type,
nullptr, funcOp,
98 return dyn_cast_or_null<FuncOp>(callOp.resolveCallableInTable(&symbolTables));
106 if (
auto *funcAnalysisState =
117static const FuncAnalysisState &
119 assert(isa<OneShotAnalysisState>(state) &&
"expected OneShotAnalysisState");
121 .getExtension<FuncAnalysisState>();
122 assert(
result &&
"FuncAnalysisState does not exist");
129 if (!isa<OneShotAnalysisState>(state))
132 .getExtension<FuncAnalysisState>();
135 const auto &analyzedFuncOps = funcState->analyzedFuncOps;
136 auto it = analyzedFuncOps.find(funcOp);
137 if (it == analyzedFuncOps.end())
144static std::optional<int64_t>
152 auto retValIt = funcOpIt->getSecond().find(returnValIdx);
153 if (retValIt == funcOpIt->getSecond().end())
157 return retValIt->getSecond();
161 :
public BufferizableOpInterface::ExternalModel<CallOpInterface,
165 func::CallOp callOp = cast<func::CallOp>(op);
167 assert(funcOp &&
"expected CallOp to a FuncOp");
174 return funcState.
readBbArgs.lookup(funcOp).contains(
180 func::CallOp callOp = cast<func::CallOp>(op);
182 assert(funcOp &&
"expected CallOp to a FuncOp");
195 func::CallOp callOp = cast<func::CallOp>(op);
197 assert(funcOp &&
"expected CallOp to a FuncOp");
200 return detail::unknownGetAliasingValues(opOperand);
204 auto aliasingReturnVals =
209 std::optional<int64_t> equivalent = {};
210 if (aliasingReturnVals.size() == 1) {
212 aliasingReturnVals.front());
213 assert((!equivalent.has_value() ||
215 "inconsistent analysis state");
218 for (
int64_t resultIdx : aliasingReturnVals)
219 result.addAlias({callOp->getOpResult(resultIdx),
220 equivalent.has_value() ? BufferRelation::Equivalent
221 : BufferRelation::Unknown,
222 equivalent.has_value()});
226 FailureOr<BufferLikeType>
228 const BufferizationState &state,
230 auto callOp = cast<func::CallOp>(op);
236 assert(funcOp &&
"expected CallOp to a FuncOp");
240 FunctionType funcType = funcOp.getFunctionType();
242 funcType.getResult(cast<OpResult>(value).getResultNumber());
243 if (
auto bufferizedType = dyn_cast<BufferLikeType>(resultType))
244 return bufferizedType;
247 auto tensorType = cast<TensorLikeType>(resultType);
248 return cast<BufferLikeType>(
options.functionArgTypeConverterFn(
257 BufferizationState &state)
const {
258 func::CallOp callOp = cast<func::CallOp>(op);
264 if (!isa<TensorLikeType>(returnType)) {
266 resultTypes.push_back(returnType);
271 FailureOr<BufferLikeType> resultType =
273 if (failed(resultType))
275 resultTypes.push_back(*resultType);
283 assert(funcOp &&
"expected CallOp to a FuncOp");
284 FunctionType funcType = funcOp.getFunctionType();
286 for (
OpOperand &opOperand : callOp->getOpOperands()) {
288 if (!isa<TensorLikeType>(opOperand.get().getType())) {
289 newOperands.push_back(opOperand.get());
294 FailureOr<Value> maybeBuffer =
295 getBuffer(rewriter, opOperand.get(),
options, state);
296 if (failed(maybeBuffer))
298 Value buffer = *maybeBuffer;
301 auto bufferType = funcType.getInput(opOperand.getOperandNumber());
302 if (!isa<BufferLikeType>(bufferType)) {
306 FailureOr<BufferLikeType> maybeBufferType =
307 bufferization::getBufferType(
308 funcOp.getArgument(opOperand.getOperandNumber()),
options,
310 if (failed(maybeBufferType))
312 bufferType = *maybeBufferType;
321 if (buffer.
getType() != bufferType) {
322 auto memrefDstType = dyn_cast<MemRefType>(bufferType);
323 assert(memrefDstType &&
324 "buffer layout not supported on unranked tensors");
326 rewriter, buffer, memrefDstType,
options);
331 newOperands.push_back(buffer);
336 func::CallOp::create(rewriter, callOp.getLoc(), funcOp.getSymName(),
337 resultTypes, newOperands);
338 newCallOp->
setAttrs(callOp->getAttrs());
341 replaceOpWithBufferizedValues(rewriter, callOp, newCallOp->
getResults());
348 :
public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
367 BufferizationState &state)
const {
369 auto returnOp = cast<func::ReturnOp>(op);
370 assert(isa<FuncOp>(returnOp->getParentOp()) &&
371 "only support FuncOp parent for ReturnOp");
381 FuncOpInterface, FuncOp> {
386 auto isaTensor = llvm::IsaPred<TensorLikeType>;
389 auto funcOp = cast<FuncOp>(op);
390 bool hasTensorArg = any_of(funcOp.getArgumentTypes(),
isaTensor);
391 bool hasTensorResult = any_of(funcOp.getResultTypes(),
isaTensor);
392 if (hasTensorArg || hasTensorResult)
400 for (
Block &block : funcOp.getBody())
401 if (any_of(block.getArgumentTypes(),
isaTensor))
407 AliasingOpOperandList
413 FailureOr<BufferLikeType>
415 const BufferizationState &state,
417 auto funcOp = cast<FuncOp>(op);
418 auto bbArg = cast<BlockArgument>(value);
421 if (bbArg.getOwner() == &funcOp.getBody().front())
438 BufferizationState &state)
const {
439 auto funcOp = cast<FuncOp>(op);
440 FunctionType funcType = funcOp.getFunctionType();
444 for (
const auto &it : llvm::enumerate(funcType.getInputs())) {
445 Type argType = it.value();
446 if (isa<TensorLikeType>(argType)) {
451 argTypes.push_back(argType);
456 for (
Type resultType : funcType.getResults()) {
457 if (
auto tensorType = dyn_cast<TensorLikeType>(resultType)) {
458 BufferLikeType resultType =
options.functionArgTypeConverterFn(
461 retTypes.push_back(resultType);
464 retTypes.push_back(resultType);
468 auto newFuncType = FunctionType::get(op->
getContext(), argTypes, retTypes);
471 if (funcOp.isExternal()) {
472 funcOp.setType(newFuncType);
477 for (
Block &block : funcOp.getBody())
484 assert(returnOp->getNumOperands() == retTypes.size() &&
485 "incorrect number of return values");
487 for (
auto [returnVal, bufferizedType] :
488 llvm::zip_equal(returnOp->getOperands(), retTypes)) {
489 auto tensorType = dyn_cast<TensorLikeType>(returnVal.getType());
494 returnValues.push_back(returnVal);
500 Value toBufferOp = bufferization::ToBufferOp::create(
501 rewriter, returnOp.getLoc(), bufferizedType, returnVal);
502 returnValues.push_back(toBufferOp);
505 returnOp.getOperandsMutable().assign(returnValues);
509 funcOp.setType(newFuncType);
516 auto funcOp = cast<FuncOp>(op);
518 assert(bbArg &&
"expected BlockArgument");
528 bbArg.
getArgNumber(), BufferizationDialect::kWritableAttrName))
529 return writable.getValue();
543 func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx);
544 func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx);
545 func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx);
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static bool isaTensor(Type t)
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block * getOwner() const
Returns the block that owns this argument.
Block represents an ordered list of Operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
result_range getResults()
MLIRContext * getContext()
Return the context this operation is associated with.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
State for analysis-enabled bufferization.
static BufferLikeType getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, const BufferizationOptions &options)
Return the index-th bufferized function argument type.
FuncOpAnalysisState
The state of analysis of a FuncOp.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
static FuncOp getCalledFunction(CallOpInterface callOp, SymbolTableCollection &symbolTables)
Return the FuncOp called by callOp.
static std::optional< int64_t > getEquivalentFuncArgIdx(FuncOp funcOp, const FuncAnalysisState &state, int64_t returnValIdx)
Return the index of the bbArg in the given FuncOp that is equivalent to the specified return value (i...
static mlir::Attribute getDefaultMemorySpace(const BufferizationOptions &options, TensorLikeType type)
static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, FuncOp funcOp)
Return the state (phase) of analysis of the FuncOp.
static const FuncAnalysisState & getFuncAnalysisState(const AnalysisState &state)
Get FuncAnalysisState.
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type, const BufferizationOptions &options)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
SmallVector< func::ReturnOp > getReturnOps(func::FuncOp funcOp)
Helper function that returns all func.return ops in the given function.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
Include the generated interface declarations.
A template that provides a default implementation of getAliasingOpOperands for ops that support unstr...
AliasingOpOperandList getAliasingBranchOpOperands(Operation *op, BlockArgument bbArg, const AnalysisState &state) const
FailureOr< BufferLikeType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector< Value > &invocationStack) const
FailureOr< BufferLikeType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector< Value > &invocationStack) const
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state) const
All function arguments are writable.
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
Extra analysis state that is required for bufferization of function boundaries.
DenseMap< FuncOp, IndexMapping > equivalentFuncArgs
A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg indices.
DenseMap< int64_t, SmallVector< int64_t > > IndexToIndexListMapping
A mapping of indices to a list of indices.
DenseMap< FuncOp, IndexToIndexListMapping > aliasingReturnVals
A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
DenseMap< FuncOp, BbArgIndexSet > readBbArgs
A set of all read BlockArguments of FuncOps.
DenseSet< int64_t > BbArgIndexSet
A set of block argument indices.
DenseMap< FuncOp, BbArgIndexSet > writtenBbArgs
A set of all written-to BlockArguments of FuncOps.
DenseMap< FuncOp, FuncOpAnalysisState > analyzedFuncOps
Keep track of which FuncOps are fully analyzed or currently being analyzed.
void startFunctionAnalysis(FuncOp funcOp)
This function is called right before analyzing the given FuncOp.
DenseMap< int64_t, int64_t > IndexMapping
A mapping of indices to indices.
FailureOr< BufferLikeType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector< Value > &invocationStack) const
AliasingOpOperandList getAliasingOpOperands(Operation *op, Value value, const AnalysisState &state) const
bool isWritable(Operation *op, Value value, const AnalysisState &state) const
Return true if the given function argument is writable.
static bool supportsUnstructuredControlFlow()
bool hasTensorSemantics(Operation *op) const
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state) const
Rewrite function bbArgs and return values into buffer form.
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state) const