24 for (
Block &b : funcOp.getBody())
25 if (
auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
26 result.push_back(returnOp);
30 namespace bufferization {
36 auto createdAliasingResults =
41 (void)createdAliasingResults;
45 assert(createdEquiv.second &&
"equivalence info exists already");
46 assert(createdAliasingResults.second &&
"aliasing info exists already");
47 assert(createdRead.second &&
"bbarg access info exists already");
48 assert(createdWritten.second &&
"bbarg access info exists already");
59 dyn_cast<TensorType>(funcOp.getFunctionType().getInput(index));
60 assert(tensorType &&
"expected TensorType");
63 tensorType, *
options.defaultMemorySpaceFn(tensorType), funcOp,
options);
65 auto layoutAttr = funcOp.getArgAttrOfType<MemRefLayoutAttrInterface>(
66 index, BufferizationDialect::kBufferLayoutAttrName);
70 auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
71 assert(rankedMemrefType &&
"buffer layout not supported on unranked tensors");
73 rankedMemrefType.getElementType(), layoutAttr,
74 rankedMemrefType.getMemorySpace());
81 llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
84 return dyn_cast_or_null<FuncOp>(
93 if (
auto *funcAnalysisState =
104 static const FuncAnalysisState &
106 assert(isa<OneShotAnalysisState>(state) &&
"expected OneShotAnalysisState");
108 .getExtension<FuncAnalysisState>();
109 assert(result &&
"FuncAnalysisState does not exist");
116 if (!isa<OneShotAnalysisState>(state))
119 .getExtension<FuncAnalysisState>();
122 const auto &analyzedFuncOps = funcState->analyzedFuncOps;
123 auto it = analyzedFuncOps.find(funcOp);
124 if (it == analyzedFuncOps.end())
131 static std::optional<int64_t>
133 int64_t returnValIdx) {
134 auto funcOpIt = state.equivalentFuncArgs.find(funcOp);
135 if (funcOpIt == state.equivalentFuncArgs.end())
139 auto retValIt = funcOpIt->getSecond().find(returnValIdx);
140 if (retValIt == funcOpIt->getSecond().end())
144 return retValIt->getSecond();
148 :
public BufferizableOpInterface::ExternalModel<CallOpInterface,
152 func::CallOp callOp = cast<func::CallOp>(op);
154 assert(funcOp &&
"expected CallOp to a FuncOp");
161 return funcState.
readBbArgs.lookup(funcOp).contains(
167 func::CallOp callOp = cast<func::CallOp>(op);
169 assert(funcOp &&
"expected CallOp to a FuncOp");
182 func::CallOp callOp = cast<func::CallOp>(op);
184 assert(funcOp &&
"expected CallOp to a FuncOp");
191 auto aliasingReturnVals =
196 std::optional<int64_t> equivalent = {};
197 if (aliasingReturnVals.size() == 1) {
199 aliasingReturnVals.front());
200 assert((!equivalent.has_value() ||
202 "inconsistent analysis state");
205 for (int64_t resultIdx : aliasingReturnVals)
206 result.
addAlias({callOp->getOpResult(resultIdx),
209 equivalent.has_value()});
213 FailureOr<BufferLikeType>
217 auto callOp = cast<func::CallOp>(op);
223 assert(funcOp &&
"expected CallOp to a FuncOp");
227 FunctionType funcType = funcOp.getFunctionType();
229 funcType.getResult(cast<OpResult>(value).getResultNumber());
230 if (
auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
231 return cast<BufferLikeType>(bufferizedType);
234 auto tensorType = cast<TensorType>(resultType);
235 return cast<BufferLikeType>(
options.functionArgTypeConverterFn(
236 tensorType, *
options.defaultMemorySpaceFn(tensorType), funcOp,
245 func::CallOp callOp = cast<func::CallOp>(op);
249 for (
Value result : callOp.getResults()) {
250 Type returnType = result.getType();
251 if (!isa<TensorType>(returnType)) {
253 resultTypes.push_back(returnType);
258 FailureOr<BufferLikeType> resultType =
260 if (failed(resultType))
262 resultTypes.push_back(*resultType);
270 assert(funcOp &&
"expected CallOp to a FuncOp");
271 FunctionType funcType = funcOp.getFunctionType();
273 for (
OpOperand &opOperand : callOp->getOpOperands()) {
275 if (!isa<TensorType>(opOperand.get().getType())) {
276 newOperands.push_back(opOperand.get());
281 FailureOr<Value> maybeBuffer =
283 if (failed(maybeBuffer))
285 Value buffer = *maybeBuffer;
288 auto memRefType = funcType.getInput(opOperand.getOperandNumber());
289 if (!isa<BaseMemRefType>(memRefType)) {
293 FailureOr<BufferLikeType> maybeBufferType =
295 funcOp.getArgument(opOperand.getOperandNumber()),
options,
297 if (failed(maybeBufferType))
299 memRefType = *maybeBufferType;
308 if (buffer.
getType() != memRefType) {
309 auto memrefDstType = dyn_cast<MemRefType>(memRefType);
310 assert(memrefDstType &&
311 "buffer layout not supported on unranked tensors");
313 rewriter, buffer, memrefDstType,
options);
314 if (failed(replacement))
316 buffer = *replacement;
318 newOperands.push_back(buffer);
323 callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands);
324 newCallOp->
setAttrs(callOp->getAttrs());
334 :
public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
355 auto returnOp = cast<func::ReturnOp>(op);
356 assert(isa<FuncOp>(returnOp->getParentOp()) &&
357 "only support FuncOp parent for ReturnOp");
367 FuncOpInterface, FuncOp> {
372 auto isaTensor = llvm::IsaPred<TensorType>;
375 auto funcOp = cast<FuncOp>(op);
376 bool hasTensorArg = any_of(funcOp.getArgumentTypes(),
isaTensor);
377 bool hasTensorResult = any_of(funcOp.getResultTypes(),
isaTensor);
378 if (hasTensorArg || hasTensorResult)
386 for (
Block &block : funcOp.getBody())
387 if (any_of(block.getArgumentTypes(),
isaTensor))
399 FailureOr<BufferLikeType>
403 auto funcOp = cast<FuncOp>(op);
404 auto bbArg = cast<BlockArgument>(value);
407 if (bbArg.getOwner() == &funcOp.getBody().front())
408 return cast<BufferLikeType>(
425 auto funcOp = cast<FuncOp>(op);
426 FunctionType funcType = funcOp.getFunctionType();
431 Type argType = it.value();
432 if (isa<TensorType>(argType)) {
437 argTypes.push_back(argType);
442 for (
Type resultType : funcType.getResults()) {
443 if (
auto tensorType = dyn_cast<TensorType>(resultType)) {
445 tensorType, *
options.defaultMemorySpaceFn(tensorType), funcOp,
447 retTypes.push_back(resultType);
450 retTypes.push_back(resultType);
457 if (funcOp.isExternal()) {
458 funcOp.setType(newFuncType);
463 for (
Block &block : funcOp.getBody())
470 assert(returnOp->getNumOperands() == retTypes.size() &&
471 "incorrect number of return values");
473 for (
auto [returnVal, bufferizedType] :
474 llvm::zip_equal(returnOp->getOperands(), retTypes)) {
475 auto tensorType = dyn_cast<TensorType>(returnVal.getType());
480 returnValues.push_back(returnVal);
486 Value toBufferOp = rewriter.
create<bufferization::ToBufferOp>(
487 returnOp.getLoc(), bufferizedType, returnVal);
488 returnValues.push_back(toBufferOp);
491 returnOp.getOperandsMutable().assign(returnValues);
495 funcOp.setType(newFuncType);
502 auto funcOp = cast<FuncOp>(op);
504 assert(bbArg &&
"expected BlockArgument");
514 bbArg.
getArgNumber(), BufferizationDialect::kWritableAttrName))
515 return writable.getValue();
529 func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx);
530 func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx);
531 func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx);
static bool isaTensor(Type t)
static llvm::ManagedStatic< PassManagerOptions > options
This class provides a shared interface for ranked and unranked memref types.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
result_range getResults()
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
AnalysisState provides a variety of helper functions for dealing with tensor values.
BufferizationState provides information about the state of the IR during the bufferization process.
State for analysis-enabled bufferization.
AliasingValueList unknownGetAliasingValues(OpOperand &opOperand)
This is the default implementation of getAliasingValues in case the owner op does not implement the B...
static std::optional< int64_t > getEquivalentFuncArgIdx(FuncOp funcOp, const FuncAnalysisState &state, int64_t returnValIdx)
Return the index of the bbArg in the given FuncOp that is equivalent to the specified return value (i...
FuncOpAnalysisState
The state of analysis of a FuncOp.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
static FuncOp getCalledFunction(CallOpInterface callOp, SymbolTableCollection &symbolTables)
Return the FuncOp called by callOp.
static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, FuncOp funcOp)
Return the state (phase) of analysis of the FuncOp.
static const FuncAnalysisState & getFuncAnalysisState(const AnalysisState &state)
Get FuncAnalysisState.
static BaseMemRefType getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, const BufferizationOptions &options)
Return the index-th bufferized function argument type.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type, const BufferizationOptions &options)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
SmallVector< func::ReturnOp > getReturnOps(func::FuncOp funcOp)
Helper function that returns all func.return ops in the given function.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options, const BufferizationState &state)
Lookup the buffer for the given value.
FailureOr< BufferLikeType > getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Options for BufferizableOpInterface-based bufferization.
A template that provides a default implementation of getAliasingOpOperands for ops that support unstr...
AliasingOpOperandList getAliasingBranchOpOperands(Operation *op, BlockArgument bbArg, const AnalysisState &state) const
Assuming that bbArg is a block argument of a block that belongs to the given op, return all OpOperand...
FailureOr< BufferLikeType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector< Value > &invocationStack) const
FailureOr< BufferLikeType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector< Value > &invocationStack) const
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state) const
All function arguments are writable.
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
Extra analysis state that is required for bufferization of function boundaries.
DenseMap< FuncOp, IndexMapping > equivalentFuncArgs
A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg indices.
DenseMap< FuncOp, IndexToIndexListMapping > aliasingReturnVals
A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
DenseMap< int64_t, SmallVector< int64_t > > IndexToIndexListMapping
A mapping of indices to a list of indices.
DenseMap< FuncOp, BbArgIndexSet > readBbArgs
A set of all read BlockArguments of FuncOps.
DenseSet< int64_t > BbArgIndexSet
A set of block argument indices.
DenseMap< FuncOp, BbArgIndexSet > writtenBbArgs
A set of all written-to BlockArguments of FuncOps.
DenseMap< FuncOp, FuncOpAnalysisState > analyzedFuncOps
Keep track of which FuncOps are fully analyzed or currently being analyzed.
void startFunctionAnalysis(FuncOp funcOp)
This function is called right before analyzing the given FuncOp.
DenseMap< int64_t, int64_t > IndexMapping
A mapping of indices to indices.
AliasingOpOperandList getAliasingOpOperands(Operation *op, Value value, const AnalysisState &state) const
bool isWritable(Operation *op, Value value, const AnalysisState &state) const
Return true if the given function argument is writable.
static bool supportsUnstructuredControlFlow()
bool hasTensorSemantics(Operation *op) const
FailureOr< BufferLikeType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector< Value > &invocationStack) const
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state) const
Rewrite function bbArgs and return values into buffer form.
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state) const