20 namespace bufferization {
26 auto createdAliasingOperands =
28 auto createdAliasingResults =
33 (void)createdAliasingOperands;
34 (void)createdAliasingResults;
38 assert(createdEquiv.second &&
"equivalence info exists already");
39 assert(createdAliasingOperands.second &&
"aliasing info exists already");
40 assert(createdAliasingResults.second &&
"aliasing info exists already");
41 assert(createdRead.second &&
"bbarg access info exists already");
42 assert(createdWritten.second &&
"bbarg access info exists already");
49 func::ReturnOp returnOp;
50 for (
Block &b : funcOp.getBody()) {
51 if (
auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
54 returnOp = candidateOp;
69 assert(tensorType &&
"expected TensorType");
72 if (
options.functionBoundaryTypeConversion ==
73 LayoutMapOption::IdentityLayoutMap) {
81 auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
82 index, BufferizationDialect::kBufferLayoutAttrName);
86 auto rankedMemrefType = memrefType.
dyn_cast<MemRefType>();
87 assert(rankedMemrefType &&
"buffer layout not supported on unranked tensors");
88 return MemRefType::get(
89 rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
90 layoutAttr.getValue(), rankedMemrefType.getMemorySpace());
95 SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
98 return dyn_cast_or_null<FuncOp>(
103 static const FuncAnalysisState &
105 assert(isa<OneShotAnalysisState>(state) &&
"expected OneShotAnalysisState");
107 .getExtension<FuncAnalysisState>();
108 assert(result &&
"FuncAnalysisState does not exist");
115 if (!isa<OneShotAnalysisState>(state))
118 .getExtension<FuncAnalysisState>();
121 const auto &analyzedFuncOps = funcState->analyzedFuncOps;
122 auto it = analyzedFuncOps.find(funcOp);
123 if (it == analyzedFuncOps.end())
130 static std::optional<int64_t>
132 int64_t returnValIdx) {
138 auto retValIt = funcOpIt->getSecond().find(returnValIdx);
139 if (retValIt == funcOpIt->getSecond().end())
143 return retValIt->getSecond();
147 :
public BufferizableOpInterface::ExternalModel<CallOpInterface,
151 func::CallOp callOp = cast<func::CallOp>(op);
153 assert(funcOp &&
"expected CallOp to a FuncOp");
160 return funcState.
readBbArgs.lookup(funcOp).contains(
166 func::CallOp callOp = cast<func::CallOp>(op);
168 assert(funcOp &&
"expected CallOp to a FuncOp");
181 func::CallOp callOp = cast<func::CallOp>(op);
183 assert(funcOp &&
"expected CallOp to a FuncOp");
190 result.push_back(opResult);
196 auto aliasingReturnVals =
200 for (int64_t resultIdx : aliasingReturnVals)
201 result.push_back(callOp->getOpResult(resultIdx));
208 func::CallOp callOp = cast<func::CallOp>(op);
210 assert(funcOp &&
"expected CallOp to a FuncOp");
216 if (opOperand.get().getType().isa<
TensorType>())
217 result.push_back(&opOperand);
226 for (int64_t bbArgIdx : aliasingFuncArgs)
227 result.push_back(&callOp->getOpOperand(bbArgIdx));
233 func::CallOp callOp = cast<func::CallOp>(op);
235 assert(funcOp &&
"expected CallOp to a FuncOp");
243 std::optional<int64_t> maybeEquiv =
249 assert(aliasingOpOperands.size() == 1 &&
250 "expected exactly 1 aliasing OpOperand");
251 assert(aliasingOpOperands.front()->getOperandNumber() == *maybeEquiv &&
252 "inconsistent analysis state");
263 func::CallOp callOp = cast<func::CallOp>(op);
264 unsigned numResults = callOp.getNumResults();
265 unsigned numOperands = callOp->getNumOperands();
267 assert(funcOp &&
"expected CallOp to a FuncOp");
268 FunctionType funcType = funcOp.getFunctionType();
284 unsigned returnValIdx = it.index();
285 Type returnType = it.value();
288 retValMapping[returnValIdx] = resultTypes.size();
289 resultTypes.push_back(returnType);
294 retValMapping[returnValIdx] = resultTypes.size();
295 resultTypes.push_back(funcType.getResult(resultTypes.size()));
299 for (
OpOperand &opOperand : callOp->getOpOperands()) {
300 unsigned idx = opOperand.getOperandNumber();
301 Value tensorOperand = opOperand.get();
305 newOperands[idx] = tensorOperand;
310 Value buffer = newOperands[idx];
316 buffer = *maybeBuffer;
320 auto memRefType = funcType.getInput(idx);
326 if (buffer.
getType() != memRefType) {
328 memref::CastOp::areCastCompatible(buffer.
getType(), memRefType) &&
329 "CallOp::bufferize: cast incompatible");
330 Value castBuffer = rewriter.
create<memref::CastOp>(callOp.getLoc(),
334 newOperands[idx] = buffer;
339 callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands);
340 newCallOp->
setAttrs(callOp->getAttrs());
342 for (
unsigned i = 0; i < replacementValues.size(); ++i) {
343 if (replacementValues[i])
345 replacementValues[i] = newCallOp->
getResult(*retValMapping[i]);
356 :
public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
376 auto returnOp = cast<func::ReturnOp>(op);
377 assert(isa<FuncOp>(returnOp->getParentOp()) &&
378 "only support FuncOp parent for ReturnOp");
387 :
public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
397 auto funcOp = cast<FuncOp>(op);
398 FunctionType funcType = funcOp.getFunctionType();
403 Type argType = it.value();
409 argTypes.push_back(argType);
415 if (funcOp.getBody().empty()) {
417 for (
Type resultType : funcType.getResults()) {
419 return funcOp->emitError() <<
"cannot bufferize bodiless function "
420 <<
"that returns a tensor";
421 retTypes.push_back(resultType);
423 funcOp.setType(FunctionType::get(op->
getContext(), argTypes, retTypes));
429 assert(returnOp &&
"expected func with single return op");
435 auto tensorType = bbArg.getType().dyn_cast<
TensorType>();
443 bbArgUses.push_back(&use);
448 bbArg.setType(memrefType);
452 if (!bbArgUses.empty()) {
456 rewriter.
create<bufferization::ToTensorOp>(funcOp.getLoc(), bbArg);
458 use->set(toTensorOp);
464 for (
OpOperand &returnOperand : returnOp->getOpOperands()) {
465 Value returnVal = returnOperand.get();
471 returnValues.push_back(returnVal);
476 if (
options.functionBoundaryTypeConversion ==
477 LayoutMapOption::IdentityLayoutMap) {
483 Value toMemrefOp = rewriter.
create<bufferization::ToMemrefOp>(
484 loc, resultType, returnVal);
485 returnValues.push_back(toMemrefOp);
489 returnOp.getOperandsMutable().assign(returnValues);
492 funcOp.setType(FunctionType::get(op->
getContext(), argTypes,
501 auto funcOp = cast<FuncOp>(op);
503 assert(bbArg &&
"expected BlockArgument");
508 bbArg.
getArgNumber(), BufferizationDialect::kWritableAttrName))
509 return writable.getValue();
523 func::CallOp::attachInterface<func_ext::CallOpInterface>(*ctx);
524 func::FuncOp::attachInterface<func_ext::FuncOpInterface>(*ctx);
525 func::ReturnOp::attachInterface<func_ext::ReturnOpInterface>(*ctx);
static llvm::ManagedStatic< PassManagerOptions > options
This class provides a shared interface for ranked and unranked memref types.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
BlockArgListType getArguments()
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
void setAttrs(DictionaryAttr newAttrs)
Set the attribute dictionary on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
MutableArrayRef< OpOperand > getOpOperands()
result_range getOpResults()
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
AnalysisState provides a variety of helper functions for dealing with tensor values.
State for analysis-enabled bufferization.
static std::optional< int64_t > getEquivalentFuncArgIdx(FuncOp funcOp, const FuncAnalysisState &state, int64_t returnValIdx)
Return the index of the bbArg in the given FuncOp that is equivalent to the specified return value (i...
FuncOpAnalysisState
The state of analysis of a FuncOp.
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry)
static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, FuncOp funcOp)
Return the state (phase) of analysis of the FuncOp.
static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp)
Return the unique ReturnOp that terminates funcOp.
static const FuncAnalysisState & getFuncAnalysisState(const AnalysisState &state)
Get FuncAnalysisState.
static FuncOp getCalledFunction(CallOpInterface callOp)
Return the FuncOp called by callOp.
static BaseMemRefType getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, const BufferizationOptions &options)
Return the index-th bufferized function argument type.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values)
Replace an op with replacement values.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
FailureOr< Value > getBuffer(RewriterBase &rewriter, Value value, const BufferizationOptions &options)
Lookup the buffer for the given value.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
BufferRelation
Specify fine-grain relationship between buffers to enable more analysis.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
Options for BufferizableOpInterface-based bufferization.
SmallVector< OpOperand * > getAliasingOpOperand(Operation *op, OpResult opResult, const AnalysisState &state) const
BufferRelation bufferRelation(Operation *op, OpResult opResult, const AnalysisState &state) const
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const
All function arguments are writable.
SmallVector< OpResult > getAliasingOpResult(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
Extra analysis state that is required for bufferization of function boundaries.
DenseMap< FuncOp, IndexMapping > equivalentFuncArgs
A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg indices.
DenseMap< FuncOp, IndexToIndexListMapping > aliasingReturnVals
A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
DenseMap< int64_t, SmallVector< int64_t > > IndexToIndexListMapping
A mapping of indices to a list of indices.
DenseMap< FuncOp, BbArgIndexSet > readBbArgs
A set of all read BlockArguments of FuncOps.
DenseSet< int64_t > BbArgIndexSet
A set of block argument indices.
DenseMap< FuncOp, IndexToIndexListMapping > aliasingFuncArgs
A mapping of ReturnOp OpOperand indices to aliasing FuncOp BBArg indices.
DenseMap< FuncOp, BbArgIndexSet > writtenBbArgs
A set of all written-to BlockArguments of FuncOps.
DenseMap< FuncOp, FuncOpAnalysisState > analyzedFuncOps
Keep track of which FuncOps are fully analyzed or currently being analyzed.
void startFunctionAnalysis(FuncOp funcOp)
This function is called right before analyzing the given FuncOp.
DenseMap< int64_t, int64_t > IndexMapping
A mapping of indices to indices.
bool isWritable(Operation *op, Value value, const AnalysisState &state) const
Return true if the given function argument is writable.
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const
Rewrite function bbArgs and return values into buffer form.
SmallVector< OpResult > getAliasingOpResult(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const