22 namespace bufferization {
28 auto createdAliasingResults =
33 (void)createdAliasingResults;
37 assert(createdEquiv.second &&
"equivalence info exists already");
38 assert(createdAliasingResults.second &&
"aliasing info exists already");
39 assert(createdRead.second &&
"bbarg access info exists already");
40 assert(createdWritten.second &&
"bbarg access info exists already");
47 func::ReturnOp returnOp;
48 for (
Block &b : funcOp.getBody()) {
49 if (
auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
52 returnOp = candidateOp;
65 dyn_cast<TensorType>(funcOp.getFunctionType().getInput(index));
66 assert(tensorType &&
"expected TensorType");
71 auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
72 index, BufferizationDialect::kBufferLayoutAttrName);
76 auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
77 assert(rankedMemrefType &&
"buffer layout not supported on unranked tensors");
79 rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
80 layoutAttr.getValue(), rankedMemrefType.getMemorySpace());
85 SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
88 return dyn_cast_or_null<FuncOp>(
93 static const FuncAnalysisState &
95 assert(isa<OneShotAnalysisState>(state) &&
"expected OneShotAnalysisState");
97 .getExtension<FuncAnalysisState>();
98 assert(result &&
"FuncAnalysisState does not exist");
105 if (!isa<OneShotAnalysisState>(state))
108 .getExtension<FuncAnalysisState>();
111 const auto &analyzedFuncOps = funcState->analyzedFuncOps;
112 auto it = analyzedFuncOps.find(funcOp);
113 if (it == analyzedFuncOps.end())
120 static std::optional<int64_t>
122 int64_t returnValIdx) {
123 auto funcOpIt = state.equivalentFuncArgs.find(funcOp);
124 if (funcOpIt == state.equivalentFuncArgs.end())
128 auto retValIt = funcOpIt->getSecond().find(returnValIdx);
129 if (retValIt == funcOpIt->getSecond().end())
133 return retValIt->getSecond();
137 :
public BufferizableOpInterface::ExternalModel<CallOpInterface,
141 func::CallOp callOp = cast<func::CallOp>(op);
143 assert(funcOp &&
"expected CallOp to a FuncOp");
150 return funcState.
readBbArgs.lookup(funcOp).contains(
156 func::CallOp callOp = cast<func::CallOp>(op);
158 assert(funcOp &&
"expected CallOp to a FuncOp");
171 func::CallOp callOp = cast<func::CallOp>(op);
173 assert(funcOp &&
"expected CallOp to a FuncOp");
180 auto aliasingReturnVals =
185 std::optional<int64_t> equivalent = {};
186 if (aliasingReturnVals.size() == 1) {
188 aliasingReturnVals.front());
189 assert((!equivalent.has_value() ||
191 "inconsistent analysis state");
194 for (int64_t resultIdx : aliasingReturnVals)
195 result.
addAlias({callOp->getOpResult(resultIdx),
198 equivalent.has_value()});
205 auto callOp = cast<func::CallOp>(op);
207 assert(funcOp &&
"expected CallOp to a FuncOp");
211 FunctionType funcType = funcOp.getFunctionType();
212 return cast<BaseMemRefType>(
213 funcType.getResult(cast<OpResult>(value).getResultNumber()));
220 func::CallOp callOp = cast<func::CallOp>(op);
224 for (
Value result : callOp.getResults()) {
225 Type returnType = result.getType();
226 if (!isa<TensorType>(returnType)) {
228 resultTypes.push_back(returnType);
237 resultTypes.push_back(*resultType);
244 assert(funcOp &&
"expected CallOp to a FuncOp");
245 FunctionType funcType = funcOp.getFunctionType();
247 for (
OpOperand &opOperand : callOp->getOpOperands()) {
249 if (!isa<TensorType>(opOperand.get().getType())) {
250 newOperands.push_back(opOperand.get());
259 Value buffer = *maybeBuffer;
262 auto memRefType = funcType.getInput(opOperand.getOperandNumber());
268 if (buffer.
getType() != memRefType) {
270 memref::CastOp::areCastCompatible(buffer.
getType(), memRefType) &&
271 "CallOp::bufferize: cast incompatible");
272 Value castBuffer = rewriter.
create<memref::CastOp>(callOp.getLoc(),
276 newOperands.push_back(buffer);
281 callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands);
282 newCallOp->
setAttrs(callOp->getAttrs());
292 :
public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
312 auto returnOp = cast<func::ReturnOp>(op);
313 assert(isa<FuncOp>(returnOp->getParentOp()) &&
314 "only support FuncOp parent for ReturnOp");
324 FuncOpInterface, FuncOp> {
337 auto funcOp = cast<FuncOp>(op);
338 auto bbArg = cast<BlockArgument>(value);
341 if (bbArg.getOwner() == &funcOp.getBody().front())
351 auto funcOp = cast<func::FuncOp>(op);
354 return op->emitOpError(
"op without unique func.return is not supported");
367 auto funcOp = cast<FuncOp>(op);
368 FunctionType funcType = funcOp.getFunctionType();
373 Type argType = it.value();
374 if (dyn_cast<TensorType>(argType)) {
379 argTypes.push_back(argType);
385 if (funcOp.isExternal()) {
387 for (
Type resultType : funcType.getResults()) {
388 if (isa<TensorType>(resultType))
389 return funcOp->emitError() <<
"cannot bufferize bodiless function "
390 <<
"that returns a tensor";
391 retTypes.push_back(resultType);
399 assert(returnOp &&
"expected func with single return op");
403 for (
Block &block : funcOp.getBody())
410 for (
OpOperand &returnOperand : returnOp->getOpOperands()) {
411 Value returnVal = returnOperand.get();
412 auto tensorType = dyn_cast<TensorType>(returnVal.
getType());
417 returnValues.push_back(returnVal);
425 Value toMemrefOp = rewriter.
create<bufferization::ToMemrefOp>(
426 loc, resultType, returnVal);
427 returnValues.push_back(toMemrefOp);
431 returnOp.getOperandsMutable().assign(returnValues);
443 auto funcOp = cast<FuncOp>(op);
445 assert(bbArg &&
"expected BlockArgument");
455 bbArg.
getArgNumber(), BufferizationDialect::kWritableAttrName))
456 return writable.getValue();
470 func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx);
471 func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx);
472 func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx);
static llvm::ManagedStatic< PassManagerOptions > options
This class provides a shared interface for ranked and unranked memref types.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
result_range getResults()
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
AnalysisState provides a variety of helper functions for dealing with tensor values.
State for analysis-enabled bufferization.
AliasingValueList unknownGetAliasingValues(OpOperand &opOperand)
This is the default implementation of getAliasingValues in case the owner op does not implement the B...
static std::optional< int64_t > getEquivalentFuncArgIdx(FuncOp funcOp, const FuncAnalysisState &state, int64_t returnValIdx)
Return the index of the bbArg in the given FuncOp that is equivalent to the specified return value (i...
FuncOpAnalysisState
The state of analysis of a FuncOp.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, FuncOp funcOp)
Return the state (phase) of analysis of the FuncOp.
static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp)
Return the unique ReturnOp that terminates funcOp.
static const FuncAnalysisState & getFuncAnalysisState(const AnalysisState &state)
Get FuncAnalysisState.
static FuncOp getCalledFunction(CallOpInterface callOp)
Return the FuncOp called by callOp.
static BaseMemRefType getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, const BufferizationOptions &options)
Return the index-th bufferized function argument type.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
Options for BufferizableOpInterface-based bufferization.
A template that provides a default implementation of getAliasingOpOperands for ops that support unstr...
FailureOr< BaseMemRefType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector< Value > &invocationStack) const
AliasingOpOperandList getAliasingBranchOpOperands(Operation *op, BlockArgument bbArg, const AnalysisState &state) const
Assuming that bbArg is a block argument of a block that belongs to the given op, return all OpOperand...
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const
All function arguments are writable.
FailureOr< BaseMemRefType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector< Value > &invocationStack) const
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
Extra analysis state that is required for bufferization of function boundaries.
DenseMap< FuncOp, IndexMapping > equivalentFuncArgs
A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg indices.
DenseMap< FuncOp, IndexToIndexListMapping > aliasingReturnVals
A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
DenseMap< int64_t, SmallVector< int64_t > > IndexToIndexListMapping
A mapping of indices to a list of indices.
DenseMap< FuncOp, BbArgIndexSet > readBbArgs
A set of all read BlockArguments of FuncOps.
DenseSet< int64_t > BbArgIndexSet
A set of block argument indices.
DenseMap< FuncOp, BbArgIndexSet > writtenBbArgs
A set of all written-to BlockArguments of FuncOps.
DenseMap< FuncOp, FuncOpAnalysisState > analyzedFuncOps
Keep track of which FuncOps are fully analyzed or currently being analyzed.
void startFunctionAnalysis(FuncOp funcOp)
This function is called right before analyzing the given FuncOp.
DenseMap< int64_t, int64_t > IndexMapping
A mapping of indices to indices.
AliasingOpOperandList getAliasingOpOperands(Operation *op, Value value, const AnalysisState &state) const
LogicalResult verifyAnalysis(Operation *op, const AnalysisState &state) const
FailureOr< BaseMemRefType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector< Value > &invocationStack) const
bool isWritable(Operation *op, Value value, const AnalysisState &state) const
Return true if the given function argument is writable.
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const
Rewrite function bbArgs and return values into buffer form.
static bool supportsUnstructuredControlFlow()
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const