17 namespace bufferization {
18 #define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMS
19 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
31 if (!llvm::all_of(strides, ShapedType::isDynamic))
33 if (!ShapedType::isDynamic(offset))
41 return type.getLayout().isIdentity();
50 auto functionType = func.getFunctionType();
54 BitVector erasedResultIndices(functionType.getNumResults());
55 for (
const auto &resultType :
llvm::enumerate(functionType.getResults())) {
56 if (
auto memrefType = resultType.value().dyn_cast<MemRefType>()) {
62 return func->emitError()
63 <<
"cannot create out param for result with unsupported layout";
65 erasedResultIndices.set(resultType.index());
66 erasedResultTypes.push_back(memrefType);
71 auto newArgTypes = llvm::to_vector<6>(
72 llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
73 auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes,
74 functionType.getResults());
75 func.setType(newFunctionType);
78 auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
79 for (
int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
80 func.setArgAttrs(functionType.getNumInputs() + i,
81 func.getResultAttrs(*erasedIndicesIt));
85 func.eraseResults(erasedResultIndices);
88 if (func.isExternal())
91 for (
Type type : erasedResultTypes)
92 appendedEntryArgs.push_back(func.front().addArgument(type, loc));
102 func.walk([&](func::ReturnOp op) {
105 for (
Value operand : op.getOperands()) {
106 if (operand.getType().isa<MemRefType>())
107 copyIntoOutParams.push_back(operand);
109 keepAsReturnOperands.push_back(operand);
112 for (
auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs))
113 builder.
create<memref::CopyOp>(op.getLoc(), std::get<0>(t),
115 builder.
create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
125 bool didFail =
false;
127 module.walk([&](func::CallOp op) {
128 auto callee = symtab.
lookup<func::FuncOp>(op.getCallee());
130 op.emitError() <<
"cannot find callee '" << op.getCallee() <<
"' in "
135 if (!
options.filterFn(&callee))
139 for (
OpResult result : op.getResults()) {
140 if (result.getType().isa<MemRefType>())
141 replaceWithOutParams.push_back(result);
143 replaceWithNewCallResults.push_back(result);
147 for (
Value memref : replaceWithOutParams) {
148 if (!memref.getType().cast<MemRefType>().hasStaticShape()) {
150 <<
"cannot create out param for dynamically shaped result";
154 auto memrefType = memref.getType().cast<MemRefType>();
156 MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
157 AffineMap(), memrefType.getMemorySpace());
158 Value outParam = builder.
create<memref::AllocOp>(op.getLoc(), allocType);
162 "layout map not supported");
164 builder.
create<memref::CastOp>(op.getLoc(), memrefType, outParam);
167 outParams.push_back(outParam);
170 auto newOperands = llvm::to_vector<6>(op.getOperands());
171 newOperands.append(outParams.begin(), outParams.end());
172 auto newResultTypes = llvm::to_vector<6>(llvm::map_range(
173 replaceWithNewCallResults, [](
Value v) {
return v.
getType(); }));
174 auto newCall = builder.
create<func::CallOp>(op.getLoc(), op.getCalleeAttr(),
175 newResultTypes, newOperands);
176 for (
auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))
187 for (
auto func : module.getOps<func::FuncOp>()) {
193 if (func.isExternal())
203 struct BufferResultsToOutParamsPass
204 : bufferization::impl::BufferResultsToOutParamsBase<
205 BufferResultsToOutParamsPass> {
206 explicit BufferResultsToOutParamsPass(
210 void runOnOperation()
override {
213 return signalPassFailure();
223 return std::make_unique<BufferResultsToOutParamsPass>(
options);
static bool hasStaticIdentityLayout(MemRefType type)
Return true if the given MemRef type has a static identity layout (i.e., no layout).
static LogicalResult updateFuncOp(func::FuncOp func, SmallVectorImpl< BlockArgument > &appendedEntryArgs)
static LogicalResult updateCalls(ModuleOp module, const bufferization::BufferResultsToOutParamsOptions &options)
static bool hasFullyDynamicLayoutMap(MemRefType type)
Return true if the given MemRef type has a fully dynamic layout.
static void updateReturnOps(func::FuncOp func, ArrayRef< BlockArgument > appendedEntryArgs)
static llvm::ManagedStatic< PassManagerOptions > options
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This is a value defined by a result of an operation.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
void erase()
Remove this operation from its parent block and delete it.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult promoteBufferResultsToOutParams(ModuleOp module, const BufferResultsToOutParamsOptions &options)
Replace buffers that are returned from a function with an out parameter.
std::unique_ptr< Pass > createBufferResultsToOutParamsPass(const BufferResultsToOutParamsOptions &options={})
Creates a pass that converts memref function results to out-params.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.