18 namespace bufferization {
19 #define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMSPASS
20 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
32 if (failed(type.getStridesAndOffset(strides, offset)))
34 if (!llvm::all_of(strides, ShapedType::isDynamic))
36 if (!ShapedType::isDynamic(offset))
44 return type.getLayout().isIdentity();
55 bool addResultAttribute) {
56 auto functionType = func.getFunctionType();
60 BitVector erasedResultIndices(functionType.getNumResults());
61 for (
const auto &resultType :
llvm::enumerate(functionType.getResults())) {
62 if (
auto memrefType = dyn_cast<MemRefType>(resultType.value())) {
68 return func->emitError()
69 <<
"cannot create out param for result with unsupported layout";
71 erasedResultIndices.set(resultType.index());
72 erasedResultTypes.push_back(memrefType);
77 auto newArgTypes = llvm::to_vector<6>(
78 llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
80 functionType.getResults());
81 func.setType(newFunctionType);
84 auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
85 for (
int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
86 func.setArgAttrs(functionType.getNumInputs() + i,
87 func.getResultAttrs(*erasedIndicesIt));
88 if (addResultAttribute)
89 func.setArgAttr(functionType.getNumInputs() + i,
95 func.eraseResults(erasedResultIndices);
98 if (func.isExternal())
101 for (
Type type : erasedResultTypes)
102 appendedEntryArgs.push_back(func.front().addArgument(type, loc));
113 auto res = func.walk([&](func::ReturnOp op) {
116 for (
Value operand : op.getOperands()) {
117 if (isa<MemRefType>(operand.getType()))
118 copyIntoOutParams.push_back(operand);
120 keepAsReturnOperands.push_back(operand);
123 for (
auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
124 if (
options.hoistStaticAllocs &&
125 isa_and_nonnull<bufferization::AllocationOpInterface>(
126 orig.getDefiningOp()) &&
127 mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
128 orig.replaceAllUsesWith(arg);
129 orig.getDefiningOp()->erase();
131 if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg)))
132 return WalkResult::interrupt();
135 builder.
create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
139 return failure(res.wasInterrupted());
147 bool didFail =
false;
149 module.walk([&](func::CallOp op) {
150 auto callee = symtab.
lookup<func::FuncOp>(op.getCallee());
152 op.emitError() <<
"cannot find callee '" << op.getCallee() <<
"' in "
157 if (!
options.filterFn(&callee))
161 for (
OpResult result : op.getResults()) {
162 if (isa<MemRefType>(result.getType()))
163 replaceWithOutParams.push_back(result);
165 replaceWithNewCallResults.push_back(result);
169 for (
Value memref : replaceWithOutParams) {
170 if (!cast<MemRefType>(memref.getType()).hasStaticShape()) {
172 <<
"cannot create out param for dynamically shaped result";
176 auto memrefType = cast<MemRefType>(memref.getType());
179 AffineMap(), memrefType.getMemorySpace());
181 options.allocationFn(builder, op.getLoc(), allocType);
182 if (failed(maybeOutParam)) {
183 op.emitError() <<
"failed to create allocation op";
187 Value outParam = maybeOutParam.value();
191 "layout map not supported");
193 builder.
create<memref::CastOp>(op.getLoc(), memrefType, outParam);
196 outParams.push_back(outParam);
199 auto newOperands = llvm::to_vector<6>(op.getOperands());
200 newOperands.append(outParams.begin(), outParams.end());
201 auto newResultTypes = llvm::to_vector<6>(llvm::map_range(
202 replaceWithNewCallResults, [](
Value v) {
return v.
getType(); }));
203 auto newCall = builder.
create<func::CallOp>(op.getLoc(), op.getCalleeAttr(),
204 newResultTypes, newOperands);
205 for (
auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))
210 return failure(didFail);
216 for (
auto func : module.getOps<func::FuncOp>()) {
223 if (func.isExternal())
235 struct BufferResultsToOutParamsPass
236 : bufferization::impl::BufferResultsToOutParamsPassBase<
237 BufferResultsToOutParamsPass> {
240 void runOnOperation()
override {
242 if (addResultAttribute)
243 options.addResultAttribute =
true;
244 if (hoistStaticAllocs)
245 options.hoistStaticAllocs =
true;
249 return signalPassFailure();
static LogicalResult updateReturnOps(func::FuncOp func, ArrayRef< BlockArgument > appendedEntryArgs, const bufferization::BufferResultsToOutParamsOpts &options)
bufferization::BufferResultsToOutParamsOpts::AllocationFn AllocationFn
bufferization::BufferResultsToOutParamsOpts::MemCpyFn MemCpyFn
static LogicalResult updateFuncOp(func::FuncOp func, SmallVectorImpl< BlockArgument > &appendedEntryArgs, bool addResultAttribute)
static bool hasStaticIdentityLayout(MemRefType type)
Return true if the given MemRef type has a static identity layout (i.e., no layout).
static LogicalResult updateCalls(ModuleOp module, const bufferization::BufferResultsToOutParamsOpts &options)
static bool hasFullyDynamicLayoutMap(MemRefType type)
Return true if the given MemRef type has a fully dynamic layout.
static llvm::ManagedStatic< PassManagerOptions > options
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This is a value defined by a result of an operation.
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
LogicalResult promoteBufferResultsToOutParams(ModuleOp module, const BufferResultsToOutParamsOpts &options)
Replace buffers that are returned from a function with an out parameter.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::function< LogicalResult(OpBuilder &, Location, Value, Value)> MemCpyFn
Memcpy function: Generate a memcpy between two memrefs.
std::function< FailureOr< Value >(OpBuilder &, Location, MemRefType)> AllocationFn
Allocator function: Generate a memref allocation with the given type.