22 namespace bufferization {
28 auto createdAliasingResults =
33 (void)createdAliasingResults;
37 assert(createdEquiv.second &&
"equivalence info exists already");
38 assert(createdAliasingResults.second &&
"aliasing info exists already");
39 assert(createdRead.second &&
"bbarg access info exists already");
40 assert(createdWritten.second &&
"bbarg access info exists already");
47 func::ReturnOp returnOp;
48 for (
Block &b : funcOp.getBody()) {
49 if (
auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
52 returnOp = candidateOp;
65 dyn_cast<TensorType>(funcOp.getFunctionType().getInput(index));
66 assert(tensorType &&
"expected TensorType");
69 tensorType, *
options.defaultMemorySpaceFn(tensorType), funcOp,
options);
71 auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
72 index, BufferizationDialect::kBufferLayoutAttrName);
76 auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
77 assert(rankedMemrefType &&
"buffer layout not supported on unranked tensors");
79 rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
80 layoutAttr.getValue(), rankedMemrefType.getMemorySpace());
85 SymbolRefAttr sym = llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
88 return dyn_cast_or_null<FuncOp>(
93 static const FuncAnalysisState &
95 assert(isa<OneShotAnalysisState>(state) &&
"expected OneShotAnalysisState");
97 .getExtension<FuncAnalysisState>();
98 assert(result &&
"FuncAnalysisState does not exist");
105 if (!isa<OneShotAnalysisState>(state))
108 .getExtension<FuncAnalysisState>();
111 const auto &analyzedFuncOps = funcState->analyzedFuncOps;
112 auto it = analyzedFuncOps.find(funcOp);
113 if (it == analyzedFuncOps.end())
120 static std::optional<int64_t>
122 int64_t returnValIdx) {
123 auto funcOpIt = state.equivalentFuncArgs.find(funcOp);
124 if (funcOpIt == state.equivalentFuncArgs.end())
128 auto retValIt = funcOpIt->getSecond().find(returnValIdx);
129 if (retValIt == funcOpIt->getSecond().end())
133 return retValIt->getSecond();
137 :
public BufferizableOpInterface::ExternalModel<CallOpInterface,
141 func::CallOp callOp = cast<func::CallOp>(op);
143 assert(funcOp &&
"expected CallOp to a FuncOp");
150 return funcState.
readBbArgs.lookup(funcOp).contains(
156 func::CallOp callOp = cast<func::CallOp>(op);
158 assert(funcOp &&
"expected CallOp to a FuncOp");
171 func::CallOp callOp = cast<func::CallOp>(op);
173 assert(funcOp &&
"expected CallOp to a FuncOp");
180 auto aliasingReturnVals =
185 std::optional<int64_t> equivalent = {};
186 if (aliasingReturnVals.size() == 1) {
188 aliasingReturnVals.front());
189 assert((!equivalent.has_value() ||
191 "inconsistent analysis state");
194 for (int64_t resultIdx : aliasingReturnVals)
195 result.
addAlias({callOp->getOpResult(resultIdx),
198 equivalent.has_value()});
202 FailureOr<BaseMemRefType>
205 auto callOp = cast<func::CallOp>(op);
207 assert(funcOp &&
"expected CallOp to a FuncOp");
211 FunctionType funcType = funcOp.getFunctionType();
212 return cast<BaseMemRefType>(
213 funcType.getResult(cast<OpResult>(value).getResultNumber()));
220 func::CallOp callOp = cast<func::CallOp>(op);
224 for (
Value result : callOp.getResults()) {
225 Type returnType = result.getType();
226 if (!isa<TensorType>(returnType)) {
228 resultTypes.push_back(returnType);
233 FailureOr<BaseMemRefType> resultType =
235 if (failed(resultType))
237 resultTypes.push_back(*resultType);
244 assert(funcOp &&
"expected CallOp to a FuncOp");
245 FunctionType funcType = funcOp.getFunctionType();
247 for (
OpOperand &opOperand : callOp->getOpOperands()) {
249 if (!isa<TensorType>(opOperand.get().getType())) {
250 newOperands.push_back(opOperand.get());
255 FailureOr<Value> maybeBuffer =
257 if (failed(maybeBuffer))
259 Value buffer = *maybeBuffer;
262 auto memRefType = funcType.getInput(opOperand.getOperandNumber());
269 if (buffer.
getType() != memRefType) {
270 auto memrefDstType = dyn_cast<MemRefType>(memRefType);
271 assert(memrefDstType &&
272 "buffer layout not supported on unranked tensors");
274 rewriter, buffer, memrefDstType,
options);
275 if (failed(replacement))
277 buffer = *replacement;
279 newOperands.push_back(buffer);
284 callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands);
285 newCallOp->
setAttrs(callOp->getAttrs());
295 :
public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
315 auto returnOp = cast<func::ReturnOp>(op);
316 assert(isa<FuncOp>(returnOp->getParentOp()) &&
317 "only support FuncOp parent for ReturnOp");
327 FuncOpInterface, FuncOp> {
332 auto isaTensor = llvm::IsaPred<TensorType>;
335 auto funcOp = cast<FuncOp>(op);
336 bool hasTensorArg = any_of(funcOp.getArgumentTypes(),
isaTensor);
337 bool hasTensorResult = any_of(funcOp.getResultTypes(),
isaTensor);
338 if (hasTensorArg || hasTensorResult)
346 for (
Block &block : funcOp.getBody())
347 if (any_of(block.getArgumentTypes(),
isaTensor))
359 FailureOr<BaseMemRefType>
362 auto funcOp = cast<FuncOp>(op);
363 auto bbArg = cast<BlockArgument>(value);
366 if (bbArg.getOwner() == &funcOp.getBody().front())
376 auto funcOp = cast<func::FuncOp>(op);
379 return op->emitOpError(
"op without unique func.return is not supported");
392 auto funcOp = cast<FuncOp>(op);
393 FunctionType funcType = funcOp.getFunctionType();
398 Type argType = it.value();
399 if (dyn_cast<TensorType>(argType)) {
404 argTypes.push_back(argType);
410 if (funcOp.isExternal()) {
412 for (
Type resultType : funcType.getResults()) {
413 if (isa<TensorType>(resultType))
414 return funcOp->emitError() <<
"cannot bufferize bodiless function "
415 <<
"that returns a tensor";
416 retTypes.push_back(resultType);
424 assert(returnOp &&
"expected func with single return op");
428 for (
Block &block : funcOp.getBody())
435 for (
OpOperand &returnOperand : returnOp->getOpOperands()) {
436 Value returnVal = returnOperand.get();
437 auto tensorType = dyn_cast<TensorType>(returnVal.
getType());
442 returnValues.push_back(returnVal);
449 tensorType, *
options.defaultMemorySpaceFn(tensorType), funcOp,
451 Value toMemrefOp = rewriter.
create<bufferization::ToMemrefOp>(
452 loc, resultType, returnVal);
453 returnValues.push_back(toMemrefOp);
457 returnOp.getOperandsMutable().assign(returnValues);
469 auto funcOp = cast<FuncOp>(op);
471 assert(bbArg &&
"expected BlockArgument");
481 bbArg.
getArgNumber(), BufferizationDialect::kWritableAttrName))
482 return writable.getValue();
496 func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx);
497 func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx);
498 func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx);
static bool isaTensor(Type t)
static llvm::ManagedStatic< PassManagerOptions > options
This class provides a shared interface for ranked and unranked memref types.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
result_range getResults()
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
AnalysisState provides a variety of helper functions for dealing with tensor values.
State for analysis-enabled bufferization.
AliasingValueList unknownGetAliasingValues(OpOperand &opOperand)
This is the default implementation of getAliasingValues in case the owner op does not implement the B...
static std::optional< int64_t > getEquivalentFuncArgIdx(FuncOp funcOp, const FuncAnalysisState &state, int64_t returnValIdx)
Return the index of the bbArg in the given FuncOp that is equivalent to the specified return value (i...
FuncOpAnalysisState
The state of analysis of a FuncOp.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, FuncOp funcOp)
Return the state (phase) of analysis of the FuncOp.
static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp)
Return the unique ReturnOp that terminates funcOp.
static const FuncAnalysisState & getFuncAnalysisState(const AnalysisState &state)
Get FuncAnalysisState.
static FuncOp getCalledFunction(CallOpInterface callOp)
Return the FuncOp called by callOp.
static BaseMemRefType getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, const BufferizationOptions &options)
Return the index-th bufferized function argument type.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type, const BufferizationOptions &options)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options)
Bufferize the signature of block and its callers (i.e., ops that have the given block as a successor)...
FailureOr< BaseMemRefType > getBufferType(Value value, const BufferizationOptions &options)
Return the buffer type for a given Value (tensor) after bufferization without bufferizing any IR.
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
Options for BufferizableOpInterface-based bufferization.
A template that provides a default implementation of getAliasingOpOperands for ops that support unstr...
FailureOr< BaseMemRefType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector< Value > &invocationStack) const
AliasingOpOperandList getAliasingBranchOpOperands(Operation *op, BlockArgument bbArg, const AnalysisState &state) const
Assuming that bbArg is a block argument of a block that belongs to the given op, return all OpOperand...
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const
All function arguments are writable.
FailureOr< BaseMemRefType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector< Value > &invocationStack) const
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
Extra analysis state that is required for bufferization of function boundaries.
void startFunctionAnalysis(FunctionOpInterface funcOp)
This function is called right before analyzing the given FuncOp.
DenseMap< FunctionOpInterface, IndexToIndexListMapping > aliasingReturnVals
A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
DenseMap< FunctionOpInterface, FuncOpAnalysisState > analyzedFuncOps
Keep track of which FuncOps are fully analyzed or currently being analyzed.
DenseMap< int64_t, SmallVector< int64_t > > IndexToIndexListMapping
A mapping of indices to a list of indices.
DenseMap< FunctionOpInterface, IndexMapping > equivalentFuncArgs
A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg indices.
DenseMap< FunctionOpInterface, BbArgIndexSet > writtenBbArgs
A set of all written-to BlockArguments of FuncOps.
DenseMap< FunctionOpInterface, BbArgIndexSet > readBbArgs
A set of all read BlockArguments of FuncOps.
DenseSet< int64_t > BbArgIndexSet
A set of block argument indices.
DenseMap< int64_t, int64_t > IndexMapping
A mapping of indices to indices.
AliasingOpOperandList getAliasingOpOperands(Operation *op, Value value, const AnalysisState &state) const
LogicalResult verifyAnalysis(Operation *op, const AnalysisState &state) const
FailureOr< BaseMemRefType > getBufferType(Operation *op, Value value, const BufferizationOptions &options, SmallVector< Value > &invocationStack) const
bool isWritable(Operation *op, Value value, const AnalysisState &state) const
Return true if the given function argument is writable.
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const
Rewrite function bbArgs and return values into buffer form.
static bool supportsUnstructuredControlFlow()
bool hasTensorSemantics(Operation *op) const
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const