MLIR  21.0.0git
BufferResultsToOutParams.cpp
Go to the documentation of this file.
1 //===- BufferResultsToOutParams.cpp - Calling convention conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 
14 #include "mlir/IR/Operation.h"
15 #include "mlir/Pass/Pass.h"
16 
17 namespace mlir {
18 namespace bufferization {
19 #define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMSPASS
20 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
21 } // namespace bufferization
22 } // namespace mlir
23 
24 using namespace mlir;
27 
28 /// Return `true` if the given MemRef type has a fully dynamic layout.
29 static bool hasFullyDynamicLayoutMap(MemRefType type) {
30  int64_t offset;
32  if (failed(type.getStridesAndOffset(strides, offset)))
33  return false;
34  if (!llvm::all_of(strides, ShapedType::isDynamic))
35  return false;
36  if (!ShapedType::isDynamic(offset))
37  return false;
38  return true;
39 }
40 
41 /// Return `true` if the given MemRef type has a static identity layout (i.e.,
42 /// no layout).
43 static bool hasStaticIdentityLayout(MemRefType type) {
44  return type.getLayout().isIdentity();
45 }
46 
47 // Updates the func op and entry block.
48 //
49 // Any args appended to the entry block are added to `appendedEntryArgs`.
50 // If `addResultAttribute` is true, adds the unit attribute `bufferize.result`
51 // to each newly created function argument.
52 static LogicalResult
53 updateFuncOp(func::FuncOp func,
54  SmallVectorImpl<BlockArgument> &appendedEntryArgs,
55  bool addResultAttribute) {
56  auto functionType = func.getFunctionType();
57 
58  // Collect information about the results will become appended arguments.
59  SmallVector<Type, 6> erasedResultTypes;
60  BitVector erasedResultIndices(functionType.getNumResults());
61  for (const auto &resultType : llvm::enumerate(functionType.getResults())) {
62  if (auto memrefType = dyn_cast<MemRefType>(resultType.value())) {
63  if (!hasStaticIdentityLayout(memrefType) &&
64  !hasFullyDynamicLayoutMap(memrefType)) {
65  // Only buffers with static identity layout can be allocated. These can
66  // be casted to memrefs with fully dynamic layout map. Other layout maps
67  // are not supported.
68  return func->emitError()
69  << "cannot create out param for result with unsupported layout";
70  }
71  erasedResultIndices.set(resultType.index());
72  erasedResultTypes.push_back(memrefType);
73  }
74  }
75 
76  // Add the new arguments to the function type.
77  auto newArgTypes = llvm::to_vector<6>(
78  llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
79  auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes,
80  functionType.getResults());
81  func.setType(newFunctionType);
82 
83  // Transfer the result attributes to arg attributes.
84  auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
85  for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
86  func.setArgAttrs(functionType.getNumInputs() + i,
87  func.getResultAttrs(*erasedIndicesIt));
88  if (addResultAttribute)
89  func.setArgAttr(functionType.getNumInputs() + i,
90  StringAttr::get(func.getContext(), "bufferize.result"),
91  UnitAttr::get(func.getContext()));
92  }
93 
94  // Erase the results.
95  if (failed(func.eraseResults(erasedResultIndices)))
96  return failure();
97 
98  // Add the new arguments to the entry block if the function is not external.
99  if (func.isExternal())
100  return success();
101  Location loc = func.getLoc();
102  for (Type type : erasedResultTypes)
103  appendedEntryArgs.push_back(func.front().addArgument(type, loc));
104 
105  return success();
106 }
107 
108 // Updates all ReturnOps in the scope of the given func::FuncOp by either
109 // keeping them as return values or copying the associated buffer contents into
110 // the given out-params.
111 static LogicalResult
112 updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs,
114  auto res = func.walk([&](func::ReturnOp op) {
115  SmallVector<Value, 6> copyIntoOutParams;
116  SmallVector<Value, 6> keepAsReturnOperands;
117  for (Value operand : op.getOperands()) {
118  if (isa<MemRefType>(operand.getType()))
119  copyIntoOutParams.push_back(operand);
120  else
121  keepAsReturnOperands.push_back(operand);
122  }
123  OpBuilder builder(op);
124  for (auto [orig, arg] : llvm::zip(copyIntoOutParams, appendedEntryArgs)) {
125  if (options.hoistStaticAllocs &&
126  isa_and_nonnull<bufferization::AllocationOpInterface>(
127  orig.getDefiningOp()) &&
128  mlir::cast<MemRefType>(orig.getType()).hasStaticShape()) {
129  orig.replaceAllUsesWith(arg);
130  orig.getDefiningOp()->erase();
131  } else {
132  if (failed(options.memCpyFn(builder, op.getLoc(), orig, arg)))
133  return WalkResult::interrupt();
134  }
135  }
136  builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
137  op.erase();
138  return WalkResult::advance();
139  });
140  return failure(res.wasInterrupted());
141 }
142 
143 // Updates all CallOps in the scope of the given ModuleOp by allocating
144 // temporary buffers for newly introduced out params.
145 static LogicalResult
146 updateCalls(ModuleOp module,
148  bool didFail = false;
149  SymbolTable symtab(module);
150  module.walk([&](func::CallOp op) {
151  auto callee = symtab.lookup<func::FuncOp>(op.getCallee());
152  if (!callee) {
153  op.emitError() << "cannot find callee '" << op.getCallee() << "' in "
154  << "symbol table";
155  didFail = true;
156  return;
157  }
158  if (!options.filterFn(&callee))
159  return;
160  SmallVector<Value, 6> replaceWithNewCallResults;
161  SmallVector<Value, 6> replaceWithOutParams;
162  for (OpResult result : op.getResults()) {
163  if (isa<MemRefType>(result.getType()))
164  replaceWithOutParams.push_back(result);
165  else
166  replaceWithNewCallResults.push_back(result);
167  }
168  SmallVector<Value, 6> outParams;
169  OpBuilder builder(op);
170  for (Value memref : replaceWithOutParams) {
171  if (!cast<MemRefType>(memref.getType()).hasStaticShape()) {
172  op.emitError()
173  << "cannot create out param for dynamically shaped result";
174  didFail = true;
175  return;
176  }
177  auto memrefType = cast<MemRefType>(memref.getType());
178  auto allocType =
179  MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
180  AffineMap(), memrefType.getMemorySpace());
181  auto maybeOutParam =
182  options.allocationFn(builder, op.getLoc(), allocType);
183  if (failed(maybeOutParam)) {
184  op.emitError() << "failed to create allocation op";
185  didFail = true;
186  return;
187  }
188  Value outParam = maybeOutParam.value();
189  if (!hasStaticIdentityLayout(memrefType)) {
190  // Layout maps are already checked in `updateFuncOp`.
191  assert(hasFullyDynamicLayoutMap(memrefType) &&
192  "layout map not supported");
193  outParam =
194  builder.create<memref::CastOp>(op.getLoc(), memrefType, outParam);
195  }
196  memref.replaceAllUsesWith(outParam);
197  outParams.push_back(outParam);
198  }
199 
200  auto newOperands = llvm::to_vector<6>(op.getOperands());
201  newOperands.append(outParams.begin(), outParams.end());
202  auto newResultTypes = llvm::to_vector<6>(llvm::map_range(
203  replaceWithNewCallResults, [](Value v) { return v.getType(); }));
204  auto newCall = builder.create<func::CallOp>(op.getLoc(), op.getCalleeAttr(),
205  newResultTypes, newOperands);
206  for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))
207  std::get<0>(t).replaceAllUsesWith(std::get<1>(t));
208  op.erase();
209  });
210 
211  return failure(didFail);
212 }
213 
215  ModuleOp module,
217  for (auto func : module.getOps<func::FuncOp>()) {
218  if (!options.filterFn(&func))
219  continue;
220  SmallVector<BlockArgument, 6> appendedEntryArgs;
221  if (failed(
222  updateFuncOp(func, appendedEntryArgs, options.addResultAttribute)))
223  return failure();
224  if (func.isExternal())
225  continue;
226  if (failed(updateReturnOps(func, appendedEntryArgs, options))) {
227  return failure();
228  }
229  }
230  if (failed(updateCalls(module, options)))
231  return failure();
232  return success();
233 }
234 
235 namespace {
236 struct BufferResultsToOutParamsPass
237  : bufferization::impl::BufferResultsToOutParamsPassBase<
238  BufferResultsToOutParamsPass> {
239  using Base::Base;
240 
241  void runOnOperation() override {
242  // Convert from pass options in tablegen to BufferResultsToOutParamsOpts.
243  if (addResultAttribute)
244  options.addResultAttribute = true;
245  if (hoistStaticAllocs)
246  options.hoistStaticAllocs = true;
247 
248  if (failed(bufferization::promoteBufferResultsToOutParams(getOperation(),
249  options)))
250  return signalPassFailure();
251  }
252 
253 private:
255 };
256 } // namespace
static LogicalResult updateReturnOps(func::FuncOp func, ArrayRef< BlockArgument > appendedEntryArgs, const bufferization::BufferResultsToOutParamsOpts &options)
bufferization::BufferResultsToOutParamsOpts::AllocationFn AllocationFn
bufferization::BufferResultsToOutParamsOpts::MemCpyFn MemCpyFn
static LogicalResult updateFuncOp(func::FuncOp func, SmallVectorImpl< BlockArgument > &appendedEntryArgs, bool addResultAttribute)
static bool hasStaticIdentityLayout(MemRefType type)
Return true if the given MemRef type has a static identity layout (i.e., no layout).
static LogicalResult updateCalls(ModuleOp module, const bufferization::BufferResultsToOutParamsOpts &options)
static bool hasFullyDynamicLayoutMap(MemRefType type)
Return true if the given MemRef type has a fully dynamic layout.
static llvm::ManagedStatic< PassManagerOptions > options
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This is a value defined by a result of an operation.
Definition: Value.h:433
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Definition: Operation.h:272
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static WalkResult advance()
Definition: Visitors.h:51
LogicalResult promoteBufferResultsToOutParams(ModuleOp module, const BufferResultsToOutParamsOpts &options)
Replace buffers that are returned from a function with an out parameter.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::function< LogicalResult(OpBuilder &, Location, Value, Value)> MemCpyFn
Memcpy function: Generate a memcpy between two memrefs.
Definition: Passes.h:137
std::function< FailureOr< Value >(OpBuilder &, Location, MemRefType)> AllocationFn
Allocator function: Generate a memref allocation with the given type.
Definition: Passes.h:133