18#define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMSPASS
19#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
33 if (failed(type.getStridesAndOffset(strides, offset)))
35 if (!llvm::all_of(strides, ShapedType::isDynamic))
37 if (ShapedType::isStatic(offset))
45 return type.getLayout().isIdentity();
57 for (
Value size : operands) {
58 if (!isa<IndexType>(size.getType()))
64 auto arguments = funcOp.getArguments();
65 auto iter = llvm::find(arguments, sizeSrc);
66 if (iter == arguments.end())
68 dynamicSizes.push_back(*iter);
79 for (
Value size : dynamicSizes) {
80 for (
auto [src, dst] :
81 llvm::zip_first(call.getOperands(), callee.getArguments())) {
84 mappedDynamicSizes.push_back(src);
87 assert(mappedDynamicSizes.size() == dynamicSizes.size() &&
88 "could not find all dynamic sizes");
89 return mappedDynamicSizes;
100 bool addResultAttribute) {
101 auto functionType =
func.getFunctionType();
105 BitVector erasedResultIndices(functionType.getNumResults());
106 for (
const auto &resultType : llvm::enumerate(functionType.getResults())) {
107 if (
auto memrefType = dyn_cast<MemRefType>(resultType.value())) {
113 return func->emitError()
114 <<
"cannot create out param for result with unsupported layout";
116 erasedResultIndices.set(resultType.index());
117 erasedResultTypes.push_back(memrefType);
122 auto newArgTypes = llvm::to_vector<6>(
123 llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
124 auto newFunctionType = FunctionType::get(
func.getContext(), newArgTypes,
125 functionType.getResults());
126 func.setType(newFunctionType);
129 auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
130 for (
int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
131 func.setArgAttrs(functionType.getNumInputs() + i,
132 func.getResultAttrs(*erasedIndicesIt));
133 if (addResultAttribute)
134 func.setArgAttr(functionType.getNumInputs() + i,
135 StringAttr::get(
func.getContext(),
"bufferize.result"),
136 UnitAttr::get(
func.getContext()));
140 if (failed(
func.eraseResults(erasedResultIndices)))
144 if (
func.isExternal())
147 for (
Type type : erasedResultTypes)
148 appendedEntryArgs.push_back(
func.front().addArgument(type, loc));
160 auto res =
func.walk([&](func::ReturnOp op) {
163 for (
Value operand : op.getOperands()) {
164 if (isa<MemRefType>(operand.getType()))
165 copyIntoOutParams.push_back(operand);
167 keepAsReturnOperands.push_back(operand);
171 for (
auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
172 bool hoistStaticAllocs =
174 cast<MemRefType>(orig.getType()).hasStaticShape();
175 bool hoistDynamicAllocs =
177 !cast<MemRefType>(orig.getType()).hasStaticShape();
178 if ((hoistStaticAllocs || hoistDynamicAllocs) &&
179 isa_and_nonnull<bufferization::AllocationOpInterface>(
180 orig.getDefiningOp())) {
181 orig.replaceAllUsesWith(arg);
182 if (hoistDynamicAllocs) {
184 dynamicSizes.push_back(dynamicSize);
186 orig.getDefiningOp()->erase();
188 if (failed(
options.memCpyFn(builder, op.getLoc(), orig, arg)))
192 func::ReturnOp::create(builder, op.getLoc(), keepAsReturnOperands);
194 auto dynamicSizePair =
195 std::pair<func::FuncOp, SmallVector<SmallVector<Value>>>(
func,
197 map.insert(dynamicSizePair);
200 return failure(res.wasInterrupted());
208 bool didFail =
false;
210 module.walk([&](func::CallOp op) {
211 auto callee = symtab.lookup<func::FuncOp>(op.getCallee());
213 op.emitError() <<
"cannot find callee '" << op.getCallee() <<
"' in "
218 if (!
options.filterFn(&callee))
220 if (callee.isPublic() && !
options.modifyPublicFunctions)
222 if (callee.isExternal())
228 if (isa<MemRefType>(
result.getType()))
229 replaceWithOutParams.push_back(
result);
231 replaceWithNewCallResults.push_back(
result);
236 size_t dynamicSizesIndex = 0;
239 ? dynamicSizes[dynamicSizesIndex]
241 bool memrefStaticShape =
242 cast<MemRefType>(
memref.getType()).hasStaticShape();
243 if (!memrefStaticShape && dynamicSize.empty()) {
245 <<
"cannot create out param for dynamically shaped result";
249 auto memrefType = cast<MemRefType>(
memref.getType());
251 MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
252 AffineMap(), memrefType.getMemorySpace());
254 if (memrefStaticShape) {
261 options.allocationFn(builder, op.getLoc(), allocType, dynamicSize);
262 if (failed(maybeOutParam)) {
263 op.emitError() <<
"failed to create allocation op";
267 Value outParam = maybeOutParam.value();
271 "layout map not supported");
273 memref::CastOp::create(builder, op.getLoc(), memrefType, outParam);
275 memref.replaceAllUsesWith(outParam);
276 outParams.push_back(outParam);
279 auto newOperands = llvm::to_vector<6>(op.getOperands());
280 newOperands.append(outParams.begin(), outParams.end());
281 auto newResultTypes = llvm::to_vector<6>(llvm::map_range(
283 auto newCall = func::CallOp::create(
284 builder, op.getLoc(), op.getCalleeAttr(), newResultTypes, newOperands);
285 for (
auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))
286 std::get<0>(t).replaceAllUsesWith(std::get<1>(t));
290 return failure(didFail);
299 for (
auto func : module.getOps<func::FuncOp>()) {
302 if (
func.isExternal())
320struct BufferResultsToOutParamsPass
322 BufferResultsToOutParamsPass> {
330 options.hoistStaticAllocs =
true;
334 options.modifyPublicFunctions =
true;
static SmallVector< Value > getDynamicSize(Value memref, func::FuncOp funcOp)
Return the dynamic shapes of the memref based on the defining op.
static LogicalResult updateCalls(ModuleOp module, const AllocDynamicSizesMap &map, const bufferization::BufferResultsToOutParamsOpts &options)
bufferization::BufferResultsToOutParamsOpts::AllocationFn AllocationFn
bufferization::BufferResultsToOutParamsOpts::MemCpyFn MemCpyFn
static LogicalResult updateFuncOp(func::FuncOp func, SmallVectorImpl< BlockArgument > &appendedEntryArgs, bool addResultAttribute)
static bool hasStaticIdentityLayout(MemRefType type)
Return true if the given MemRef type has a static identity layout (i.e., no layout).
static SmallVector< Value > mapDynamicSizeAtCaller(func::CallOp call, func::FuncOp callee, ValueRange dynamicSizes)
Returns the dynamic sizes at the callee, through the call relationship between the caller and callee.
llvm::DenseMap< func::FuncOp, SmallVector< SmallVector< Value > > > AllocDynamicSizesMap
static bool hasFullyDynamicLayoutMap(MemRefType type)
Return true if the given MemRef type has a fully dynamic layout.
static LogicalResult updateReturnOps(func::FuncOp func, ArrayRef< BlockArgument > appendedEntryArgs, AllocDynamicSizesMap &map, const bufferization::BufferResultsToOutParamsOpts &options)
static llvm::ManagedStatic< PassManagerOptions > options
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
This class represents an argument of a Block.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
operand_range getOperands()
Returns an iterator on the underlying Value's.
virtual void runOnOperation()=0
The polymorphic API that runs the pass over the currently held operation.
void signalPassFailure()
Signal that some invariant was broken when running.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
::mlir::Pass::Option< bool > modifyPublicFunctions
::mlir::Pass::Option< bool > hoistDynamicAllocs
BufferResultsToOutParamsPassBase Base
::mlir::Pass::Option< bool > hoistStaticAllocs
::mlir::Pass::Option< bool > addResultAttribute
LogicalResult promoteBufferResultsToOutParams(ModuleOp module, const BufferResultsToOutParamsOpts &options)
Replace buffers that are returned from a function with an out parameter.
Include the generated interface declarations.
std::function< LogicalResult(OpBuilder &, Location, Value, Value)> MemCpyFn
Memcpy function: Generate a memcpy between two memrefs.
std::function< FailureOr< Value >(OpBuilder &, Location, MemRefType, ValueRange)> AllocationFn
Allocator function: Generate a memref allocation with the given type.