92 static void annotateEquivalentReturnBbArg(
OpOperand &returnVal,
94 const char *kEquivalentArgsAttr =
"__equivalent_func_args__";
98 if (op->
hasAttr(kEquivalentArgsAttr)) {
99 auto attr = cast<ArrayAttr>(op->
getAttr(kEquivalentArgsAttr));
100 equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](
Attribute a) {
101 return cast<IntegerAttr>(a).getValue().getSExtValue();
109 op->
setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
117 if (funcOp.getBody().empty()) {
120 FunctionType type = funcOp.getFunctionType();
122 if (!isa<TensorType>(inputIt.value()))
125 if (!isa<TensorType>(resultIt.value()))
127 int64_t returnIdx = resultIt.index();
128 int64_t bbArgIdx = inputIt.index();
137 assert(!returnOps.empty() &&
"expected at least one ReturnOp");
141 if (isa<RankedTensorType>(bbArg.
getType())) {
145 for (func::ReturnOp returnOp : returnOps) {
146 for (
OpOperand &returnVal : returnOp->getOpOperands()) {
147 if (isa<RankedTensorType>(returnVal.
get().
getType())) {
149 if (state.areAliasingBufferizedValues(returnVal.
get(), bbArg))
150 aliases.insert(returnIdx);
154 for (int64_t alias : aliases)
163 auto findEquivalentBlockArgIdx =
164 [&](
OpOperand &opOperand) -> std::optional<int64_t> {
165 Value v = opOperand.get();
166 if (!isa<TensorType>(v.
getType()))
169 if (isa<RankedTensorType>(bbArg.
getType())) {
170 if (state.areEquivalentBufferizedValues(v, bbArg)) {
171 if (state.getOptions().testAnalysisOnly)
172 annotateEquivalentReturnBbArg(opOperand, bbArg);
180 int64_t numResults = returnOps.front()->getNumOperands();
181 for (int64_t i = 0; i < numResults; ++i) {
184 std::optional<int64_t> maybeEquiv =
185 findEquivalentBlockArgIdx(returnOps.front()->getOpOperand(i));
186 if (!maybeEquiv.has_value())
188 int64_t bbArgIdx = *maybeEquiv;
189 bool allEquiv =
true;
195 for (func::ReturnOp returnOp :
ArrayRef(returnOps).drop_front()) {
196 std::optional<int64_t> maybeEquiv =
197 findEquivalentBlockArgIdx(returnOp->getOpOperand(i));
198 if (maybeEquiv != bbArgIdx) {
213 static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx,
bool isRead,
217 if (isRead && isWritten) {
218 accessType = b.getStringAttr(
"read-write");
220 accessType = b.getStringAttr(
"read");
221 }
else if (isWritten) {
222 accessType = b.getStringAttr(
"write");
224 accessType = b.getStringAttr(
"none");
226 funcOp.setArgAttr(idx, BufferizationDialect::kBufferAccessAttrName,
236 for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
239 if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
243 if (
auto accessAttr = funcOp.getArgAttrOfType<StringAttr>(
244 idx, BufferizationDialect::kBufferAccessAttrName)) {
246 StringRef str = accessAttr.getValue();
247 isRead = str ==
"read" || str ==
"read-write";
248 isWritten = str ==
"write" || str ==
"read-write";
249 }
else if (funcOp.getBody().empty()) {
257 isRead = state.isValueRead(bbArg);
258 isWritten = state.isValueWritten(bbArg);
261 if (state.getOptions().testAnalysisOnly)
262 annotateFuncArgAccess(funcOp, idx, isRead, isWritten);
277 BufferizationDialect::kBufferLayoutAttrName);
279 BufferizationDialect::kWritableAttrName);
285 llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
288 return dyn_cast_or_null<func::FuncOp>(
299 funcOp->walk([&](func::CallOp callOp) {
301 assert(calledFunction &&
"could not retrieved called func::FuncOp");
308 int64_t returnIdx = it.first;
309 int64_t bbargIdx = it.second;
310 if (!state.isInPlace(callOp->getOpOperand(bbargIdx)))
312 Value returnVal = callOp.getResult(returnIdx);
313 Value argVal = callOp->getOperand(bbargIdx);
314 state.unionEquivalenceClasses(returnVal, argVal);
323 return llvm::any_of(funcOp.getFunctionType().getInputs(),
324 llvm::IsaPred<TensorType>) ||
325 llvm::any_of(funcOp.getFunctionType().getResults(),
326 llvm::IsaPred<TensorType>);
348 numberCallOpsContainedInFuncOp[funcOp] = 0;
349 return funcOp.walk([&](func::CallOp callOp) ->
WalkResult {
351 assert(calledFunction &&
"could not retrieved called func::FuncOp");
357 callerMap[calledFunction].insert(callOp);
358 if (calledBy[calledFunction].insert(funcOp).second) {
359 numberCallOpsContainedInFuncOp[funcOp]++;
369 while (!numberCallOpsContainedInFuncOp.empty()) {
370 auto it = llvm::find_if(numberCallOpsContainedInFuncOp,
371 [](
auto entry) {
return entry.getSecond() == 0; });
372 if (it == numberCallOpsContainedInFuncOp.end())
374 orderedFuncOps.push_back(it->getFirst());
375 for (
auto callee : calledBy[it->getFirst()])
376 numberCallOpsContainedInFuncOp[callee]--;
377 numberCallOpsContainedInFuncOp.erase(it);
382 for (
auto it : numberCallOpsContainedInFuncOp)
383 remainingFuncOps.push_back(it.first);
394 return castOp.getSource();
402 assert(!returnOps.empty() &&
"expected at least one ReturnOp");
403 int numOperands = returnOps.front()->getNumOperands();
409 for (
int i = 0; i < numOperands; ++i) {
411 Type t = getSourceType(returnOps.front()->getOperand(i));
415 if (getSourceType(returnOps[
j]->getOperand(i)) != t)
433 if (funcOp.getBody().empty())
441 for (func::ReturnOp returnOp : returnOps) {
442 for (
OpOperand &operand : returnOp->getOpOperands()) {
444 if (resultTypes[operand.getOperandNumber()]) {
452 for (
int i = 0; i < static_cast<int>(resultTypes.size()); ++i) {
455 resultTypes[i] = funcOp.getFunctionType().getResult(i);
460 funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
461 funcOp.setType(newFuncType);
468 assert(state.getOptions().bufferizeFunctionBoundaries &&
469 "expected that function boundary bufferization is activated");
484 remainingFuncOps, callerMap)))
489 for (func::FuncOp funcOp : orderedFuncOps) {
490 if (!state.getOptions().isOpAllowed(funcOp))
500 if (failed(
analyzeOp(funcOp, state, statistics)))
504 if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) ||
505 failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState)))
513 for (func::FuncOp funcOp : remainingFuncOps) {
514 if (!state.getOptions().isOpAllowed(funcOp))
521 if (failed(
analyzeOp(funcOp, state, statistics)))
536 moduleOp.walk([&](func::FuncOp op) {
545 assert(
options.bufferizeFunctionBoundaries &&
546 "expected that function boundary bufferization is activated");
566 remainingFuncOps, callerMap)))
568 llvm::append_range(orderedFuncOps, remainingFuncOps);
571 for (func::FuncOp funcOp : orderedFuncOps) {
575 if (llvm::is_contained(
options.noAnalysisFuncFilter, funcOp.getSymName())) {
580 if (failed(
bufferizeOp(funcOp, updatedOptions, statistics)))
588 if (
options.inferFunctionResultLayout)
593 for (
Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
595 if (isa<func::FuncOp>(&op))
610 assert(
options.bufferizeFunctionBoundaries &&
611 "expected that function boundary bufferization is activated");
613 "invalid combination of bufferization flags");
614 if (!
options.copyBeforeWrite) {
615 if (
options.noAnalysisFuncFilter.empty()) {
622 auto func = dyn_cast<func::FuncOp>(op);
626 return llvm::is_contained(
options.noAnalysisFuncFilter,
static bool hasTensorSignature(func::FuncOp funcOp)
Return "true" if the given function signature has tensor semantics.
static FuncAnalysisState & getOrCreateFuncAnalysisState(OneShotAnalysisState &state)
Get or create FuncAnalysisState.
static void removeBufferizationAttributes(BlockArgument bbArg)
Remove bufferization attributes on FuncOp arguments.
static LogicalResult getFuncOpsOrderedByCalls(ModuleOp moduleOp, SmallVectorImpl< func::FuncOp > &orderedFuncOps, SmallVectorImpl< func::FuncOp > &remainingFuncOps, FuncCallerMap &callerMap)
Store all functions of the moduleOp in orderedFuncOps, sorted by callee-caller order (i....
static void foldMemRefCasts(func::FuncOp funcOp)
Fold return values that are memref casts and update function return types.
static Value unpackCast(Value v)
Helper function that extracts the source from a memref.cast.
static SmallVector< Type > getReturnTypes(SmallVector< func::ReturnOp > returnOps)
Helper function that returns the return types (skipping casts) of the given func.return ops.
static void equivalenceAnalysis(func::FuncOp funcOp, OneShotAnalysisState &state, FuncAnalysisState &funcState)
Gather equivalence info of CallOps.
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class helps build Operations.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumOperands()
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
State for analysis-enabled bufferization.
void denyOperation()
Deny the given ops.
Operation * getOwner() const
Return the owner of this operand.
static FuncOp getCalledFunction(CallOpInterface callOp)
Return the FuncOp called by callOp.
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
SmallVector< func::ReturnOp > getReturnOps(func::FuncOp funcOp)
Helper function that returns all func.return ops in the given function.
llvm::LogicalResult runOneShotModuleBufferize(ModuleOp moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given module.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
llvm::LogicalResult analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze moduleOp and its nested ops.
void removeBufferizationAttributesInModule(ModuleOp moduleOp)
Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
llvm::LogicalResult bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool copyBeforeWrite
If set to true, the analysis is skipped.
OpFilter opFilter
A filter that specifies which ops should be bufferized and which ops should be ignored.
Bufferization statistics for debugging.
Options for analysis-enabled bufferization.
std::function< bool(Operation *)> FilterFn
If the filter function evaluates to true, the filter matches.
Extra analysis state that is required for bufferization of function boundaries.
DenseMap< FuncOp, IndexMapping > equivalentFuncArgs
A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg indices.
DenseMap< FuncOp, IndexToIndexListMapping > aliasingReturnVals
A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
DenseMap< FuncOp, BbArgIndexSet > readBbArgs
A set of all read BlockArguments of FuncOps.
DenseMap< FuncOp, BbArgIndexSet > writtenBbArgs
A set of all written-to BlockArguments of FuncOps.
DenseMap< FuncOp, FuncOpAnalysisState > analyzedFuncOps
Keep track of which FuncOps are fully analyzed or currently being analyzed.
void startFunctionAnalysis(FuncOp funcOp)
This function is called right before analyzing the given FuncOp.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.