18#define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMSPASS
19#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
33 if (failed(type.getStridesAndOffset(strides, offset)))
35 if (!llvm::all_of(strides, ShapedType::isDynamic))
37 if (ShapedType::isStatic(offset))
45 return type.getLayout().isIdentity();
57 for (
Value size : operands) {
58 if (!isa<IndexType>(size.getType()))
64 auto arguments = funcOp.getArguments();
65 auto iter = llvm::find(arguments, sizeSrc);
66 if (iter == arguments.end())
68 dynamicSizes.push_back(*iter);
79 for (
Value size : dynamicSizes) {
80 for (
auto [src, dst] :
81 llvm::zip_first(call.getOperands(), callee.getArguments())) {
84 mappedDynamicSizes.push_back(src);
87 assert(mappedDynamicSizes.size() == dynamicSizes.size() &&
88 "could not find all dynamic sizes");
89 return mappedDynamicSizes;
100 bool addResultAttribute) {
101 auto functionType =
func.getFunctionType();
105 BitVector erasedResultIndices(functionType.getNumResults());
106 for (
const auto &resultType : llvm::enumerate(functionType.getResults())) {
107 if (
auto memrefType = dyn_cast<MemRefType>(resultType.value())) {
113 return func->emitError()
114 <<
"cannot create out param for result with unsupported layout";
116 erasedResultIndices.set(resultType.index());
117 erasedResultTypes.push_back(memrefType);
122 auto newArgTypes = llvm::to_vector<6>(
123 llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
124 auto newFunctionType = FunctionType::get(
func.getContext(), newArgTypes,
125 functionType.getResults());
126 func.setType(newFunctionType);
129 auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
130 for (
int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
131 func.setArgAttrs(functionType.getNumInputs() + i,
132 func.getResultAttrs(*erasedIndicesIt));
133 if (addResultAttribute)
134 func.setArgAttr(functionType.getNumInputs() + i,
135 StringAttr::get(
func.getContext(),
"bufferize.result"),
136 UnitAttr::get(
func.getContext()));
140 if (failed(
func.eraseResults(erasedResultIndices)))
144 if (
func.isExternal())
147 for (
Type type : erasedResultTypes)
148 appendedEntryArgs.push_back(
func.front().addArgument(type, loc));
160 auto res =
func.walk([&](func::ReturnOp op) {
163 for (
Value operand : op.getOperands()) {
164 if (isa<MemRefType>(operand.getType()))
165 copyIntoOutParams.push_back(operand);
167 keepAsReturnOperands.push_back(operand);
171 for (
auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
172 bool hoistStaticAllocs =
174 cast<MemRefType>(orig.getType()).hasStaticShape();
175 bool hoistDynamicAllocs =
177 !cast<MemRefType>(orig.getType()).hasStaticShape();
178 if ((hoistStaticAllocs || hoistDynamicAllocs) &&
179 isa_and_nonnull<bufferization::AllocationOpInterface>(
180 orig.getDefiningOp())) {
181 orig.replaceAllUsesWith(arg);
182 if (hoistDynamicAllocs) {
184 dynamicSizes.push_back(dynamicSize);
186 orig.getDefiningOp()->erase();
188 if (failed(
options.memCpyFn(builder, op.getLoc(), orig, arg)))
192 func::ReturnOp::create(builder, op.getLoc(), keepAsReturnOperands);
194 auto dynamicSizePair =
195 std::pair<func::FuncOp, SmallVector<SmallVector<Value>>>(
func,
197 map.insert(dynamicSizePair);
200 return failure(res.wasInterrupted());
208 bool didFail =
false;
210 module.walk([&](func::CallOp op) {
211 auto callee = symtab.lookup<func::FuncOp>(op.getCallee());
213 op.emitError() <<
"cannot find callee '" << op.getCallee() <<
"' in "
218 if (!
options.filterFn(&callee))
220 if (callee.isExternal() || callee.isPublic())
226 if (isa<MemRefType>(
result.getType()))
227 replaceWithOutParams.push_back(
result);
229 replaceWithNewCallResults.push_back(
result);
234 size_t dynamicSizesIndex = 0;
237 ? dynamicSizes[dynamicSizesIndex]
239 bool memrefStaticShape =
240 cast<MemRefType>(
memref.getType()).hasStaticShape();
241 if (!memrefStaticShape && dynamicSize.empty()) {
243 <<
"cannot create out param for dynamically shaped result";
247 auto memrefType = cast<MemRefType>(
memref.getType());
249 MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
250 AffineMap(), memrefType.getMemorySpace());
252 if (memrefStaticShape) {
259 options.allocationFn(builder, op.getLoc(), allocType, dynamicSize);
260 if (failed(maybeOutParam)) {
261 op.emitError() <<
"failed to create allocation op";
265 Value outParam = maybeOutParam.value();
269 "layout map not supported");
271 memref::CastOp::create(builder, op.getLoc(), memrefType, outParam);
274 outParams.push_back(outParam);
277 auto newOperands = llvm::to_vector<6>(op.getOperands());
278 newOperands.append(outParams.begin(), outParams.end());
279 auto newResultTypes = llvm::to_vector<6>(llvm::map_range(
281 auto newCall = func::CallOp::create(
282 builder, op.getLoc(), op.getCalleeAttr(), newResultTypes, newOperands);
283 for (
auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))
284 std::get<0>(t).replaceAllUsesWith(std::get<1>(t));
288 return failure(didFail);
297 for (
auto func : module.getOps<func::FuncOp>()) {
298 if (
func.isExternal() ||
func.isPublic())
316struct BufferResultsToOutParamsPass
318 BufferResultsToOutParamsPass> {
328 options.hoistDynamicAllocs =
true;
static SmallVector< Value > getDynamicSize(Value memref, func::FuncOp funcOp)
Return the dynamic shapes of the memref based on the defining op.
static LogicalResult updateCalls(ModuleOp module, const AllocDynamicSizesMap &map, const bufferization::BufferResultsToOutParamsOpts &options)
bufferization::BufferResultsToOutParamsOpts::AllocationFn AllocationFn
bufferization::BufferResultsToOutParamsOpts::MemCpyFn MemCpyFn
static LogicalResult updateFuncOp(func::FuncOp func, SmallVectorImpl< BlockArgument > &appendedEntryArgs, bool addResultAttribute)
static bool hasStaticIdentityLayout(MemRefType type)
Return true if the given MemRef type has a static identity layout (i.e., no layout).
static SmallVector< Value > mapDynamicSizeAtCaller(func::CallOp call, func::FuncOp callee, ValueRange dynamicSizes)
Returns the dynamic sizes at the callee, through the call relationship between the caller and callee.
llvm::DenseMap< func::FuncOp, SmallVector< SmallVector< Value > > > AllocDynamicSizesMap
static bool hasFullyDynamicLayoutMap(MemRefType type)
Return true if the given MemRef type has a fully dynamic layout.
static LogicalResult updateReturnOps(func::FuncOp func, ArrayRef< BlockArgument > appendedEntryArgs, AllocDynamicSizesMap &map, const bufferization::BufferResultsToOutParamsOpts &options)
static llvm::ManagedStatic< PassManagerOptions > options
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
This class represents an argument of a Block.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
operand_range getOperands()
Returns an iterator on the underlying Value's.
virtual void runOnOperation()=0
The polymorphic API that runs the pass over the currently held operation.
void signalPassFailure()
Signal that some invariant was broken when running.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
::mlir::Pass::Option< bool > hoistDynamicAllocs
BufferResultsToOutParamsPassBase Base
::mlir::Pass::Option< bool > hoistStaticAllocs
::mlir::Pass::Option< bool > addResultAttribute
LogicalResult promoteBufferResultsToOutParams(ModuleOp module, const BufferResultsToOutParamsOpts &options)
Replace buffers that are returned from a function with an out parameter.
Include the generated interface declarations.
std::function< LogicalResult(OpBuilder &, Location, Value, Value)> MemCpyFn
Memcpy function: Generate a memcpy between two memrefs.
std::function< FailureOr< Value >(OpBuilder &, Location, MemRefType, ValueRange)> AllocationFn
Allocator function: Generate a memref allocation with the given type.