92 func::ReturnOp returnOp;
93 for (
Block &b : funcOp.getBody()) {
94 if (
auto candidateOp = dyn_cast<func::ReturnOp>(b.getTerminator())) {
97 returnOp = candidateOp;
106 static void annotateEquivalentReturnBbArg(
OpOperand &returnVal,
108 const char *kEquivalentArgsAttr =
"__equivalent_func_args__";
112 if (op->
hasAttr(kEquivalentArgsAttr)) {
113 auto attr = cast<ArrayAttr>(op->
getAttr(kEquivalentArgsAttr));
114 equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](
Attribute a) {
115 return cast<IntegerAttr>(a).getValue().getSExtValue();
123 op->
setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
131 if (funcOp.getBody().empty()) {
134 FunctionType type = funcOp.getFunctionType();
136 if (!isa<TensorType>(inputIt.value()))
139 if (!isa<TensorType>(resultIt.value()))
141 int64_t returnIdx = resultIt.index();
142 int64_t bbArgIdx = inputIt.index();
151 assert(returnOp &&
"expected func with single return op");
153 for (
OpOperand &returnVal : returnOp->getOpOperands())
154 if (isa<RankedTensorType>(returnVal.
get().
getType()))
156 if (isa<RankedTensorType>(bbArg.
getType())) {
159 if (state.areEquivalentBufferizedValues(returnVal.
get(), bbArg)) {
161 if (state.getOptions().testAnalysisOnly)
162 annotateEquivalentReturnBbArg(returnVal, bbArg);
164 if (state.areAliasingBufferizedValues(returnVal.
get(), bbArg))
171 static void annotateFuncArgAccess(func::FuncOp funcOp, int64_t idx,
bool isRead,
175 if (isRead && isWritten) {
176 accessType = b.getStringAttr(
"read-write");
178 accessType = b.getStringAttr(
"read");
179 }
else if (isWritten) {
180 accessType = b.getStringAttr(
"write");
182 accessType = b.getStringAttr(
"none");
184 funcOp.setArgAttr(idx, BufferizationDialect::kBufferAccessAttrName,
194 for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
197 if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
201 if (
auto accessAttr = funcOp.getArgAttrOfType<StringAttr>(
202 idx, BufferizationDialect::kBufferAccessAttrName)) {
204 StringRef str = accessAttr.getValue();
205 isRead = str ==
"read" || str ==
"read-write";
206 isWritten = str ==
"write" || str ==
"read-write";
207 }
else if (funcOp.getBody().empty()) {
215 isRead = state.isValueRead(bbArg);
216 isWritten = state.isValueWritten(bbArg);
219 if (state.getOptions().testAnalysisOnly)
220 annotateFuncArgAccess(funcOp, idx, isRead, isWritten);
235 BufferizationDialect::kBufferLayoutAttrName);
237 BufferizationDialect::kWritableAttrName);
243 llvm::dyn_cast_if_present<SymbolRefAttr>(callOp.getCallableForCallee());
246 return dyn_cast_or_null<func::FuncOp>(
257 funcOp->walk([&](func::CallOp callOp) {
259 assert(calledFunction &&
"could not retrieved called func::FuncOp");
266 int64_t returnIdx = it.first;
267 int64_t bbargIdx = it.second;
268 if (!state.isInPlace(callOp->getOpOperand(bbargIdx)))
270 Value returnVal = callOp.getResult(returnIdx);
271 Value argVal = callOp->getOperand(bbargIdx);
272 state.unionEquivalenceClasses(returnVal, argVal);
281 return llvm::any_of(funcOp.getFunctionType().getInputs(),
282 llvm::IsaPred<TensorType>) ||
283 llvm::any_of(funcOp.getFunctionType().getResults(),
284 llvm::IsaPred<TensorType>);
302 if (!funcOp.getBody().empty()) {
303 func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
305 return funcOp->emitError()
306 <<
"cannot bufferize a FuncOp with tensors and "
307 "without a unique ReturnOp";
311 numberCallOpsContainedInFuncOp[funcOp] = 0;
312 return funcOp.walk([&](func::CallOp callOp) ->
WalkResult {
313 func::FuncOp calledFunction = getCalledFunction(callOp);
314 assert(calledFunction &&
"could not retrieved called func::FuncOp");
317 if (!hasTensorSignature(calledFunction))
318 return WalkResult::skip();
320 callerMap[calledFunction].insert(callOp);
321 if (calledBy[calledFunction].insert(funcOp).second) {
322 numberCallOpsContainedInFuncOp[funcOp]++;
327 if (res.wasInterrupted())
331 while (!numberCallOpsContainedInFuncOp.empty()) {
332 auto it = llvm::find_if(numberCallOpsContainedInFuncOp,
333 [](
auto entry) {
return entry.getSecond() == 0; });
334 if (it == numberCallOpsContainedInFuncOp.end())
335 return moduleOp.emitOpError(
336 "expected callgraph to be free of circular dependencies.");
337 orderedFuncOps.push_back(it->getFirst());
338 for (
auto callee : calledBy[it->getFirst()])
339 numberCallOpsContainedInFuncOp[callee]--;
340 numberCallOpsContainedInFuncOp.erase(it);
353 if (funcOp.getBody().empty())
359 for (
OpOperand &operand : returnOp->getOpOperands()) {
360 if (
auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
361 operand.set(castOp.getSource());
362 resultTypes.push_back(castOp.getSource().getType());
364 resultTypes.push_back(operand.get().getType());
369 funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes);
370 funcOp.setType(newFuncType);
377 assert(state.getOptions().bufferizeFunctionBoundaries &&
378 "expected that function boundary bufferization is activated");
391 for (func::FuncOp funcOp : orderedFuncOps) {
392 if (!state.getOptions().isOpAllowed(funcOp))
402 if (failed(
analyzeOp(funcOp, state, statistics)))
406 if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) ||
407 failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState)))
419 moduleOp.walk([&](func::FuncOp op) {
428 assert(
options.bufferizeFunctionBoundaries &&
429 "expected that function boundary bufferization is activated");
442 for (func::FuncOp funcOp : orderedFuncOps) {
446 if (llvm::is_contained(
options.noAnalysisFuncFilter, funcOp.getSymName())) {
451 if (failed(
bufferizeOp(funcOp, updatedOptions, statistics)))
459 if (
options.inferFunctionResultLayout)
464 for (
Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
466 if (isa<func::FuncOp>(&op))
481 assert(
options.bufferizeFunctionBoundaries &&
482 "expected that function boundary bufferization is activated");
484 "invalid combination of bufferization flags");
485 if (!
options.copyBeforeWrite) {
486 if (
options.noAnalysisFuncFilter.empty()) {
493 auto func = dyn_cast<func::FuncOp>(op);
497 return llvm::is_contained(
options.noAnalysisFuncFilter,
static bool hasTensorSignature(func::FuncOp funcOp)
Return "true" if the given function signature has tensor semantics.
static LogicalResult getFuncOpsOrderedByCalls(ModuleOp moduleOp, SmallVectorImpl< func::FuncOp > &orderedFuncOps, FuncCallerMap &callerMap)
Store all functions of the moduleOp in orderedFuncOps, sorted by callee-caller order (i....
static FuncAnalysisState & getOrCreateFuncAnalysisState(OneShotAnalysisState &state)
Get or create FuncAnalysisState.
static void removeBufferizationAttributes(BlockArgument bbArg)
Remove bufferization attributes on FuncOp arguments.
static void foldMemRefCasts(func::FuncOp funcOp)
Fold return values that are memref casts and update function return types.
static void equivalenceAnalysis(func::FuncOp funcOp, OneShotAnalysisState &state, FuncAnalysisState &funcState)
Gather equivalence info of CallOps.
static func::ReturnOp getAssumedUniqueReturnOp(func::FuncOp funcOp)
Return the unique ReturnOp that terminates funcOp.
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
unsigned getArgNumber() const
Returns the number of this argument.
Block represents an ordered list of Operations.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class helps build Operations.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Operation is the basic unit of execution within MLIR.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumOperands()
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
State for analysis-enabled bufferization.
void denyOperation()
Deny the given ops.
Operation * getOwner() const
Return the owner of this operand.
static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp)
Return the unique ReturnOp that terminates funcOp.
static FuncOp getCalledFunction(CallOpInterface callOp)
Return the FuncOp called by callOp.
LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze op and its nested ops.
llvm::LogicalResult runOneShotModuleBufferize(ModuleOp moduleOp, const bufferization::OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Run One-Shot Module Bufferization on the given module.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
llvm::LogicalResult analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics=nullptr)
Analyze moduleOp and its nested ops.
void removeBufferizationAttributesInModule(ModuleOp moduleOp)
Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
llvm::LogicalResult bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options, BufferizationStatistics *statistics=nullptr)
Bufferize op and its nested ops that implement BufferizableOpInterface.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool copyBeforeWrite
If set to true, the analysis is skipped.
OpFilter opFilter
A filter that specifies which ops should be bufferized and which ops should be ignored.
Bufferization statistics for debugging.
Options for analysis-enabled bufferization.
std::function< bool(Operation *)> FilterFn
If the filter function evaluates to true, the filter matches.
Extra analysis state that is required for bufferization of function boundaries.
DenseMap< FuncOp, IndexMapping > equivalentFuncArgs
A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg indices.
DenseMap< FuncOp, IndexToIndexListMapping > aliasingReturnVals
A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices.
DenseMap< FuncOp, BbArgIndexSet > readBbArgs
A set of all read BlockArguments of FuncOps.
DenseMap< FuncOp, BbArgIndexSet > writtenBbArgs
A set of all written-to BlockArguments of FuncOps.
DenseMap< FuncOp, FuncOpAnalysisState > analyzedFuncOps
Keep track of which FuncOps are fully analyzed or currently being analyzed.
void startFunctionAnalysis(FuncOp funcOp)
This function is called right before analyzing the given FuncOp.