MLIR  22.0.0git
BufferResultsToOutParams.cpp
Go to the documentation of this file.
1 //===- BufferResultsToOutParams.cpp - Calling convention conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 
14 #include "mlir/IR/Operation.h"
15 
16 namespace mlir {
17 namespace bufferization {
18 #define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMSPASS
19 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
20 } // namespace bufferization
21 } // namespace mlir
22 
23 using namespace mlir;
28 
29 /// Return `true` if the given MemRef type has a fully dynamic layout.
30 static bool hasFullyDynamicLayoutMap(MemRefType type) {
31  int64_t offset;
33  if (failed(type.getStridesAndOffset(strides, offset)))
34  return false;
35  if (!llvm::all_of(strides, ShapedType::isDynamic))
36  return false;
37  if (ShapedType::isStatic(offset))
38  return false;
39  return true;
40 }
41 
42 /// Return `true` if the given MemRef type has a static identity layout (i.e.,
43 /// no layout).
44 static bool hasStaticIdentityLayout(MemRefType type) {
45  return type.getLayout().isIdentity();
46 }
47 
48 /// Return the dynamic shapes of the `memref` based on the defining op. If the
49 /// complete dynamic shape fails to be captured, return an empty value.
50 /// Currently, only function block arguments are supported for capturing.
51 static SmallVector<Value> getDynamicSize(Value memref, func::FuncOp funcOp) {
52  Operation *defOp = memref.getDefiningOp();
53  if (!defOp)
54  return {};
55  auto operands = defOp->getOperands();
56  SmallVector<Value> dynamicSizes;
57  for (Value size : operands) {
58  if (!isa<IndexType>(size.getType()))
59  continue;
60 
61  BlockArgument sizeSrc = dyn_cast<BlockArgument>(size);
62  if (!sizeSrc)
63  return {};
64  auto arguments = funcOp.getArguments();
65  auto iter = llvm::find(arguments, sizeSrc);
66  if (iter == arguments.end())
67  return {};
68  dynamicSizes.push_back(*iter);
69  }
70  return dynamicSizes;
71 }
72 
73 /// Returns the dynamic sizes at the callee, through the call relationship
74 /// between the caller and callee.
75 static SmallVector<Value> mapDynamicSizeAtCaller(func::CallOp call,
76  func::FuncOp callee,
77  ValueRange dynamicSizes) {
78  SmallVector<Value> mappedDynamicSizes;
79  for (Value size : dynamicSizes) {
80  for (auto [src, dst] :
81  llvm::zip_first(call.getOperands(), callee.getArguments())) {
82  if (size != dst)
83  continue;
84  mappedDynamicSizes.push_back(src);
85  }
86  }
87  assert(mappedDynamicSizes.size() == dynamicSizes.size() &&
88  "could not find all dynamic sizes");
89  return mappedDynamicSizes;
90 }
91 
92 // Updates the func op and entry block.
93 //
94 // Any args appended to the entry block are added to `appendedEntryArgs`.
95 // If `addResultAttribute` is true, adds the unit attribute `bufferize.result`
96 // to each newly created function argument.
97 static LogicalResult
98 updateFuncOp(func::FuncOp func,
99  SmallVectorImpl<BlockArgument> &appendedEntryArgs,
100  bool addResultAttribute) {
101  auto functionType = func.getFunctionType();
102 
103  // Collect information about the results will become appended arguments.
104  SmallVector<Type, 6> erasedResultTypes;
105  BitVector erasedResultIndices(functionType.getNumResults());
106  for (const auto &resultType : llvm::enumerate(functionType.getResults())) {
107  if (auto memrefType = dyn_cast<MemRefType>(resultType.value())) {
108  if (!hasStaticIdentityLayout(memrefType) &&
109  !hasFullyDynamicLayoutMap(memrefType)) {
110  // Only buffers with static identity layout can be allocated. These can
111  // be casted to memrefs with fully dynamic layout map. Other layout maps
112  // are not supported.
113  return func->emitError()
114  << "cannot create out param for result with unsupported layout";
115  }
116  erasedResultIndices.set(resultType.index());
117  erasedResultTypes.push_back(memrefType);
118  }
119  }
120 
121  // Add the new arguments to the function type.
122  auto newArgTypes = llvm::to_vector<6>(
123  llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
124  auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes,
125  functionType.getResults());
126  func.setType(newFunctionType);
127 
128  // Transfer the result attributes to arg attributes.
129  auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
130  for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
131  func.setArgAttrs(functionType.getNumInputs() + i,
132  func.getResultAttrs(*erasedIndicesIt));
133  if (addResultAttribute)
134  func.setArgAttr(functionType.getNumInputs() + i,
135  StringAttr::get(func.getContext(), "bufferize.result"),
136  UnitAttr::get(func.getContext()));
137  }
138 
139  // Erase the results.
140  if (failed(func.eraseResults(erasedResultIndices)))
141  return failure();
142 
143  // Add the new arguments to the entry block if the function is not external.
144  if (func.isExternal())
145  return success();
146  Location loc = func.getLoc();
147  for (Type type : erasedResultTypes)
148  appendedEntryArgs.push_back(func.front().addArgument(type, loc));
149 
150  return success();
151 }
152 
153 // Updates all ReturnOps in the scope of the given func::FuncOp by either
154 // keeping them as return values or copying the associated buffer contents into
155 // the given out-params.
156 static LogicalResult
157 updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
160  auto res = func.walk([&](func::ReturnOp op) {
161  SmallVector<Value, 6> copyIntoOutParams;
162  SmallVector<Value, 6> keepAsReturnOperands;
163  for (Value operand : op.getOperands()) {
164  if (isa<MemRefType>(operand.getType()))
165  copyIntoOutParams.push_back(operand);
166  else
167  keepAsReturnOperands.push_back(operand);
168  }
169  OpBuilder builder(op);
170  SmallVector<SmallVector<Value>> dynamicSizes;
171  for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
172  bool hoistStaticAllocs =
173  options.hoistStaticAllocs &&
174  cast<MemRefType>(orig.getType()).hasStaticShape();
175  bool hoistDynamicAllocs =
176  options.hoistDynamicAllocs &&
177  !cast<MemRefType>(orig.getType()).hasStaticShape();
178  if ((hoistStaticAllocs || hoistDynamicAllocs) &&
179  isa_and_nonnull<bufferization::AllocationOpInterface>(
180  orig.getDefiningOp())) {
181  orig.replaceAllUsesWith(arg);
182  if (hoistDynamicAllocs) {
183  SmallVector<Value> dynamicSize = getDynamicSize(orig, func);
184  dynamicSizes.push_back(dynamicSize);
185  }
186  orig.getDefiningOp()->erase();
187  } else {
188  if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg)))
189  return WalkResult::interrupt();
190  }
191  }
192  func::ReturnOp::create(builder, op.getLoc(), keepAsReturnOperands);
193  op.erase();
194  auto dynamicSizePair =
195  std::pair<func::FuncOp, SmallVector<SmallVector<Value>>>(func,
196  dynamicSizes);
197  map.insert(dynamicSizePair);
198  return WalkResult::advance();
199  });
200  return failure(res.wasInterrupted());
201 }
202 
203 // Updates all CallOps in the scope of the given ModuleOp by allocating
204 // temporary buffers for newly introduced out params.
205 static LogicalResult
206 updateCalls(ModuleOp module, const AllocDynamicSizesMap &map,
208  bool didFail = false;
209  SymbolTable symtab(module);
210  module.walk([&](func::CallOp op) {
211  auto callee = symtab.lookup<func::FuncOp>(op.getCallee());
212  if (!callee) {
213  op.emitError() << "cannot find callee '" << op.getCallee() << "' in "
214  << "symbol table";
215  didFail = true;
216  return;
217  }
218  if (!options.filterFn(&callee))
219  return;
220  if (callee.isExternal() || callee.isPublic())
221  return;
222 
223  SmallVector<Value, 6> replaceWithNewCallResults;
224  SmallVector<Value, 6> replaceWithOutParams;
225  for (OpResult result : op.getResults()) {
226  if (isa<MemRefType>(result.getType()))
227  replaceWithOutParams.push_back(result);
228  else
229  replaceWithNewCallResults.push_back(result);
230  }
231  SmallVector<Value, 6> outParams;
232  OpBuilder builder(op);
233  SmallVector<SmallVector<Value>> dynamicSizes = map.lookup(callee);
234  size_t dynamicSizesIndex = 0;
235  for (Value memref : replaceWithOutParams) {
236  SmallVector<Value> dynamicSize = dynamicSizes.size() > dynamicSizesIndex
237  ? dynamicSizes[dynamicSizesIndex]
238  : SmallVector<Value>();
239  bool memrefStaticShape =
240  cast<MemRefType>(memref.getType()).hasStaticShape();
241  if (!memrefStaticShape && dynamicSize.empty()) {
242  op.emitError()
243  << "cannot create out param for dynamically shaped result";
244  didFail = true;
245  return;
246  }
247  auto memrefType = cast<MemRefType>(memref.getType());
248  auto allocType =
249  MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
250  AffineMap(), memrefType.getMemorySpace());
251 
252  if (memrefStaticShape) {
253  dynamicSize = {};
254  } else {
255  ++dynamicSizesIndex;
256  dynamicSize = mapDynamicSizeAtCaller(op, callee, dynamicSize);
257  }
258  auto maybeOutParam =
259  options.allocationFn(builder, op.getLoc(), allocType, dynamicSize);
260  if (failed(maybeOutParam)) {
261  op.emitError() << "failed to create allocation op";
262  didFail = true;
263  return;
264  }
265  Value outParam = maybeOutParam.value();
266  if (!hasStaticIdentityLayout(memrefType)) {
267  // Layout maps are already checked in `updateFuncOp`.
268  assert(hasFullyDynamicLayoutMap(memrefType) &&
269  "layout map not supported");
270  outParam =
271  memref::CastOp::create(builder, op.getLoc(), memrefType, outParam);
272  }
273  memref.replaceAllUsesWith(outParam);
274  outParams.push_back(outParam);
275  }
276 
277  auto newOperands = llvm::to_vector<6>(op.getOperands());
278  newOperands.append(outParams.begin(), outParams.end());
279  auto newResultTypes = llvm::to_vector<6>(llvm::map_range(
280  replaceWithNewCallResults, [](Value v) { return v.getType(); }));
281  auto newCall = func::CallOp::create(
282  builder, op.getLoc(), op.getCalleeAttr(), newResultTypes, newOperands);
283  for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))
284  std::get<0>(t).replaceAllUsesWith(std::get<1>(t));
285  op.erase();
286  });
287 
288  return failure(didFail);
289 }
290 
292  ModuleOp module,
294  // It maps the shape source of the dynamic shape memref returned by each
295  // function.
297  for (auto func : module.getOps<func::FuncOp>()) {
298  if (func.isExternal() || func.isPublic())
299  continue;
300  if (!options.filterFn(&func))
301  continue;
302  SmallVector<BlockArgument, 6> appendedEntryArgs;
303  if (failed(
304  updateFuncOp(func, appendedEntryArgs, options.addResultAttribute)))
305  return failure();
306  if (failed(updateReturnOps(func, appendedEntryArgs, map, options))) {
307  return failure();
308  }
309  }
310  if (failed(updateCalls(module, map, options)))
311  return failure();
312  return success();
313 }
314 
315 namespace {
316 struct BufferResultsToOutParamsPass
317  : bufferization::impl::BufferResultsToOutParamsPassBase<
318  BufferResultsToOutParamsPass> {
319  using Base::Base;
320 
321  void runOnOperation() override {
322  // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
323  if (addResultAttribute)
324  options.addResultAttribute = true;
325  if (hoistStaticAllocs)
326  options.hoistStaticAllocs = true;
327  if (hoistDynamicAllocs)
328  options.hoistDynamicAllocs = true;
329 
331  options)))
332  return signalPassFailure();
333  }
334 
335 private:
337 };
338 } // namespace
static LogicalResult updateCalls(ModuleOp module, const AllocDynamicSizesMap &map, const bufferization::BufferResultsToOutParamsOpts &options)
bufferization::BufferResultsToOutParamsOpts::AllocationFn AllocationFn
bufferization::BufferResultsToOutParamsOpts::MemCpyFn MemCpyFn
static SmallVector< Value > mapDynamicSizeAtCaller(func::CallOp call, func::FuncOp callee, ValueRange dynamicSizes)
Returns the dynamic sizes at the callee, through the call relationship between the caller and callee.
static LogicalResult updateFuncOp(func::FuncOp func, SmallVectorImpl< BlockArgument > &appendedEntryArgs, bool addResultAttribute)
static bool hasStaticIdentityLayout(MemRefType type)
Return true if the given MemRef type has a static identity layout (i.e., no layout).
static SmallVector< Value > getDynamicSize(Value memref, func::FuncOp funcOp)
Return the dynamic shapes of the memref based on the defining op.
static bool hasFullyDynamicLayoutMap(MemRefType type)
Return true if the given MemRef type has a fully dynamic layout.
static LogicalResult updateReturnOps(func::FuncOp func, ArrayRef< BlockArgument > appendedEntryArgs, AllocDynamicSizesMap &map, const bufferization::BufferResultsToOutParamsOpts &options)
static llvm::ManagedStatic< PassManagerOptions > options
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
This class represents an argument of a Block.
Definition: Value.h:309
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:207
This is a value defined by a result of an operation.
Definition: Value.h:447
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static WalkResult advance()
Definition: WalkResult.h:47
static WalkResult interrupt()
Definition: WalkResult.h:46
LogicalResult promoteBufferResultsToOutParams(ModuleOp module, const BufferResultsToOutParamsOpts &options)
Replace buffers that are returned from a function with an out parameter.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::function< LogicalResult(OpBuilder &, Location, Value, Value)> MemCpyFn
Memcpy function: Generate a memcpy between two memrefs.
Definition: Passes.h:139
std::function< FailureOr< Value >(OpBuilder &, Location, MemRefType, ValueRange)> AllocationFn
Allocator function: Generate a memref allocation with the given type.
Definition: Passes.h:135