MLIR  19.0.0git
BufferResultsToOutParams.cpp
Go to the documentation of this file.
1 //===- BufferResultsToOutParams.cpp - Calling convention conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
13 #include "mlir/IR/Operation.h"
14 #include "mlir/Pass/Pass.h"
15 
16 namespace mlir {
17 namespace bufferization {
18 #define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMS
19 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
20 } // namespace bufferization
21 } // namespace mlir
22 
23 using namespace mlir;
25 
26 /// Return `true` if the given MemRef type has a fully dynamic layout.
27 static bool hasFullyDynamicLayoutMap(MemRefType type) {
28  int64_t offset;
30  if (failed(getStridesAndOffset(type, strides, offset)))
31  return false;
32  if (!llvm::all_of(strides, ShapedType::isDynamic))
33  return false;
34  if (!ShapedType::isDynamic(offset))
35  return false;
36  return true;
37 }
38 
39 /// Return `true` if the given MemRef type has a static identity layout (i.e.,
40 /// no layout).
41 static bool hasStaticIdentityLayout(MemRefType type) {
42  return type.getLayout().isIdentity();
43 }
44 
45 // Updates the func op and entry block.
46 //
47 // Any args appended to the entry block are added to `appendedEntryArgs`.
48 // If `addResultAttribute` is true, adds the unit attribute `bufferize.result`
49 // to each newly created function argument.
50 static LogicalResult
51 updateFuncOp(func::FuncOp func,
52  SmallVectorImpl<BlockArgument> &appendedEntryArgs,
53  bool addResultAttribute) {
54  auto functionType = func.getFunctionType();
55 
56  // Collect information about the results will become appended arguments.
57  SmallVector<Type, 6> erasedResultTypes;
58  BitVector erasedResultIndices(functionType.getNumResults());
59  for (const auto &resultType : llvm::enumerate(functionType.getResults())) {
60  if (auto memrefType = dyn_cast<MemRefType>(resultType.value())) {
61  if (!hasStaticIdentityLayout(memrefType) &&
62  !hasFullyDynamicLayoutMap(memrefType)) {
63  // Only buffers with static identity layout can be allocated. These can
64  // be casted to memrefs with fully dynamic layout map. Other layout maps
65  // are not supported.
66  return func->emitError()
67  << "cannot create out param for result with unsupported layout";
68  }
69  erasedResultIndices.set(resultType.index());
70  erasedResultTypes.push_back(memrefType);
71  }
72  }
73 
74  // Add the new arguments to the function type.
75  auto newArgTypes = llvm::to_vector<6>(
76  llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
77  auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes,
78  functionType.getResults());
79  func.setType(newFunctionType);
80 
81  // Transfer the result attributes to arg attributes.
82  auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
83  for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
84  func.setArgAttrs(functionType.getNumInputs() + i,
85  func.getResultAttrs(*erasedIndicesIt));
86  if (addResultAttribute)
87  func.setArgAttr(functionType.getNumInputs() + i,
88  StringAttr::get(func.getContext(), "bufferize.result"),
89  UnitAttr::get(func.getContext()));
90  }
91 
92  // Erase the results.
93  func.eraseResults(erasedResultIndices);
94 
95  // Add the new arguments to the entry block if the function is not external.
96  if (func.isExternal())
97  return success();
98  Location loc = func.getLoc();
99  for (Type type : erasedResultTypes)
100  appendedEntryArgs.push_back(func.front().addArgument(type, loc));
101 
102  return success();
103 }
104 
105 // Updates all ReturnOps in the scope of the given func::FuncOp by either
106 // keeping them as return values or copying the associated buffer contents into
107 // the given out-params.
108 static LogicalResult updateReturnOps(func::FuncOp func,
109  ArrayRef<BlockArgument> appendedEntryArgs,
110  MemCpyFn memCpyFn,
111  bool hoistStaticAllocs) {
112  auto res = func.walk([&](func::ReturnOp op) {
113  SmallVector<Value, 6> copyIntoOutParams;
114  SmallVector<Value, 6> keepAsReturnOperands;
115  for (Value operand : op.getOperands()) {
116  if (isa<MemRefType>(operand.getType()))
117  copyIntoOutParams.push_back(operand);
118  else
119  keepAsReturnOperands.push_back(operand);
120  }
121  OpBuilder builder(op);
122  for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
123  if (hoistStaticAllocs && isa<memref::AllocOp>(orig.getDefiningOp()) &&
124  mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
125  orig.replaceAllUsesWith(arg);
126  orig.getDefiningOp()->erase();
127  } else {
128  if (failed(memCpyFn(builder, op.getLoc(), orig, arg)))
129  return WalkResult::interrupt();
130  }
131  }
132  builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
133  op.erase();
134  return WalkResult::advance();
135  });
136  return failure(res.wasInterrupted());
137 }
138 
139 // Updates all CallOps in the scope of the given ModuleOp by allocating
140 // temporary buffers for newly introduced out params.
141 static LogicalResult
142 updateCalls(ModuleOp module,
144  bool didFail = false;
145  SymbolTable symtab(module);
146  module.walk([&](func::CallOp op) {
147  auto callee = symtab.lookup<func::FuncOp>(op.getCallee());
148  if (!callee) {
149  op.emitError() << "cannot find callee '" << op.getCallee() << "' in "
150  << "symbol table";
151  didFail = true;
152  return;
153  }
154  if (!options.filterFn(&callee))
155  return;
156  SmallVector<Value, 6> replaceWithNewCallResults;
157  SmallVector<Value, 6> replaceWithOutParams;
158  for (OpResult result : op.getResults()) {
159  if (isa<MemRefType>(result.getType()))
160  replaceWithOutParams.push_back(result);
161  else
162  replaceWithNewCallResults.push_back(result);
163  }
164  SmallVector<Value, 6> outParams;
165  OpBuilder builder(op);
166  for (Value memref : replaceWithOutParams) {
167  if (!cast<MemRefType>(memref.getType()).hasStaticShape()) {
168  op.emitError()
169  << "cannot create out param for dynamically shaped result";
170  didFail = true;
171  return;
172  }
173  auto memrefType = cast<MemRefType>(memref.getType());
174  auto allocType =
175  MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
176  AffineMap(), memrefType.getMemorySpace());
177  Value outParam = builder.create<memref::AllocOp>(op.getLoc(), allocType);
178  if (!hasStaticIdentityLayout(memrefType)) {
179  // Layout maps are already checked in `updateFuncOp`.
180  assert(hasFullyDynamicLayoutMap(memrefType) &&
181  "layout map not supported");
182  outParam =
183  builder.create<memref::CastOp>(op.getLoc(), memrefType, outParam);
184  }
185  memref.replaceAllUsesWith(outParam);
186  outParams.push_back(outParam);
187  }
188 
189  auto newOperands = llvm::to_vector<6>(op.getOperands());
190  newOperands.append(outParams.begin(), outParams.end());
191  auto newResultTypes = llvm::to_vector<6>(llvm::map_range(
192  replaceWithNewCallResults, [](Value v) { return v.getType(); }));
193  auto newCall = builder.create<func::CallOp>(op.getLoc(), op.getCalleeAttr(),
194  newResultTypes, newOperands);
195  for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))
196  std::get<0>(t).replaceAllUsesWith(std::get<1>(t));
197  op.erase();
198  });
199 
200  return failure(didFail);
201 }
202 
204  ModuleOp module,
206  for (auto func : module.getOps<func::FuncOp>()) {
207  if (!options.filterFn(&func))
208  continue;
209  SmallVector<BlockArgument, 6> appendedEntryArgs;
210  if (failed(
211  updateFuncOp(func, appendedEntryArgs, options.addResultAttribute)))
212  return failure();
213  if (func.isExternal())
214  continue;
215  auto defaultMemCpyFn = [](OpBuilder &builder, Location loc, Value from,
216  Value to) {
217  builder.create<memref::CopyOp>(loc, from, to);
218  return success();
219  };
220  if (failed(updateReturnOps(func, appendedEntryArgs,
221  options.memCpyFn.value_or(defaultMemCpyFn),
222  options.hoistStaticAllocs))) {
223  return failure();
224  }
225  }
226  if (failed(updateCalls(module, options)))
227  return failure();
228  return success();
229 }
230 
231 namespace {
232 struct BufferResultsToOutParamsPass
233  : bufferization::impl::BufferResultsToOutParamsBase<
234  BufferResultsToOutParamsPass> {
235  explicit BufferResultsToOutParamsPass(
237  : options(options) {}
238 
239  void runOnOperation() override {
240  // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
241  if (addResultAttribute)
242  options.addResultAttribute = true;
243  if (hoistStaticAllocs)
244  options.hoistStaticAllocs = true;
245 
246  if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
247  options)))
248  return signalPassFailure();
249  }
250 
251 private:
253 };
254 } // namespace
255 
258  return std::make_unique<BufferResultsToOutParamsPass>(options);
259 }
bufferization::BufferResultsToOutParamsOpts::MemCpyFn MemCpyFn
static LogicalResult updateFuncOp(func::FuncOp func, SmallVectorImpl< BlockArgument > &appendedEntryArgs, bool addResultAttribute)
static bool hasStaticIdentityLayout(MemRefType type)
Return true if the given MemRef type has a static identity layout (i.e., no layout).
static LogicalResult updateCalls(ModuleOp module, const bufferization::BufferResultsToOutParamsOpts &options)
static LogicalResult updateReturnOps(func::FuncOp func, ArrayRef< BlockArgument > appendedEntryArgs, MemCpyFn memCpyFn, bool hoistStaticAllocs)
static bool hasFullyDynamicLayoutMap(MemRefType type)
Return true if the given MemRef type has a fully dynamic layout.
static llvm::ManagedStatic< PassManagerOptions > options
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This is a value defined by a result of an operation.
Definition: Value.h:457
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Definition: Operation.h:272
result_range getResults()
Definition: Operation.h:410
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
static WalkResult advance()
Definition: Visitors.h:51
LogicalResult promoteBufferResultsToOutParams(ModuleOp module, const BufferResultsToOutParamsOpts &options)
Replace buffers that are returned from a function with an out parameter.
std::unique_ptr< Pass > createBufferResultsToOutParamsPass(const BufferResultsToOutParamsOpts &options={})
Creates a pass that converts memref function results to out-params.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::function< LogicalResult(OpBuilder &, Location, Value, Value)> MemCpyFn
Memcpy function: Generate a memcpy between two memrefs.
Definition: Passes.h:154