17 namespace bufferization {
18 #define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMS
19 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
32 if (!llvm::all_of(strides, ShapedType::isDynamic))
34 if (!ShapedType::isDynamic(offset))
42 return type.getLayout().isIdentity();
53 bool addResultAttribute) {
54 auto functionType = func.getFunctionType();
58 BitVector erasedResultIndices(functionType.getNumResults());
59 for (
const auto &resultType :
llvm::enumerate(functionType.getResults())) {
60 if (
auto memrefType = dyn_cast<MemRefType>(resultType.value())) {
66 return func->emitError()
67 <<
"cannot create out param for result with unsupported layout";
69 erasedResultIndices.set(resultType.index());
70 erasedResultTypes.push_back(memrefType);
75 auto newArgTypes = llvm::to_vector<6>(
76 llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
78 functionType.getResults());
79 func.setType(newFunctionType);
82 auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
83 for (
int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
84 func.setArgAttrs(functionType.getNumInputs() + i,
85 func.getResultAttrs(*erasedIndicesIt));
86 if (addResultAttribute)
87 func.setArgAttr(functionType.getNumInputs() + i,
93 func.eraseResults(erasedResultIndices);
96 if (func.isExternal())
99 for (
Type type : erasedResultTypes)
100 appendedEntryArgs.push_back(func.front().addArgument(type, loc));
111 bool hoistStaticAllocs) {
112 auto res = func.walk([&](func::ReturnOp op) {
116 if (isa<MemRefType>(operand.getType()))
117 copyIntoOutParams.push_back(operand);
119 keepAsReturnOperands.push_back(operand);
122 for (
auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
123 if (hoistStaticAllocs && isa<memref::AllocOp>(orig.getDefiningOp()) &&
124 mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
125 orig.replaceAllUsesWith(arg);
126 orig.getDefiningOp()->erase();
128 if (failed(memCpyFn(builder, op.getLoc(), orig, arg)))
129 return WalkResult::interrupt();
132 builder.
create<func::ReturnOp>(op.
getLoc(), keepAsReturnOperands);
136 return failure(res.wasInterrupted());
144 bool didFail =
false;
146 module.walk([&](func::CallOp op) {
147 auto callee = symtab.
lookup<func::FuncOp>(op.getCallee());
149 op.
emitError() <<
"cannot find callee '" << op.getCallee() <<
"' in "
154 if (!
options.filterFn(&callee))
159 if (isa<MemRefType>(result.getType()))
160 replaceWithOutParams.push_back(result);
162 replaceWithNewCallResults.push_back(result);
166 for (
Value memref : replaceWithOutParams) {
167 if (!cast<MemRefType>(memref.getType()).hasStaticShape()) {
169 <<
"cannot create out param for dynamically shaped result";
173 auto memrefType = cast<MemRefType>(memref.getType());
176 AffineMap(), memrefType.getMemorySpace());
181 "layout map not supported");
183 builder.
create<memref::CastOp>(op.
getLoc(), memrefType, outParam);
186 outParams.push_back(outParam);
189 auto newOperands = llvm::to_vector<6>(op.
getOperands());
190 newOperands.append(outParams.begin(), outParams.end());
191 auto newResultTypes = llvm::to_vector<6>(llvm::map_range(
192 replaceWithNewCallResults, [](
Value v) {
return v.
getType(); }));
193 auto newCall = builder.
create<func::CallOp>(op.
getLoc(), op.getCalleeAttr(),
194 newResultTypes, newOperands);
195 for (
auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))
200 return failure(didFail);
206 for (
auto func : module.getOps<func::FuncOp>()) {
213 if (func.isExternal())
217 builder.
create<memref::CopyOp>(loc, from, to);
221 options.memCpyFn.value_or(defaultMemCpyFn),
232 struct BufferResultsToOutParamsPass
233 : bufferization::impl::BufferResultsToOutParamsBase<
234 BufferResultsToOutParamsPass> {
235 explicit BufferResultsToOutParamsPass(
239 void runOnOperation()
override {
241 if (addResultAttribute)
242 options.addResultAttribute =
true;
243 if (hoistStaticAllocs)
244 options.hoistStaticAllocs =
true;
248 return signalPassFailure();
258 return std::make_unique<BufferResultsToOutParamsPass>(
options);
bufferization::BufferResultsToOutParamsOpts::MemCpyFn MemCpyFn
static LogicalResult updateFuncOp(func::FuncOp func, SmallVectorImpl< BlockArgument > &appendedEntryArgs, bool addResultAttribute)
static bool hasStaticIdentityLayout(MemRefType type)
Return true if the given MemRef type has a static identity layout (i.e., no layout).
static LogicalResult updateCalls(ModuleOp module, const bufferization::BufferResultsToOutParamsOpts &options)
static LogicalResult updateReturnOps(func::FuncOp func, ArrayRef< BlockArgument > appendedEntryArgs, MemCpyFn memCpyFn, bool hoistStaticAllocs)
static bool hasFullyDynamicLayoutMap(MemRefType type)
Return true if the given MemRef type has a fully dynamic layout.
static llvm::ManagedStatic< PassManagerOptions > options
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This is a value defined by a result of an operation.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_range getOperands()
Returns an iterator on the underlying Value's.
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
result_range getResults()
void erase()
Remove this operation from its parent block and delete it.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
LogicalResult promoteBufferResultsToOutParams(ModuleOp module, const BufferResultsToOutParamsOpts &options)
Replace buffers that are returned from a function with an out parameter.
std::unique_ptr< Pass > createBufferResultsToOutParamsPass(const BufferResultsToOutParamsOpts &options={})
Creates a pass that converts memref function results to out-params.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::function< LogicalResult(OpBuilder &, Location, Value, Value)> MemCpyFn
Memcpy function: Generate a memcpy between two memrefs.