25 for (
Block &b : funcOp.getBody())
26 if (
auto returnOp = dyn_cast<func::ReturnOp>(b.getTerminator()))
27 result.push_back(returnOp);
31 namespace bufferization {
37 auto createdAliasingResults =
42 (void)createdAliasingResults;
46 assert(createdEquiv.second &&
"equivalence info exists already");
47 assert(createdAliasingResults.second &&
"aliasing info exists already");
48 assert(createdRead.second &&
"bbarg access info exists already");
49 assert(createdWritten.second &&
"bbarg access info exists already");
60 dyn_cast<TensorType>(funcOp.getFunctionType().getInput(index));
61 assert(tensorType &&
"expected TensorType");
64 tensorType, *
options.defaultMemorySpaceFn(tensorType), funcOp,
options);
66 auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
67 index, BufferizationDialect::kBufferLayoutAttrName);
71 auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
72 assert(rankedMemrefType &&
"buffer layout not supported on unranked tensors");
74 rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
75 layoutAttr.getValue(), rankedMemrefType.getMemorySpace());
81 llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
84 return dyn_cast_or_null<FuncOp>(
89 static const FuncAnalysisState &
91 assert(isa<OneShotAnalysisState>(state) &&
"expected OneShotAnalysisState");
93 .getExtension<FuncAnalysisState>();
94 assert(result &&
"FuncAnalysisState does not exist");
101 if (!isa<OneShotAnalysisState>(state))
104 .getExtension<FuncAnalysisState>();
107 const auto &analyzedFuncOps = funcState->analyzedFuncOps;
108 auto it = analyzedFuncOps.find(funcOp);
109 if (it == analyzedFuncOps.end())
116 static std::optional<int64_t>
118 int64_t returnValIdx) {
119 auto funcOpIt = state.equivalentFuncArgs.find(funcOp);
120 if (funcOpIt == state.equivalentFuncArgs.end())
124 auto retValIt = funcOpIt->getSecond().find(returnValIdx);
125 if (retValIt == funcOpIt->getSecond().end())
129 return retValIt->getSecond();
133 :
public BufferizableOpInterface::ExternalModel<CallOpInterface,
137 func::CallOp callOp = cast<func::CallOp>(op);
139 assert(funcOp &&
"expected CallOp to a FuncOp");
146 return funcState.
readBbArgs.lookup(funcOp).contains(
152 func::CallOp callOp = cast<func::CallOp>(op);
154 assert(funcOp &&
"expected CallOp to a FuncOp");
167 func::CallOp callOp = cast<func::CallOp>(op);
169 assert(funcOp &&
"expected CallOp to a FuncOp");
176 auto aliasingReturnVals =
181 std::optional<int64_t> equivalent = {};
182 if (aliasingReturnVals.size() == 1) {
184 aliasingReturnVals.front());
185 assert((!equivalent.has_value() ||
187 "inconsistent analysis state");
190 for (int64_t resultIdx : aliasingReturnVals)
191 result.
addAlias({callOp->getOpResult(resultIdx),
194 equivalent.has_value()});
198 FailureOr<BaseMemRefType>
201 auto callOp = cast<func::CallOp>(op);
203 assert(funcOp &&
"expected CallOp to a FuncOp");
207 FunctionType funcType = funcOp.getFunctionType();
209 funcType.getResult(cast<OpResult>(value).getResultNumber());
210 if (
auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
211 return bufferizedType;
214 auto tensorType = cast<TensorType>(resultType);
215 return options.functionArgTypeConverterFn(
216 tensorType, *
options.defaultMemorySpaceFn(tensorType), funcOp,
options);
223 func::CallOp callOp = cast<func::CallOp>(op);
227 for (
Value result : callOp.getResults()) {
228 Type returnType = result.getType();
229 if (!isa<TensorType>(returnType)) {
231 resultTypes.push_back(returnType);
236 FailureOr<BaseMemRefType> resultType =
238 if (failed(resultType))
240 resultTypes.push_back(*resultType);
247 assert(funcOp &&
"expected CallOp to a FuncOp");
248 FunctionType funcType = funcOp.getFunctionType();
250 for (
OpOperand &opOperand : callOp->getOpOperands()) {
252 if (!isa<TensorType>(opOperand.get().getType())) {
253 newOperands.push_back(opOperand.get());
258 FailureOr<Value> maybeBuffer =
260 if (failed(maybeBuffer))
262 Value buffer = *maybeBuffer;
265 auto memRefType = funcType.getInput(opOperand.getOperandNumber());
266 if (!isa<BaseMemRefType>(memRefType)) {
270 FailureOr<BaseMemRefType> maybeMemRefType =
272 funcOp.getArgument(opOperand.getOperandNumber()),
options);
273 if (failed(maybeMemRefType))
275 memRefType = *maybeMemRefType;
284 if (buffer.
getType() != memRefType) {
285 auto memrefDstType = dyn_cast<MemRefType>(memRefType);
286 assert(memrefDstType &&
287 "buffer layout not supported on unranked tensors");
289 rewriter, buffer, memrefDstType,
options);
290 if (failed(replacement))
292 buffer = *replacement;
294 newOperands.push_back(buffer);
299 callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands);
300 newCallOp->
setAttrs(callOp->getAttrs());
310 :
public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
330 auto returnOp = cast<func::ReturnOp>(op);
331 assert(isa<FuncOp>(returnOp->getParentOp()) &&
332 "only support FuncOp parent for ReturnOp");
342 FuncOpInterface, FuncOp> {
347 auto isaTensor = llvm::IsaPred<TensorType>;
350 auto funcOp = cast<FuncOp>(op);
351 bool hasTensorArg = any_of(funcOp.getArgumentTypes(),
isaTensor);
352 bool hasTensorResult = any_of(funcOp.getResultTypes(),
isaTensor);
353 if (hasTensorArg || hasTensorResult)
361 for (
Block &block : funcOp.getBody())
362 if (any_of(block.getArgumentTypes(),
isaTensor))
374 FailureOr<BaseMemRefType>
377 auto funcOp = cast<FuncOp>(op);
378 auto bbArg = cast<BlockArgument>(value);
381 if (bbArg.getOwner() == &funcOp.getBody().front())
398 auto funcOp = cast<FuncOp>(op);
399 FunctionType funcType = funcOp.getFunctionType();
404 Type argType = it.value();
405 if (isa<TensorType>(argType)) {
410 argTypes.push_back(argType);
415 for (
Type resultType : funcType.getResults()) {
416 if (
auto tensorType = dyn_cast<TensorType>(resultType)) {
418 tensorType, *
options.defaultMemorySpaceFn(tensorType), funcOp,
420 retTypes.push_back(resultType);
423 retTypes.push_back(resultType);
430 if (funcOp.isExternal()) {
431 funcOp.setType(newFuncType);
436 for (
Block &block : funcOp.getBody())
443 assert(returnOp->getNumOperands() == retTypes.size() &&
444 "incorrect number of return values");
446 for (
auto [returnVal, bufferizedType] :
447 llvm::zip_equal(returnOp->getOperands(), retTypes)) {
448 auto tensorType = dyn_cast<TensorType>(returnVal.getType());
453 returnValues.push_back(returnVal);
459 Value toMemrefOp = rewriter.
create<bufferization::ToMemrefOp>(
460 returnOp.getLoc(), bufferizedType, returnVal);
461 returnValues.push_back(toMemrefOp);
464 returnOp.getOperandsMutable().assign(returnValues);
468 funcOp.setType(newFuncType);
475 auto funcOp = cast<FuncOp>(op);
477 assert(bbArg &&
"expected BlockArgument");
487 bbArg.
getArgNumber(), BufferizationDialect::kWritableAttrName))
488 return writable.getValue();
502 func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx);
503 func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx);
504 func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx);
static bool isaTensor(Type t)
static llvm::ManagedStatic< PassManagerOptions > options
This class provides a shared interface for ranked and unranked memref types.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
result_range getResults()
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
AnalysisState provides a variety of helper functions for dealing with tensor values.
State for analysis-enabled bufferization.
AliasingValueList unknownGetAliasingValues(OpOperand &opOperand)
This is the default implementation of getAliasingValues in case the owner op does not implement the B...
static std::optional< int64_t > getEquivalentFuncArgIdx(FuncOp funcOp, const FuncAnalysisState &state, int64_t returnValIdx)
Return the index of the bbArg in the given FuncOp that is equivalent to the specified return value (i...
FuncOpAnalysisState
The state of analysis of a FuncOp.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, FuncOp funcOp)
Return the state (phase) of analysis of the FuncOp.
static const FuncAnalysisState & getFuncAnalysisState(const AnalysisState &state)
Get FuncAnalysisState.
static FuncOp getCalledFunction(CallOpInterface callOp)
Return the FuncOp called by callOp.
static BaseMemRefType getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, const BufferizationOptions &options)
Return the index-th bufferized function argument type.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type, const BufferizationOptions &options)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
SmallVector< func::ReturnOp > getReturnOps(func::FuncOp funcOp)
Helper function that returns all func.return ops in the given function.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Options for BufferizableOpInterface-based bufferization.
A template that provides a default implementation of getAliasingOpOperands for ops that support unstr...
FailureOr< BaseMemRefType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector< Value > &invocationStack) const
AliasingOpOperandList getAliasingBranchOpOperands(Operation *op, BlockArgument bbArg, const AnalysisState &state) const
Assuming that bbArg is a block argument of a block that belongs to the given op, return all OpOperand...
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const
All function arguments are writable.
FailureOr< BaseMemRefType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector< Value > &invocationStack) const
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
Extra analysis state that is required for bufferization of function boundaries.
DenseMap< FuncOp, IndexMapping > equivalentFuncArgs
A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg indices.
DenseMap< FuncOp, IndexToIndexListMapping > aliasingReturnVals
A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
DenseMap< int64_t, SmallVector< int64_t > > IndexToIndexListMapping
A mapping of indices to a list of indices.
DenseMap< FuncOp, BbArgIndexSet > readBbArgs
A set of all read BlockArguments of FuncOps.
DenseSet< int64_t > BbArgIndexSet
A set of block argument indices.
DenseMap< FuncOp, BbArgIndexSet > writtenBbArgs
A set of all written-to BlockArguments of FuncOps.
DenseMap< FuncOp, FuncOpAnalysisState > analyzedFuncOps
Keep track of which FuncOps are fully analyzed or currently being analyzed.
void startFunctionAnalysis(FuncOp funcOp)
This function is called right before analyzing the given FuncOp.
DenseMap< int64_t, int64_t > IndexMapping
A mapping of indices to indices.
AliasingOpOperandList getAliasingOpOperands(Operation *op, Value value, const AnalysisState &state) const
FailureOr< BaseMemRefType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector< Value > &invocationStack) const
bool isWritable(Operation *op, Value value, const AnalysisState &state) const
Return true if the given function argument is writable.
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const
Rewrite function bbArgs and return values into buffer form.
static bool supportsUnstructuredControlFlow()
bool hasTensorSemantics(Operation *op) const
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const