73#include "llvm/ADT/MapVector.h"
74#include "llvm/ADT/SmallVectorExtras.h"
95static void annotateEquivalentReturnBbArg(
OpOperand &returnVal,
97 const char *kEquivalentArgsAttr =
"__equivalent_func_args__";
101 if (op->
hasAttr(kEquivalentArgsAttr)) {
102 auto attr = cast<ArrayAttr>(op->
getAttr(kEquivalentArgsAttr));
103 equivBbArgs = llvm::map_to_vector<4>(attr, [](
Attribute a) {
104 return cast<IntegerAttr>(a).getValue().getSExtValue();
112 op->
setAttr(kEquivalentArgsAttr,
b.getI64ArrayAttr(equivBbArgs));
120 if (funcOp.getBody().empty()) {
123 FunctionType type = funcOp.getFunctionType();
124 for (
const auto &inputIt : llvm::enumerate(type.getInputs())) {
125 if (!isa<TensorType>(inputIt.value()))
127 for (
const auto &resultIt : llvm::enumerate(type.getResults())) {
128 if (!isa<TensorType>(resultIt.value()))
130 int64_t returnIdx = resultIt.index();
131 int64_t bbArgIdx = inputIt.index();
142 if (returnOps.empty()) {
143 return funcOp.emitError(
"cannot bufferize func.func without func.return");
148 if (isa<RankedTensorType>(bbArg.
getType())) {
152 for (func::ReturnOp returnOp : returnOps) {
153 for (
OpOperand &returnVal : returnOp->getOpOperands()) {
154 if (isa<RankedTensorType>(returnVal.
get().
getType())) {
157 aliases.insert(returnIdx);
170 auto findEquivalentBlockArgIdx =
171 [&](
OpOperand &opOperand) -> std::optional<int64_t> {
172 Value v = opOperand.get();
173 if (!isa<TensorType>(v.
getType()))
176 if (isa<RankedTensorType>(bbArg.
getType())) {
179 annotateEquivalentReturnBbArg(opOperand, bbArg);
187 int64_t numResults = returnOps.front()->getNumOperands();
188 for (
int64_t i = 0; i < numResults; ++i) {
191 std::optional<int64_t> maybeEquiv =
192 findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i));
193 if (!maybeEquiv.has_value())
195 int64_t bbArgIdx = *maybeEquiv;
196 bool allEquiv =
true;
202 for (func::ReturnOp returnOp :
ArrayRef(returnOps).drop_front()) {
203 std::optional<int64_t> maybeEquiv =
204 findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
205 if (maybeEquiv != bbArgIdx) {
220static void annotateFuncArgAccess(func::FuncOp funcOp,
int64_t idx,
bool isRead,
224 if (isRead && isWritten) {
225 accessType =
b.getStringAttr(
"read-write");
227 accessType =
b.getStringAttr(
"read");
228 }
else if (isWritten) {
229 accessType =
b.getStringAttr(
"write");
231 accessType =
b.getStringAttr(
"none");
233 funcOp.setArgAttr(idx, BufferizationDialect::kBufferAccessAttrName,
243 for (
int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
246 if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
250 if (
auto accessAttr = funcOp.getArgAttrOfType<StringAttr>(
251 idx, BufferizationDialect::kBufferAccessAttrName)) {
253 StringRef str = accessAttr.getValue();
254 isRead = str ==
"read" || str ==
"read-write";
255 isWritten = str ==
"write" || str ==
"read-write";
256 }
else if (funcOp.getBody().empty()) {
264 isRead = state.isValueRead(bbArg);
269 annotateFuncArgAccess(funcOp, idx, isRead, isWritten);
284 BufferizationDialect::kBufferLayoutAttrName);
286 BufferizationDialect::kWritableAttrName);
293 return dyn_cast_or_null<func::FuncOp>(
294 callOp.resolveCallableInTable(&symbolTable));
299 return llvm::any_of(funcOp.getFunctionType().getInputs(),
300 llvm::IsaPred<TensorType>) ||
301 llvm::any_of(funcOp.getFunctionType().getResults(),
302 llvm::IsaPred<TensorType>);
322 llvm::MapVector<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
325 for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
327 numberCallOpsContainedInFuncOp[funcOp] = 0;
330 assert(calledFunction &&
"could not retrieved called func::FuncOp");
336 callerMap[calledFunction].insert(callOp);
337 if (calledBy[calledFunction].insert(funcOp)) {
338 numberCallOpsContainedInFuncOp[funcOp]++;
352 for (
const auto &entry : numberCallOpsContainedInFuncOp) {
353 if (entry.second == 0)
354 worklist.push_back(entry.first);
357 while (!worklist.empty()) {
358 func::FuncOp
func = worklist.pop_back_val();
359 orderedFuncOps.push_back(
func);
361 for (func::FuncOp caller : calledBy[
func]) {
362 auto &count = numberCallOpsContainedInFuncOp[caller];
365 worklist.push_back(caller);
368 numberCallOpsContainedInFuncOp.erase(
func);
373 for (
auto it : numberCallOpsContainedInFuncOp)
374 remainingFuncOps.push_back(it.first);
385 return castOp.getSource();
393 assert(!returnOps.empty() &&
"expected at least one ReturnOp");
394 int numOperands = returnOps.front()->getNumOperands();
400 for (
int i = 0; i < numOperands; ++i) {
402 Type t = getSourceType(returnOps.front()->getOperand(i));
405 for (
int j = 1; j < static_cast<int>(returnOps.size()); ++
j)
406 if (getSourceType(returnOps[
j]->getOperand(i)) != t)
424 if (funcOp.getBody().empty())
432 for (func::ReturnOp returnOp : returnOps) {
433 for (
OpOperand &operand : returnOp->getOpOperands()) {
435 if (resultTypes[operand.getOperandNumber()]) {
443 for (
int i = 0; i < static_cast<int>(resultTypes.size()); ++i) {
446 resultTypes[i] = funcOp.getFunctionType().getResult(i);
450 auto newFuncType = FunctionType::get(
451 funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
452 funcOp.setType(newFuncType);
459 assert(state.
getOptions().bufferizeFunctionBoundaries &&
460 "expected that function boundary bufferization is activated");
475 remainingFuncOps, callerMap,
481 for (func::FuncOp funcOp : orderedFuncOps) {
489 if (failed(
analyzeOp(funcOp, state, statistics)))
493 if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) ||
494 failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState)))
502 for (func::FuncOp funcOp : remainingFuncOps) {
507 if (failed(
analyzeOp(funcOp, state, statistics)))
524 for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
535 assert(
options.bufferizeFunctionBoundaries &&
536 "expected that function boundary bufferization is activated");
556 remainingFuncOps, callerMap,
557 state.getSymbolTables())))
559 llvm::append_range(orderedFuncOps, remainingFuncOps);
562 for (func::FuncOp funcOp : orderedFuncOps) {
566 if (llvm::is_contained(
options.noAnalysisFuncFilter, funcOp.getSymName())) {
570 updatedOptions.copyBeforeWrite =
true;
571 if (failed(
bufferizeOp(funcOp, updatedOptions, state, statistics)))
579 if (
options.inferFunctionResultLayout)
587 llvm::make_early_inc_range(block.getOperations())) {
606 assert(
options.bufferizeFunctionBoundaries &&
607 "expected that function boundary bufferization is activated");
609 "invalid combination of bufferization flags");
610 if (!
options.copyBeforeWrite) {
611 if (
options.noAnalysisFuncFilter.empty()) {
617 OpFilter::Entry::FilterFn analysisFilterFn = [=](
Operation *op) {
618 auto func = dyn_cast<func::FuncOp>(op);
622 return llvm::is_contained(
options.noAnalysisFuncFilter,
627 updatedOptions.opFilter.denyOperation(analysisFilterFn);
DenseMap< func::FuncOp, DenseSet< Operation * > > FuncCallerMap
A mapping of FuncOps to their callers.
static SmallVector< Type > getReturnTypes(SmallVector< func::ReturnOp > returnOps)
Helper function that returns the return types (skipping casts) of the given func.return ops.
static FuncAnalysisState & getOrCreateFuncAnalysisState(OneShotAnalysisState &state)
Get or create FuncAnalysisState.
static bool hasTensorSignature(func::FuncOp funcOp)
Return "true" if the given function signature has tensor semantics.
static void removeBufferizationAttributes(BlockArgument bbArg)
Remove bufferization attributes on FuncOp arguments.
static void foldMemRefCasts(func::FuncOp funcOp)
Fold return values that are memref casts and update function return types.
static Value unpackCast(Value v)
Helper function that extracts the source from a memref.cast.
static LogicalResult getFuncOpsOrderedByCalls(Operation *moduleOp, SmallVectorImpl< func::FuncOp > &orderedFuncOps, SmallVectorImpl< func::FuncOp > &remainingFuncOps, FuncCallerMap &callerMap, SymbolTableCollection &symbolTables)
Store all functions of the moduleOp in orderedFuncOps, sorted by callee-caller order (i....
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block * getOwner() const
Returns the block that owns this argument.
Block represents an ordered list of Operations.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class helps build Operations.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
unsigned getNumOperands()
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
State for analysis-enabled bufferization.
bool isValueWritten(Value value) const
Return true if the buffer of the given tensor value is written to.
Ty & addExtension(Args &&...args)
Adds a new Extension of the type specified as template parameter, constructing it with the arguments ...
const OneShotBufferizationOptions & getOptions() const
Return a reference to the BufferizationOptions.
Ty * getExtension()
Returns the extension of the specified type.
bool areEquivalentBufferizedValues(Value v1, Value v2) const override
Return true if v1 and v2 bufferize to equivalent buffers.
bool areAliasingBufferizedValues(Value v1, Value v2) const override
Return true if v1 and v2 may bufferize to aliasing buffers.
Operation * getOwner() const
Return the owner of this operand.
static FuncOp getCalledFunction(CallOpInterface callOp, SymbolTableCollection &symbolTables)
Return the FuncOp called by callOp.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
llvm::LogicalResult bufferizeModuleOp(Operation *moduleOp, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Bufferize an ops nested ops that implement BufferizableOpInterface.
void removeBufferizationAttributesInModule(Operation *moduleOp)
Remove bufferization attributes on every FuncOp arguments in the SymbolTable op.
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
SmallVector< func::ReturnOp > getReturnOps(func::FuncOp funcOp)
Helper function that returns all func.return ops in the given function.
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, const BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
llvm::LogicalResult analyzeModuleOp(Operation *moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze moduleOp and its nested ops.
llvm::LogicalResult runOneShotModuleBufferize(Operation *moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given SymbolTable.
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Bufferization statistics for debugging.
Options for analysis-enabled bufferization.
Extra analysis state that is required for bufferization of function boundaries.
DenseMap< FuncOp, IndexMapping > equivalentFuncArgs
A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg indices.
DenseMap< FuncOp, IndexToIndexListMapping > aliasingReturnVals
A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
SymbolTableCollection symbolTables
A collection of cached SymbolTables used for faster function lookup.
DenseMap< FuncOp, BbArgIndexSet > readBbArgs
A set of all read BlockArguments of FuncOps.
DenseMap< FuncOp, BbArgIndexSet > writtenBbArgs
A set of all written-to BlockArguments of FuncOps.
DenseMap< FuncOp, FuncOpAnalysisState > analyzedFuncOps
Keep track of which FuncOps are fully analyzed or currently being analyzed.
void startFunctionAnalysis(FuncOp funcOp)
This function is called right before analyzing the given FuncOp.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.