94 static void annotateEquivalentReturnBbArg(
OpOperand &returnVal,
96 const char *kEquivalentArgsAttr =
"__equivalent_func_args__";
100 if (op->
hasAttr(kEquivalentArgsAttr)) {
101 auto attr = cast<ArrayAttr>(op->
getAttr(kEquivalentArgsAttr));
102 equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](
Attribute a) {
103 return cast<IntegerAttr>(a).getValue().getSExtValue();
111 op->
setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
119 if (funcOp.getBody().empty()) {
122 FunctionType type = funcOp.getFunctionType();
124 if (!isa<TensorType>(inputIt.value()))
127 if (!isa<TensorType>(resultIt.value()))
129 int64_t returnIdx = resultIt.index();
130 int64_t bbArgIdx = inputIt.index();
139 assert(!returnOps.empty() &&
"expected at least one ReturnOp");
143 if (isa<RankedTensorType>(bbArg.
getType())) {
147 for (func::ReturnOp returnOp : returnOps) {
148 for (
OpOperand &returnVal : returnOp->getOpOperands()) {
149 if (isa<RankedTensorType>(returnVal.
get().
getType())) {
151 if (state.areAliasingBufferizedValues(returnVal.
get(), bbArg))
152 aliases.insert(returnIdx);
156 for (int64_t alias : aliases)
165 auto findEquivalentBlockArgIdx =
166 [&](
OpOperand &opOperand) -> std::optional<int64_t> {
167 Value v = opOperand.get();
168 if (!isa<TensorType>(v.
getType()))
171 if (isa<RankedTensorType>(bbArg.
getType())) {
172 if (state.areEquivalentBufferizedValues(v, bbArg)) {
173 if (state.getOptions().testAnalysisOnly)
174 annotateEquivalentReturnBbArg(opOperand, bbArg);
182 int64_t numResults = returnOps.front()->getNumOperands();
183 for (int64_t i = 0; i < numResults; ++i) {
186 std::optional<int64_t> maybeEquiv =
187 findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i));
188 if (!maybeEquiv.has_value())
190 int64_t bbArgIdx = *maybeEquiv;
191 bool allEquiv =
true;
197 for (func::ReturnOp returnOp :
ArrayRef(returnOps).drop_front()) {
198 std::optional<int64_t> maybeEquiv =
199 findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
200 if (maybeEquiv != bbArgIdx) {
215 static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx,
bool isRead,
219 if (isRead && isWritten) {
220 accessType = b.getStringAttr(
"read-write");
222 accessType = b.getStringAttr(
"read");
223 }
else if (isWritten) {
224 accessType = b.getStringAttr(
"write");
226 accessType = b.getStringAttr(
"none");
228 funcOp.setArgAttr(idx, BufferizationDialect::kBufferAccessAttrName,
238 for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
241 if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
245 if (
auto accessAttr = funcOp.getArgAttrOfType<StringAttr>(
246 idx, BufferizationDialect::kBufferAccessAttrName)) {
248 StringRef str = accessAttr.getValue();
249 isRead = str ==
"read" || str ==
"read-write";
250 isWritten = str ==
"write" || str ==
"read-write";
251 }
else if (funcOp.getBody().empty()) {
259 isRead = state.isValueRead(bbArg);
260 isWritten = state.isValueWritten(bbArg);
263 if (state.getOptions().testAnalysisOnly)
264 annotateFuncArgAccess(funcOp, idx, isRead, isWritten);
279 BufferizationDialect::kBufferLayoutAttrName);
281 BufferizationDialect::kWritableAttrName);
289 llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
292 return dyn_cast_or_null<func::FuncOp>(
298 return llvm::any_of(funcOp.getFunctionType().getInputs(),
299 llvm::IsaPred<TensorType>) ||
300 llvm::any_of(funcOp.getFunctionType().getResults(),
301 llvm::IsaPred<TensorType>);
324 for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
326 numberCallOpsContainedInFuncOp[funcOp] = 0;
329 assert(calledFunction &&
"could not retrieved called func::FuncOp");
335 callerMap[calledFunction].insert(callOp);
336 if (calledBy[calledFunction].insert(funcOp).second) {
337 numberCallOpsContainedInFuncOp[funcOp]++;
351 for (
const auto &entry : numberCallOpsContainedInFuncOp) {
352 if (entry.second == 0)
353 worklist.push_back(entry.first);
356 while (!worklist.empty()) {
357 func::FuncOp func = worklist.pop_back_val();
358 orderedFuncOps.push_back(func);
360 for (func::FuncOp caller : calledBy[func]) {
361 auto &count = numberCallOpsContainedInFuncOp[caller];
364 worklist.push_back(caller);
367 numberCallOpsContainedInFuncOp.erase(func);
372 for (
auto it : numberCallOpsContainedInFuncOp)
373 remainingFuncOps.push_back(it.first);
384 return castOp.getSource();
392 assert(!returnOps.empty() &&
"expected at least one ReturnOp");
393 int numOperands = returnOps.front()->getNumOperands();
399 for (
int i = 0; i < numOperands; ++i) {
401 Type t = getSourceType(returnOps.front()->getOperand(i));
405 if (getSourceType(returnOps[
j]->getOperand(i)) != t)
423 if (funcOp.getBody().empty())
431 for (func::ReturnOp returnOp : returnOps) {
432 for (
OpOperand &operand : returnOp->getOpOperands()) {
434 if (resultTypes[operand.getOperandNumber()]) {
442 for (
int i = 0; i < static_cast<int>(resultTypes.size()); ++i) {
445 resultTypes[i] = funcOp.getFunctionType().getResult(i);
450 funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
451 funcOp.setType(newFuncType);
458 assert(state.getOptions().bufferizeFunctionBoundaries &&
459 "expected that function boundary bufferization is activated");
474 remainingFuncOps, callerMap,
480 for (func::FuncOp funcOp : orderedFuncOps) {
481 if (!state.getOptions().isOpAllowed(funcOp))
488 if (failed(
analyzeOp(funcOp, state, statistics)))
492 if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) ||
493 failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState)))
501 for (func::FuncOp funcOp : remainingFuncOps) {
502 if (!state.getOptions().isOpAllowed(funcOp))
506 if (failed(
analyzeOp(funcOp, state, statistics)))
523 for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
534 assert(
options.bufferizeFunctionBoundaries &&
535 "expected that function boundary bufferization is activated");
555 remainingFuncOps, callerMap,
556 state.getSymbolTables())))
558 llvm::append_range(orderedFuncOps, remainingFuncOps);
561 for (func::FuncOp funcOp : orderedFuncOps) {
565 if (llvm::is_contained(
options.noAnalysisFuncFilter, funcOp.getSymName())) {
570 if (failed(
bufferizeOp(funcOp, updatedOptions, state, statistics)))
578 if (
options.inferFunctionResultLayout)
586 llvm::make_early_inc_range(block.getOperations())) {
605 assert(
options.bufferizeFunctionBoundaries &&
606 "expected that function boundary bufferization is activated");
608 "invalid combination of bufferization flags");
609 if (!
options.copyBeforeWrite) {
610 if (
options.noAnalysisFuncFilter.empty()) {
617 auto func = dyn_cast<func::FuncOp>(op);
621 return llvm::is_contained(
options.noAnalysisFuncFilter,
static bool hasTensorSignature(func::FuncOp funcOp)
Return "true" if the given function signature has tensor semantics.
static FuncAnalysisState & getOrCreateFuncAnalysisState(OneShotAnalysisState &state)
Get or create FuncAnalysisState.
static void removeBufferizationAttributes(BlockArgument bbArg)
Remove bufferization attributes on FuncOp arguments.
static void foldMemRefCasts(func::FuncOp funcOp)
Fold return values that are memref casts and update function return types.
static Value unpackCast(Value v)
Helper function that extracts the source from a memref.cast.
static SmallVector< Type > getReturnTypes(SmallVector< func::ReturnOp > returnOps)
Helper function that returns the return types (skipping casts) of the given func.return ops.
static LogicalResult getFuncOpsOrderedByCalls(Operation *moduleOp, SmallVectorImpl< func::FuncOp > &orderedFuncOps, SmallVectorImpl< func::FuncOp > &remainingFuncOps, FuncCallerMap &callerMap, SymbolTableCollection &symbolTables)
Store all functions of the moduleOp in orderedFuncOps, sorted by callee-caller order (i....
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class helps build Operations.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumOperands()
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
BufferizationState provides information about the state of the IR during the bufferization process.
State for analysis-enabled bufferization.
void denyOperation()
Deny the given ops.
Operation * getOwner() const
Return the owner of this operand.
static FuncOp getCalledFunction(CallOpInterface callOp, SymbolTableCollection &symbolTables)
Return the FuncOp called by callOp.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
llvm::LogicalResult bufferizeModuleOp(Operation *moduleOp, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Bufferize an ops nested ops that implement BufferizableOpInterface.
void removeBufferizationAttributesInModule(Operation *moduleOp)
Remove bufferization attributes on every FuncOp arguments in the SymbolTable op.
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
SmallVector< func::ReturnOp > getReturnOps(func::FuncOp funcOp)
Helper function that returns all func.return ops in the given function.
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, const BufferizationState &bufferizationState, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
llvm::LogicalResult analyzeModuleOp(Operation *moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze moduleOp and its nested ops.
llvm::LogicalResult runOneShotModuleBufferize(Operation *moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given SymbolTable.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool copyBeforeWrite
If set to true, the analysis is skipped.
OpFilter opFilter
A filter that specifies which ops should be bufferized and which ops should be ignored.
Bufferization statistics for debugging.
Options for analysis-enabled bufferization.
std::function< bool(Operation *)> FilterFn
If the filter function evaluates to true, the filter matches.
Extra analysis state that is required for bufferization of function boundaries.
DenseMap< FuncOp, IndexMapping > equivalentFuncArgs
A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg indices.
DenseMap< FuncOp, IndexToIndexListMapping > aliasingReturnVals
A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
SymbolTableCollection symbolTables
A collection of cached SymbolTables used for faster function lookup.
DenseMap< FuncOp, BbArgIndexSet > readBbArgs
A set of all read BlockArguments of FuncOps.
DenseMap< FuncOp, BbArgIndexSet > writtenBbArgs
A set of all written-to BlockArguments of FuncOps.
DenseMap< FuncOp, FuncOpAnalysisState > analyzedFuncOps
Keep track of which FuncOps are fully analyzed or currently being analyzed.
void startFunctionAnalysis(FuncOp funcOp)
This function is called right before analyzing the given FuncOp.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.