17 namespace bufferization {
18 #define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMS
19 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
32 if (!llvm::all_of(strides, ShapedType::isDynamic))
34 if (!ShapedType::isDynamic(offset))
42 return type.getLayout().isIdentity();
53 bool addResultAttribute) {
54 auto functionType = func.getFunctionType();
58 BitVector erasedResultIndices(functionType.getNumResults());
59 for (
const auto &resultType :
llvm::enumerate(functionType.getResults())) {
60 if (
auto memrefType = dyn_cast<MemRefType>(resultType.value())) {
66 return func->emitError()
67 <<
"cannot create out param for result with unsupported layout";
69 erasedResultIndices.set(resultType.index());
70 erasedResultTypes.push_back(memrefType);
75 auto newArgTypes = llvm::to_vector<6>(
76 llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
78 functionType.getResults());
79 func.setType(newFunctionType);
82 auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
83 for (
int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
84 func.setArgAttrs(functionType.getNumInputs() + i,
85 func.getResultAttrs(*erasedIndicesIt));
86 if (addResultAttribute)
87 func.setArgAttr(functionType.getNumInputs() + i,
93 func.eraseResults(erasedResultIndices);
96 if (func.isExternal())
99 for (
Type type : erasedResultTypes)
100 appendedEntryArgs.push_back(func.front().addArgument(type, loc));
111 auto res = func.walk([&](func::ReturnOp op) {
115 if (isa<MemRefType>(operand.getType()))
116 copyIntoOutParams.push_back(operand);
118 keepAsReturnOperands.push_back(operand);
121 for (
auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
123 memCpyFn(builder, op.
getLoc(), std::get<0>(t), std::get<1>(t))))
126 builder.create<func::ReturnOp>(op.
getLoc(), keepAsReturnOperands);
130 return failure(res.wasInterrupted());
138 bool didFail =
false;
140 module.walk([&](func::CallOp op) {
141 auto callee = symtab.
lookup<func::FuncOp>(op.getCallee());
143 op.
emitError() <<
"cannot find callee '" << op.getCallee() <<
"' in "
148 if (!
options.filterFn(&callee))
153 if (isa<MemRefType>(result.getType()))
154 replaceWithOutParams.push_back(result);
156 replaceWithNewCallResults.push_back(result);
160 for (
Value memref : replaceWithOutParams) {
161 if (!cast<MemRefType>(memref.getType()).hasStaticShape()) {
163 <<
"cannot create out param for dynamically shaped result";
167 auto memrefType = cast<MemRefType>(memref.getType());
170 AffineMap(), memrefType.getMemorySpace());
175 "layout map not supported");
177 builder.
create<memref::CastOp>(op.
getLoc(), memrefType, outParam);
180 outParams.push_back(outParam);
183 auto newOperands = llvm::to_vector<6>(op.
getOperands());
184 newOperands.append(outParams.begin(), outParams.end());
185 auto newResultTypes = llvm::to_vector<6>(llvm::map_range(
186 replaceWithNewCallResults, [](
Value v) {
return v.
getType(); }));
187 auto newCall = builder.
create<func::CallOp>(op.
getLoc(), op.getCalleeAttr(),
188 newResultTypes, newOperands);
189 for (
auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))
200 for (
auto func : module.getOps<func::FuncOp>()) {
207 if (func.isExternal())
211 builder.
create<memref::CopyOp>(loc, from, to);
215 options.memCpyFn.value_or(defaultMemCpyFn)))) {
225 struct BufferResultsToOutParamsPass
226 : bufferization::impl::BufferResultsToOutParamsBase<
227 BufferResultsToOutParamsPass> {
228 explicit BufferResultsToOutParamsPass(
232 void runOnOperation()
override {
234 if (addResultAttribute)
235 options.addResultAttribute =
true;
239 return signalPassFailure();
249 return std::make_unique<BufferResultsToOutParamsPass>(
options);
static LogicalResult updateReturnOps(func::FuncOp func, ArrayRef< BlockArgument > appendedEntryArgs, MemCpyFn memCpyFn)
bufferization::BufferResultsToOutParamsOpts::MemCpyFn MemCpyFn
static LogicalResult updateFuncOp(func::FuncOp func, SmallVectorImpl< BlockArgument > &appendedEntryArgs, bool addResultAttribute)
static bool hasStaticIdentityLayout(MemRefType type)
Return true if the given MemRef type has a static identity layout (i.e., no layout).
static LogicalResult updateCalls(ModuleOp module, const bufferization::BufferResultsToOutParamsOpts &options)
static bool hasFullyDynamicLayoutMap(MemRefType type)
Return true if the given MemRef type has a fully dynamic layout.
static llvm::ManagedStatic< PassManagerOptions > options
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This is a value defined by a result of an operation.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_range getOperands()
Returns an iterator on the underlying Value's.
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
result_range getResults()
void erase()
Remove this operation from its parent block and delete it.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
LogicalResult promoteBufferResultsToOutParams(ModuleOp module, const BufferResultsToOutParamsOpts &options)
Replace buffers that are returned from a function with an out parameter.
std::unique_ptr< Pass > createBufferResultsToOutParamsPass(const BufferResultsToOutParamsOpts &options={})
Creates a pass that converts memref function results to out-params.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
std::function< LogicalResult(OpBuilder &, Location, Value, Value)> MemCpyFn
Memcpy function: Generate a memcpy between two memrefs.