MLIR 22.0.0git
BufferResultsToOutParams.cpp
Go to the documentation of this file.
1//===- BufferResultsToOutParams.cpp - Calling convention conversion -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11
14#include "mlir/IR/Operation.h"
15
16namespace mlir {
17namespace bufferization {
18#define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMSPASS
19#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
20} // namespace bufferization
21} // namespace mlir
22
23using namespace mlir;
28
29/// Return `true` if the given MemRef type has a fully dynamic layout.
30static bool hasFullyDynamicLayoutMap(MemRefType type) {
31 int64_t offset;
33 if (failed(type.getStridesAndOffset(strides, offset)))
34 return false;
35 if (!llvm::all_of(strides, ShapedType::isDynamic))
36 return false;
37 if (ShapedType::isStatic(offset))
38 return false;
39 return true;
40}
41
42/// Return `true` if the given MemRef type has a static identity layout (i.e.,
43/// no layout).
44static bool hasStaticIdentityLayout(MemRefType type) {
45 return type.getLayout().isIdentity();
46}
47
48/// Return the dynamic shapes of the `memref` based on the defining op. If the
49/// complete dynamic shape fails to be captured, return an empty value.
50/// Currently, only function block arguments are supported for capturing.
51static SmallVector<Value> getDynamicSize(Value memref, func::FuncOp funcOp) {
52 Operation *defOp = memref.getDefiningOp();
53 if (!defOp)
54 return {};
55 auto operands = defOp->getOperands();
56 SmallVector<Value> dynamicSizes;
57 for (Value size : operands) {
58 if (!isa<IndexType>(size.getType()))
59 continue;
60
61 BlockArgument sizeSrc = dyn_cast<BlockArgument>(size);
62 if (!sizeSrc)
63 return {};
64 auto arguments = funcOp.getArguments();
65 auto iter = llvm::find(arguments, sizeSrc);
66 if (iter == arguments.end())
67 return {};
68 dynamicSizes.push_back(*iter);
69 }
70 return dynamicSizes;
71}
72
73/// Returns the dynamic sizes at the callee, through the call relationship
74/// between the caller and callee.
76 func::FuncOp callee,
77 ValueRange dynamicSizes) {
78 SmallVector<Value> mappedDynamicSizes;
79 for (Value size : dynamicSizes) {
80 for (auto [src, dst] :
81 llvm::zip_first(call.getOperands(), callee.getArguments())) {
82 if (size != dst)
83 continue;
84 mappedDynamicSizes.push_back(src);
85 }
86 }
87 assert(mappedDynamicSizes.size() == dynamicSizes.size() &&
88 "could not find all dynamic sizes");
89 return mappedDynamicSizes;
90}
91
92// Updates the func op and entry block.
93//
94// Any args appended to the entry block are added to `appendedEntryArgs`.
95// If `addResultAttribute` is true, adds the unit attribute `bufferize.result`
96// to each newly created function argument.
97static LogicalResult
98updateFuncOp(func::FuncOp func,
99 SmallVectorImpl<BlockArgument> &appendedEntryArgs,
100 bool addResultAttribute) {
101 auto functionType = func.getFunctionType();
102
103 // Collect information about the results will become appended arguments.
104 SmallVector<Type, 6> erasedResultTypes;
105 BitVector erasedResultIndices(functionType.getNumResults());
106 for (const auto &resultType : llvm::enumerate(functionType.getResults())) {
107 if (auto memrefType = dyn_cast<MemRefType>(resultType.value())) {
108 if (!hasStaticIdentityLayout(memrefType) &&
109 !hasFullyDynamicLayoutMap(memrefType)) {
110 // Only buffers with static identity layout can be allocated. These can
111 // be casted to memrefs with fully dynamic layout map. Other layout maps
112 // are not supported.
113 return func->emitError()
114 << "cannot create out param for result with unsupported layout";
115 }
116 erasedResultIndices.set(resultType.index());
117 erasedResultTypes.push_back(memrefType);
118 }
119 }
120
121 // Add the new arguments to the function type.
122 auto newArgTypes = llvm::to_vector<6>(
123 llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
124 auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes,
125 functionType.getResults());
126 func.setType(newFunctionType);
127
128 // Transfer the result attributes to arg attributes.
129 auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
130 for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
131 func.setArgAttrs(functionType.getNumInputs() + i,
132 func.getResultAttrs(*erasedIndicesIt));
133 if (addResultAttribute)
134 func.setArgAttr(functionType.getNumInputs() + i,
135 StringAttr::get(func.getContext(), "bufferize.result"),
136 UnitAttr::get(func.getContext()));
137 }
138
139 // Erase the results.
140 if (failed(func.eraseResults(erasedResultIndices)))
141 return failure();
142
143 // Add the new arguments to the entry block if the function is not external.
144 if (func.isExternal())
145 return success();
146 Location loc = func.getLoc();
147 for (Type type : erasedResultTypes)
148 appendedEntryArgs.push_back(func.front().addArgument(type, loc));
149
150 return success();
151}
152
153// Updates all ReturnOps in the scope of the given func::FuncOp by either
154// keeping them as return values or copying the associated buffer contents into
155// the given out-params.
156static LogicalResult
157updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
160 auto res = func.walk([&](func::ReturnOp op) {
161 SmallVector<Value, 6> copyIntoOutParams;
162 SmallVector<Value, 6> keepAsReturnOperands;
163 for (Value operand : op.getOperands()) {
164 if (isa<MemRefType>(operand.getType()))
165 copyIntoOutParams.push_back(operand);
166 else
167 keepAsReturnOperands.push_back(operand);
168 }
169 OpBuilder builder(op);
170 SmallVector<SmallVector<Value>> dynamicSizes;
171 for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
172 bool hoistStaticAllocs =
173 options.hoistStaticAllocs &&
174 cast<MemRefType>(orig.getType()).hasStaticShape();
175 bool hoistDynamicAllocs =
176 options.hoistDynamicAllocs &&
177 !cast<MemRefType>(orig.getType()).hasStaticShape();
178 if ((hoistStaticAllocs || hoistDynamicAllocs) &&
179 isa_and_nonnull<bufferization::AllocationOpInterface>(
180 orig.getDefiningOp())) {
181 orig.replaceAllUsesWith(arg);
182 if (hoistDynamicAllocs) {
183 SmallVector<Value> dynamicSize = getDynamicSize(orig, func);
184 dynamicSizes.push_back(dynamicSize);
185 }
186 orig.getDefiningOp()->erase();
187 } else {
188 if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg)))
189 return WalkResult::interrupt();
190 }
191 }
192 func::ReturnOp::create(builder, op.getLoc(), keepAsReturnOperands);
193 op.erase();
194 auto dynamicSizePair =
195 std::pair<func::FuncOp, SmallVector<SmallVector<Value>>>(func,
196 dynamicSizes);
197 map.insert(dynamicSizePair);
198 return WalkResult::advance();
199 });
200 return failure(res.wasInterrupted());
201}
202
203// Updates all CallOps in the scope of the given ModuleOp by allocating
204// temporary buffers for newly introduced out params.
205static LogicalResult
206updateCalls(ModuleOp module, const AllocDynamicSizesMap &map,
208 bool didFail = false;
209 SymbolTable symtab(module);
210 module.walk([&](func::CallOp op) {
211 auto callee = symtab.lookup<func::FuncOp>(op.getCallee());
212 if (!callee) {
213 op.emitError() << "cannot find callee '" << op.getCallee() << "' in "
214 << "symbol table";
215 didFail = true;
216 return;
217 }
218 if (!options.filterFn(&callee))
219 return;
220 if (callee.isExternal() || callee.isPublic())
221 return;
222
223 SmallVector<Value, 6> replaceWithNewCallResults;
224 SmallVector<Value, 6> replaceWithOutParams;
225 for (OpResult result : op.getResults()) {
226 if (isa<MemRefType>(result.getType()))
227 replaceWithOutParams.push_back(result);
228 else
229 replaceWithNewCallResults.push_back(result);
230 }
231 SmallVector<Value, 6> outParams;
232 OpBuilder builder(op);
233 SmallVector<SmallVector<Value>> dynamicSizes = map.lookup(callee);
234 size_t dynamicSizesIndex = 0;
235 for (Value memref : replaceWithOutParams) {
236 SmallVector<Value> dynamicSize = dynamicSizes.size() > dynamicSizesIndex
237 ? dynamicSizes[dynamicSizesIndex]
239 bool memrefStaticShape =
240 cast<MemRefType>(memref.getType()).hasStaticShape();
241 if (!memrefStaticShape && dynamicSize.empty()) {
242 op.emitError()
243 << "cannot create out param for dynamically shaped result";
244 didFail = true;
245 return;
246 }
247 auto memrefType = cast<MemRefType>(memref.getType());
248 auto allocType =
249 MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
250 AffineMap(), memrefType.getMemorySpace());
251
252 if (memrefStaticShape) {
253 dynamicSize = {};
254 } else {
255 ++dynamicSizesIndex;
256 dynamicSize = mapDynamicSizeAtCaller(op, callee, dynamicSize);
257 }
258 auto maybeOutParam =
259 options.allocationFn(builder, op.getLoc(), allocType, dynamicSize);
260 if (failed(maybeOutParam)) {
261 op.emitError() << "failed to create allocation op";
262 didFail = true;
263 return;
264 }
265 Value outParam = maybeOutParam.value();
266 if (!hasStaticIdentityLayout(memrefType)) {
267 // Layout maps are already checked in `updateFuncOp`.
268 assert(hasFullyDynamicLayoutMap(memrefType) &&
269 "layout map not supported");
270 outParam =
271 memref::CastOp::create(builder, op.getLoc(), memrefType, outParam);
272 }
273 memref.replaceAllUsesWith(outParam);
274 outParams.push_back(outParam);
276
277 auto newOperands = llvm::to_vector<6>(op.getOperands());
278 newOperands.append(outParams.begin(), outParams.end());
279 auto newResultTypes = llvm::to_vector<6>(llvm::map_range(
280 replaceWithNewCallResults, [](Value v) { return v.getType(); }));
281 auto newCall = func::CallOp::create(
282 builder, op.getLoc(), op.getCalleeAttr(), newResultTypes, newOperands);
283 for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))
284 std::get<0>(t).replaceAllUsesWith(std::get<1>(t));
285 op.erase();
286 });
287
288 return failure(didFail);
289}
292 ModuleOp module,
294 // It maps the shape source of the dynamic shape memref returned by each
295 // function.
297 for (auto func : module.getOps<func::FuncOp>()) {
298 if (func.isExternal() || func.isPublic())
299 continue;
300 if (!options.filterFn(&func))
301 continue;
302 SmallVector<BlockArgument, 6> appendedEntryArgs;
303 if (failed(
304 updateFuncOp(func, appendedEntryArgs, options.addResultAttribute)))
305 return failure();
306 if (failed(updateReturnOps(func, appendedEntryArgs, map, options))) {
307 return failure();
308 }
310 if (failed(updateCalls(module, map, options)))
311 return failure();
312 return success();
313}
314
315namespace {
316struct BufferResultsToOutParamsPass
318 BufferResultsToOutParamsPass> {
319 using Base::Base;
320
321 void runOnOperation() override {
322 // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
324 options.addResultAttribute = true;
326 options.hoistStaticAllocs = true;
328 options.hoistDynamicAllocs = true;
332 return signalPassFailure();
334
335private:
337};
338} // namespace
return success()
static SmallVector< Value > getDynamicSize(Value memref, func::FuncOp funcOp)
Return the dynamic shapes of the memref based on the defining op.
static LogicalResult updateCalls(ModuleOp module, const AllocDynamicSizesMap &map, const bufferization::BufferResultsToOutParamsOpts &options)
bufferization::BufferResultsToOutParamsOpts::AllocationFn AllocationFn
bufferization::BufferResultsToOutParamsOpts::MemCpyFn MemCpyFn
static LogicalResult updateFuncOp(func::FuncOp func, SmallVectorImpl< BlockArgument > &appendedEntryArgs, bool addResultAttribute)
static bool hasStaticIdentityLayout(MemRefType type)
Return true if the given MemRef type has a static identity layout (i.e., no layout).
static SmallVector< Value > mapDynamicSizeAtCaller(func::CallOp call, func::FuncOp callee, ValueRange dynamicSizes)
Returns the dynamic sizes at the callee, through the call relationship between the caller and callee.
llvm::DenseMap< func::FuncOp, SmallVector< SmallVector< Value > > > AllocDynamicSizesMap
static bool hasFullyDynamicLayoutMap(MemRefType type)
Return true if the given MemRef type has a fully dynamic layout.
static LogicalResult updateReturnOps(func::FuncOp func, ArrayRef< BlockArgument > appendedEntryArgs, AllocDynamicSizesMap &map, const bufferization::BufferResultsToOutParamsOpts &options)
static llvm::ManagedStatic< PassManagerOptions > options
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
This class represents an argument of a Block.
Definition Value.h:309
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
virtual void runOnOperation()=0
The polymorphic API that runs the pass over the currently held operation.
void signalPassFailure()
Signal that some invariant was broken when running.
Definition Pass.h:218
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
LogicalResult promoteBufferResultsToOutParams(ModuleOp module, const BufferResultsToOutParamsOpts &options)
Replace buffers that are returned from a function with an out parameter.
Include the generated interface declarations.
std::function< LogicalResult(OpBuilder &, Location, Value, Value)> MemCpyFn
Memcpy function: Generate a memcpy between two memrefs.
Definition Passes.h:138
std::function< FailureOr< Value >(OpBuilder &, Location, MemRefType, ValueRange)> AllocationFn
Allocator function: Generate a memref allocation with the given type.
Definition Passes.h:134