74#include "llvm/ADT/MapVector.h"
75#include "llvm/ADT/SmallVectorExtras.h"
96static void annotateEquivalentReturnBbArg(
OpOperand &returnVal,
98 const char *kEquivalentArgsAttr =
"__equivalent_func_args__";
102 if (op->
hasAttr(kEquivalentArgsAttr)) {
103 auto attr = cast<ArrayAttr>(op->
getAttr(kEquivalentArgsAttr));
104 equivBbArgs = llvm::map_to_vector<4>(attr, [](
Attribute a) {
105 return cast<IntegerAttr>(a).getValue().getSExtValue();
113 op->
setAttr(kEquivalentArgsAttr,
b.getI64ArrayAttr(equivBbArgs));
121 if (funcOp.getBody().empty()) {
124 FunctionType type = funcOp.getFunctionType();
125 for (
const auto &inputIt : llvm::enumerate(type.getInputs())) {
126 if (!isa<TensorLikeType>(inputIt.value()))
128 for (
const auto &resultIt : llvm::enumerate(type.getResults())) {
129 if (!isa<TensorLikeType>(resultIt.value()))
131 int64_t returnIdx = resultIt.index();
132 int64_t bbArgIdx = inputIt.index();
143 if (returnOps.empty()) {
144 return funcOp.emitError(
"cannot bufferize func.func without func.return");
149 if (isa<TensorLikeType>(bbArg.
getType())) {
153 for (func::ReturnOp returnOp : returnOps) {
154 for (
OpOperand &returnVal : returnOp->getOpOperands()) {
155 if (isa<TensorLikeType>(returnVal.
get().
getType())) {
158 aliases.insert(returnIdx);
171 auto findEquivalentBlockArgIdx =
172 [&](
OpOperand &opOperand) -> std::optional<int64_t> {
173 Value v = opOperand.get();
174 if (!isa<TensorLikeType>(v.
getType()))
177 if (isa<TensorLikeType>(bbArg.
getType())) {
180 annotateEquivalentReturnBbArg(opOperand, bbArg);
188 int64_t numResults = returnOps.front()->getNumOperands();
189 for (
int64_t i = 0; i < numResults; ++i) {
192 std::optional<int64_t> maybeEquiv =
193 findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i));
194 if (!maybeEquiv.has_value())
196 int64_t bbArgIdx = *maybeEquiv;
197 bool allEquiv =
true;
203 for (func::ReturnOp returnOp :
ArrayRef(returnOps).drop_front()) {
204 std::optional<int64_t> maybeEquiv =
205 findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
206 if (maybeEquiv != bbArgIdx) {
221static void annotateFuncArgAccess(func::FuncOp funcOp,
int64_t idx,
bool isRead,
225 if (isRead && isWritten) {
226 accessType =
b.getStringAttr(
"read-write");
228 accessType =
b.getStringAttr(
"read");
229 }
else if (isWritten) {
230 accessType =
b.getStringAttr(
"write");
232 accessType =
b.getStringAttr(
"none");
234 funcOp.setArgAttr(idx, BufferizationDialect::kBufferAccessAttrName,
244 for (
int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
247 if (!isa<TensorLikeType>(funcOp.getFunctionType().getInput(idx)))
251 if (
auto accessAttr = funcOp.getArgAttrOfType<StringAttr>(
252 idx, BufferizationDialect::kBufferAccessAttrName)) {
254 StringRef str = accessAttr.getValue();
255 isRead = str ==
"read" || str ==
"read-write";
256 isWritten = str ==
"write" || str ==
"read-write";
257 }
else if (funcOp.getBody().empty()) {
265 isRead = state.isValueRead(bbArg);
270 annotateFuncArgAccess(funcOp, idx, isRead, isWritten);
285 BufferizationDialect::kBufferLayoutAttrName);
287 BufferizationDialect::kWritableAttrName);
294 return dyn_cast_or_null<func::FuncOp>(
295 callOp.resolveCallableInTable(&symbolTable));
300 return llvm::any_of(funcOp.getFunctionType().getInputs(),
301 llvm::IsaPred<TensorLikeType>) ||
302 llvm::any_of(funcOp.getFunctionType().getResults(),
303 llvm::IsaPred<TensorLikeType>);
323 llvm::MapVector<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
326 for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
328 numberCallOpsContainedInFuncOp[funcOp] = 0;
331 assert(calledFunction &&
"could not retrieved called func::FuncOp");
337 callerMap[calledFunction].insert(callOp);
338 if (calledBy[calledFunction].insert(funcOp)) {
339 numberCallOpsContainedInFuncOp[funcOp]++;
353 for (
const auto &entry : numberCallOpsContainedInFuncOp) {
354 if (entry.second == 0)
355 worklist.push_back(entry.first);
358 while (!worklist.empty()) {
359 func::FuncOp
func = worklist.pop_back_val();
360 orderedFuncOps.push_back(
func);
362 for (func::FuncOp caller : calledBy[
func]) {
363 auto &count = numberCallOpsContainedInFuncOp[caller];
366 worklist.push_back(caller);
369 numberCallOpsContainedInFuncOp.erase(
func);
374 for (
auto it : numberCallOpsContainedInFuncOp)
375 remainingFuncOps.push_back(it.first);
391 if (isa<UnrankedMemRefType>(castOp.getSource().getType()) &&
394 return castOp.getSource();
402 assert(!returnOps.empty() &&
"expected at least one ReturnOp");
403 int numOperands = returnOps.front()->getNumOperands();
409 for (
int i = 0; i < numOperands; ++i) {
411 Type t = getSourceType(returnOps.front()->getOperand(i));
414 for (
int j = 1; j < static_cast<int>(returnOps.size()); ++
j)
415 if (getSourceType(returnOps[
j]->getOperand(i)) != t)
433 if (funcOp.getBody().empty())
441 for (func::ReturnOp returnOp : returnOps) {
442 for (
OpOperand &operand : returnOp->getOpOperands()) {
444 if (resultTypes[operand.getOperandNumber()]) {
452 for (
int i = 0; i < static_cast<int>(resultTypes.size()); ++i) {
455 resultTypes[i] = funcOp.getFunctionType().getResult(i);
459 auto newFuncType = FunctionType::get(
460 funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
461 funcOp.setType(newFuncType);
468 assert(state.
getOptions().bufferizeFunctionBoundaries &&
469 "expected that function boundary bufferization is activated");
484 remainingFuncOps, callerMap,
490 for (func::FuncOp funcOp : orderedFuncOps) {
498 if (failed(
analyzeOp(funcOp, state, statistics)))
502 if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) ||
503 failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState)))
511 for (func::FuncOp funcOp : remainingFuncOps) {
516 if (failed(
analyzeOp(funcOp, state, statistics)))
533 for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
544 assert(
options.bufferizeFunctionBoundaries &&
545 "expected that function boundary bufferization is activated");
565 remainingFuncOps, callerMap,
566 state.getSymbolTables())))
568 llvm::append_range(orderedFuncOps, remainingFuncOps);
571 for (func::FuncOp funcOp : orderedFuncOps) {
575 if (llvm::is_contained(
options.noAnalysisFuncFilter, funcOp.getSymName())) {
579 updatedOptions.copyBeforeWrite =
true;
580 if (failed(
bufferizeOp(funcOp, updatedOptions, state, statistics)))
588 if (
options.inferFunctionResultLayout)
596 llvm::make_early_inc_range(block.getOperations())) {
615 assert(
options.bufferizeFunctionBoundaries &&
616 "expected that function boundary bufferization is activated");
618 "invalid combination of bufferization flags");
619 if (!
options.copyBeforeWrite) {
620 if (
options.noAnalysisFuncFilter.empty()) {
626 OpFilter::Entry::FilterFn analysisFilterFn = [=](
Operation *op) {
627 auto func = dyn_cast<func::FuncOp>(op);
631 return llvm::is_contained(
options.noAnalysisFuncFilter,
636 updatedOptions.opFilter.denyOperation(analysisFilterFn);
DenseMap< func::FuncOp, DenseSet< Operation * > > FuncCallerMap
A mapping of FuncOps to their callers.
static SmallVector< Type > getReturnTypes(SmallVector< func::ReturnOp > returnOps)
Helper function that returns the return types (skipping casts) of the given func.return ops.
static FuncAnalysisState & getOrCreateFuncAnalysisState(OneShotAnalysisState &state)
Get or create FuncAnalysisState.
static bool hasTensorSignature(func::FuncOp funcOp)
Return "true" if the given function signature has tensor semantics.
static void removeBufferizationAttributes(BlockArgument bbArg)
Remove bufferization attributes on FuncOp arguments.
static void foldMemRefCasts(func::FuncOp funcOp)
Fold return values that are memref casts and update function return types.
static Value unpackCast(Value v)
Helper function that extracts the source from a memref.cast.
static LogicalResult getFuncOpsOrderedByCalls(Operation *moduleOp, SmallVectorImpl< func::FuncOp > &orderedFuncOps, SmallVectorImpl< func::FuncOp > &remainingFuncOps, FuncCallerMap &callerMap, SymbolTableCollection &symbolTables)
Store all functions of the moduleOp in orderedFuncOps, sorted by callee-caller order (i....
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
This class represents an argument of a Block.
unsigned getArgNumber() const
Returns the number of this argument.
Block * getOwner() const
Returns the block that owns this argument.
Block represents an ordered list of Operations.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class helps build Operations.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
unsigned getNumOperands()
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
State for analysis-enabled bufferization.
bool isValueWritten(Value value) const
Return true if the buffer of the given tensor value is written to.
Ty & addExtension(Args &&...args)
Adds a new Extension of the type specified as template parameter, constructing it with the arguments ...
const OneShotBufferizationOptions & getOptions() const
Return a reference to the BufferizationOptions.
Ty * getExtension()
Returns the extension of the specified type.
bool areEquivalentBufferizedValues(Value v1, Value v2) const override
Return true if v1 and v2 bufferize to equivalent buffers.
bool areAliasingBufferizedValues(Value v1, Value v2) const override
Return true if v1 and v2 may bufferize to aliasing buffers.
Operation * getOwner() const
Return the owner of this operand.
static FuncOp getCalledFunction(CallOpInterface callOp, SymbolTableCollection &symbolTables)
Return the FuncOp called by callOp.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
llvm::LogicalResult bufferizeModuleOp(Operation *moduleOp, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Bufferize an ops nested ops that implement BufferizableOpInterface.
void removeBufferizationAttributesInModule(Operation *moduleOp)
Remove bufferization attributes on every FuncOp arguments in the SymbolTable op.
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
SmallVector< func::ReturnOp > getReturnOps(func::FuncOp funcOp)
Helper function that returns all func.return ops in the given function.
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, const BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
llvm::LogicalResult analyzeModuleOp(Operation *moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze moduleOp and its nested ops.
llvm::LogicalResult runOneShotModuleBufferize(Operation *moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given SymbolTable.
Include the generated interface declarations.
llvm::SetVector< T, Vector, Set, N > SetVector
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Bufferization statistics for debugging.
Options for analysis-enabled bufferization.
Extra analysis state that is required for bufferization of function boundaries.
DenseMap< FuncOp, IndexMapping > equivalentFuncArgs
A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg indices.
DenseMap< FuncOp, IndexToIndexListMapping > aliasingReturnVals
A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
SymbolTableCollection symbolTables
A collection of cached SymbolTables used for faster function lookup.
DenseMap< FuncOp, BbArgIndexSet > readBbArgs
A set of all read BlockArguments of FuncOps.
DenseMap< FuncOp, BbArgIndexSet > writtenBbArgs
A set of all written-to BlockArguments of FuncOps.
DenseMap< FuncOp, FuncOpAnalysisState > analyzedFuncOps
Keep track of which FuncOps are fully analyzed or currently being analyzed.
void startFunctionAnalysis(FuncOp funcOp)
This function is called right before analyzing the given FuncOp.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.