17 namespace bufferization {
18 #define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMS
19 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
31 if (!llvm::all_of(strides, ShapedType::isDynamic))
33 if (!ShapedType::isDynamic(offset))
41 return type.getLayout().isIdentity();
50 auto functionType = func.getFunctionType();
54 BitVector erasedResultIndices(functionType.getNumResults());
55 for (
const auto &resultType :
llvm::enumerate(functionType.getResults())) {
56 if (
auto memrefType = dyn_cast<MemRefType>(resultType.value())) {
62 return func->emitError()
63 <<
"cannot create out param for result with unsupported layout";
65 erasedResultIndices.set(resultType.index());
66 erasedResultTypes.push_back(memrefType);
71 auto newArgTypes = llvm::to_vector<6>(
72 llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
74 functionType.getResults());
75 func.setType(newFunctionType);
78 auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
79 for (
int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
80 func.setArgAttrs(functionType.getNumInputs() + i,
81 func.getResultAttrs(*erasedIndicesIt));
85 func.eraseResults(erasedResultIndices);
88 if (func.isExternal())
91 for (
Type type : erasedResultTypes)
92 appendedEntryArgs.push_back(func.front().addArgument(type, loc));
102 func.walk([&](func::ReturnOp op) {
106 if (isa<MemRefType>(operand.getType()))
107 copyIntoOutParams.push_back(operand);
109 keepAsReturnOperands.push_back(operand);
112 for (
auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs))
113 builder.
create<memref::CopyOp>(op.
getLoc(), std::get<0>(t),
115 builder.
create<func::ReturnOp>(op.
getLoc(), keepAsReturnOperands);
125 bool didFail =
false;
127 module.walk([&](func::CallOp op) {
128 auto callee = symtab.
lookup<func::FuncOp>(op.getCallee());
130 op.
emitError() <<
"cannot find callee '" << op.getCallee() <<
"' in "
135 if (!
options.filterFn(&callee))
140 if (isa<MemRefType>(result.getType()))
141 replaceWithOutParams.push_back(result);
143 replaceWithNewCallResults.push_back(result);
147 for (
Value memref : replaceWithOutParams) {
148 if (!cast<MemRefType>(memref.getType()).hasStaticShape()) {
150 <<
"cannot create out param for dynamically shaped result";
154 auto memrefType = cast<MemRefType>(memref.getType());
157 AffineMap(), memrefType.getMemorySpace());
162 "layout map not supported");
164 builder.
create<memref::CastOp>(op.
getLoc(), memrefType, outParam);
167 outParams.push_back(outParam);
170 auto newOperands = llvm::to_vector<6>(op.
getOperands());
171 newOperands.append(outParams.begin(), outParams.end());
172 auto newResultTypes = llvm::to_vector<6>(llvm::map_range(
173 replaceWithNewCallResults, [](
Value v) {
return v.
getType(); }));
174 auto newCall = builder.
create<func::CallOp>(op.
getLoc(), op.getCalleeAttr(),
175 newResultTypes, newOperands);
176 for (
auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))
187 for (
auto func : module.getOps<func::FuncOp>()) {
193 if (func.isExternal())
203 struct BufferResultsToOutParamsPass
204 : bufferization::impl::BufferResultsToOutParamsBase<
205 BufferResultsToOutParamsPass> {
206 explicit BufferResultsToOutParamsPass(
210 void runOnOperation()
override {
213 return signalPassFailure();
223 return std::make_unique<BufferResultsToOutParamsPass>(
options);
static bool hasStaticIdentityLayout(MemRefType type)
Return true if the given MemRef type has a static identity layout (i.e., no layout).
static LogicalResult updateFuncOp(func::FuncOp func, SmallVectorImpl< BlockArgument > &appendedEntryArgs)
static LogicalResult updateCalls(ModuleOp module, const bufferization::BufferResultsToOutParamsOptions &options)
static bool hasFullyDynamicLayoutMap(MemRefType type)
Return true if the given MemRef type has a fully dynamic layout.
static void updateReturnOps(func::FuncOp func, ArrayRef< BlockArgument > appendedEntryArgs)
static llvm::ManagedStatic< PassManagerOptions > options
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This is a value defined by a result of an operation.
Location getLoc()
The source location the operation was defined or derived from.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_range getOperands()
Returns an iterator on the underlying Value's.
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
result_range getResults()
void erase()
Remove this operation from its parent block and delete it.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
LogicalResult promoteBufferResultsToOutParams(ModuleOp module, const BufferResultsToOutParamsOptions &options)
Replace buffers that are returned from a function with an out parameter.
std::unique_ptr< Pass > createBufferResultsToOutParamsPass(const BufferResultsToOutParamsOptions &options={})
Creates a pass that converts memref function results to out-params.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.