17 namespace bufferization {
18 #define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMSPASS
19 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
33 if (
failed(type.getStridesAndOffset(strides, offset)))
35 if (!llvm::all_of(strides, ShapedType::isDynamic))
37 if (ShapedType::isStatic(offset))
45 return type.getLayout().isIdentity();
57 for (
Value size : operands) {
58 if (!isa<IndexType>(size.getType()))
64 auto arguments = funcOp.getArguments();
65 auto iter = llvm::find(arguments, sizeSrc);
66 if (iter == arguments.end())
68 dynamicSizes.push_back(*iter);
79 for (
Value size : dynamicSizes) {
80 for (
auto [src, dst] :
81 llvm::zip_first(call.getOperands(), callee.getArguments())) {
84 mappedDynamicSizes.push_back(src);
87 assert(mappedDynamicSizes.size() == dynamicSizes.size() &&
88 "could not find all dynamic sizes");
89 return mappedDynamicSizes;
100 bool addResultAttribute) {
101 auto functionType = func.getFunctionType();
105 BitVector erasedResultIndices(functionType.getNumResults());
106 for (
const auto &resultType :
llvm::enumerate(functionType.getResults())) {
107 if (
auto memrefType = dyn_cast<MemRefType>(resultType.value())) {
113 return func->emitError()
114 <<
"cannot create out param for result with unsupported layout";
116 erasedResultIndices.set(resultType.index());
117 erasedResultTypes.push_back(memrefType);
122 auto newArgTypes = llvm::to_vector<6>(
123 llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
125 functionType.getResults());
126 func.setType(newFunctionType);
129 auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
130 for (
int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
131 func.setArgAttrs(functionType.getNumInputs() + i,
132 func.getResultAttrs(*erasedIndicesIt));
133 if (addResultAttribute)
134 func.setArgAttr(functionType.getNumInputs() + i,
140 if (
failed(func.eraseResults(erasedResultIndices)))
144 if (func.isExternal())
147 for (
Type type : erasedResultTypes)
148 appendedEntryArgs.push_back(func.front().addArgument(type, loc));
160 auto res = func.walk([&](func::ReturnOp op) {
163 for (
Value operand : op.getOperands()) {
164 if (isa<MemRefType>(operand.getType()))
165 copyIntoOutParams.push_back(operand);
167 keepAsReturnOperands.push_back(operand);
171 for (
auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
172 bool hoistStaticAllocs =
174 cast<MemRefType>(orig.getType()).hasStaticShape();
175 bool hoistDynamicAllocs =
177 !cast<MemRefType>(orig.getType()).hasStaticShape();
178 if ((hoistStaticAllocs || hoistDynamicAllocs) &&
179 isa_and_nonnull<bufferization::AllocationOpInterface>(
180 orig.getDefiningOp())) {
181 orig.replaceAllUsesWith(arg);
182 if (hoistDynamicAllocs) {
184 dynamicSizes.push_back(dynamicSize);
186 orig.getDefiningOp()->erase();
188 if (
failed(
options.memCpyFn(builder, op.getLoc(), orig, arg)))
192 func::ReturnOp::create(builder, op.getLoc(), keepAsReturnOperands);
194 auto dynamicSizePair =
195 std::pair<func::FuncOp, SmallVector<SmallVector<Value>>>(func,
197 map.insert(dynamicSizePair);
200 return failure(res.wasInterrupted());
208 bool didFail =
false;
210 module.walk([&](func::CallOp op) {
211 auto callee = symtab.
lookup<func::FuncOp>(op.getCallee());
213 op.emitError() <<
"cannot find callee '" << op.getCallee() <<
"' in "
218 if (!
options.filterFn(&callee))
220 if (callee.isExternal() || callee.isPublic())
225 for (
OpResult result : op.getResults()) {
226 if (isa<MemRefType>(result.getType()))
227 replaceWithOutParams.push_back(result);
229 replaceWithNewCallResults.push_back(result);
234 size_t dynamicSizesIndex = 0;
235 for (
Value memref : replaceWithOutParams) {
237 ? dynamicSizes[dynamicSizesIndex]
239 bool memrefStaticShape =
240 cast<MemRefType>(memref.getType()).hasStaticShape();
241 if (!memrefStaticShape && dynamicSize.empty()) {
243 <<
"cannot create out param for dynamically shaped result";
247 auto memrefType = cast<MemRefType>(memref.getType());
250 AffineMap(), memrefType.getMemorySpace());
252 if (memrefStaticShape) {
259 options.allocationFn(builder, op.getLoc(), allocType, dynamicSize);
260 if (
failed(maybeOutParam)) {
261 op.emitError() <<
"failed to create allocation op";
265 Value outParam = maybeOutParam.value();
269 "layout map not supported");
271 memref::CastOp::create(builder, op.getLoc(), memrefType, outParam);
273 memref.replaceAllUsesWith(outParam);
274 outParams.push_back(outParam);
277 auto newOperands = llvm::to_vector<6>(op.getOperands());
278 newOperands.append(outParams.begin(), outParams.end());
279 auto newResultTypes = llvm::to_vector<6>(llvm::map_range(
280 replaceWithNewCallResults, [](
Value v) {
return v.
getType(); }));
281 auto newCall = func::CallOp::create(
282 builder, op.getLoc(), op.getCalleeAttr(), newResultTypes, newOperands);
283 for (
auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))
284 std::get<0>(t).replaceAllUsesWith(std::get<1>(t));
288 return failure(didFail);
297 for (
auto func : module.getOps<func::FuncOp>()) {
298 if (func.isExternal() || func.isPublic())
316 struct BufferResultsToOutParamsPass
317 : bufferization::impl::BufferResultsToOutParamsPassBase<
318 BufferResultsToOutParamsPass> {
321 void runOnOperation()
override {
323 if (addResultAttribute)
324 options.addResultAttribute =
true;
325 if (hoistStaticAllocs)
326 options.hoistStaticAllocs =
true;
327 if (hoistDynamicAllocs)
328 options.hoistDynamicAllocs =
true;
332 return signalPassFailure();
static LogicalResult updateCalls(ModuleOp module, const AllocDynamicSizesMap &map, const bufferization::BufferResultsToOutParamsOpts &options)
bufferization::BufferResultsToOutParamsOpts::AllocationFn AllocationFn
bufferization::BufferResultsToOutParamsOpts::MemCpyFn MemCpyFn
static SmallVector< Value > mapDynamicSizeAtCaller(func::CallOp call, func::FuncOp callee, ValueRange dynamicSizes)
Returns the dynamic sizes at the callee, through the call relationship between the caller and callee.
static LogicalResult updateFuncOp(func::FuncOp func, SmallVectorImpl< BlockArgument > &appendedEntryArgs, bool addResultAttribute)
static bool hasStaticIdentityLayout(MemRefType type)
Return true if the given MemRef type has a static identity layout (i.e., no layout).
static SmallVector< Value > getDynamicSize(Value memref, func::FuncOp funcOp)
Return the dynamic shapes of the memref based on the defining op.
static bool hasFullyDynamicLayoutMap(MemRefType type)
Return true if the given MemRef type has a fully dynamic layout.
static LogicalResult updateReturnOps(func::FuncOp func, ArrayRef< BlockArgument > appendedEntryArgs, AllocDynamicSizesMap &map, const bufferization::BufferResultsToOutParamsOpts &options)
static llvm::ManagedStatic< PassManagerOptions > options
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
This class represents an argument of a Block.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
operand_range getOperands()
Returns an iterator on the underlying Value's.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
LogicalResult promoteBufferResultsToOutParams(ModuleOp module, const BufferResultsToOutParamsOpts &options)
Replace buffers that are returned from a function with an out parameter.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::function< LogicalResult(OpBuilder &, Location, Value, Value)> MemCpyFn
Memcpy function: Generate a memcpy between two memrefs.
std::function< FailureOr< Value >(OpBuilder &, Location, MemRefType, ValueRange)> AllocationFn
Allocator function: Generate a memref allocation with the given type.