MLIR  18.0.0git
BufferResultsToOutParams.cpp
Go to the documentation of this file.
1 //===- BufferResultsToOutParams.cpp - Calling convention conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
13 #include "mlir/IR/Operation.h"
14 #include "mlir/Pass/Pass.h"
15 
16 namespace mlir {
17 namespace bufferization {
18 #define GEN_PASS_DEF_BUFFERRESULTSTOOUTPARAMS
19 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
20 } // namespace bufferization
21 } // namespace mlir
22 
23 using namespace mlir;
24 
25 /// Return `true` if the given MemRef type has a fully dynamic layout.
26 static bool hasFullyDynamicLayoutMap(MemRefType type) {
27  int64_t offset;
29  if (failed(getStridesAndOffset(type, strides, offset)))
30  return false;
31  if (!llvm::all_of(strides, ShapedType::isDynamic))
32  return false;
33  if (!ShapedType::isDynamic(offset))
34  return false;
35  return true;
36 }
37 
38 /// Return `true` if the given MemRef type has a static identity layout (i.e.,
39 /// no layout).
40 static bool hasStaticIdentityLayout(MemRefType type) {
41  return type.getLayout().isIdentity();
42 }
43 
44 // Updates the func op and entry block.
45 //
46 // Any args appended to the entry block are added to `appendedEntryArgs`.
47 static LogicalResult
48 updateFuncOp(func::FuncOp func,
49  SmallVectorImpl<BlockArgument> &appendedEntryArgs) {
50  auto functionType = func.getFunctionType();
51 
52  // Collect information about the results will become appended arguments.
53  SmallVector<Type, 6> erasedResultTypes;
54  BitVector erasedResultIndices(functionType.getNumResults());
55  for (const auto &resultType : llvm::enumerate(functionType.getResults())) {
56  if (auto memrefType = dyn_cast<MemRefType>(resultType.value())) {
57  if (!hasStaticIdentityLayout(memrefType) &&
58  !hasFullyDynamicLayoutMap(memrefType)) {
59  // Only buffers with static identity layout can be allocated. These can
60  // be casted to memrefs with fully dynamic layout map. Other layout maps
61  // are not supported.
62  return func->emitError()
63  << "cannot create out param for result with unsupported layout";
64  }
65  erasedResultIndices.set(resultType.index());
66  erasedResultTypes.push_back(memrefType);
67  }
68  }
69 
70  // Add the new arguments to the function type.
71  auto newArgTypes = llvm::to_vector<6>(
72  llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
73  auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes,
74  functionType.getResults());
75  func.setType(newFunctionType);
76 
77  // Transfer the result attributes to arg attributes.
78  auto erasedIndicesIt = erasedResultIndices.set_bits_begin();
79  for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) {
80  func.setArgAttrs(functionType.getNumInputs() + i,
81  func.getResultAttrs(*erasedIndicesIt));
82  }
83 
84  // Erase the results.
85  func.eraseResults(erasedResultIndices);
86 
87  // Add the new arguments to the entry block if the function is not external.
88  if (func.isExternal())
89  return success();
90  Location loc = func.getLoc();
91  for (Type type : erasedResultTypes)
92  appendedEntryArgs.push_back(func.front().addArgument(type, loc));
93 
94  return success();
95 }
96 
97 // Updates all ReturnOps in the scope of the given func::FuncOp by either
98 // keeping them as return values or copying the associated buffer contents into
99 // the given out-params.
100 static void updateReturnOps(func::FuncOp func,
101  ArrayRef<BlockArgument> appendedEntryArgs) {
102  func.walk([&](func::ReturnOp op) {
103  SmallVector<Value, 6> copyIntoOutParams;
104  SmallVector<Value, 6> keepAsReturnOperands;
105  for (Value operand : op.getOperands()) {
106  if (isa<MemRefType>(operand.getType()))
107  copyIntoOutParams.push_back(operand);
108  else
109  keepAsReturnOperands.push_back(operand);
110  }
111  OpBuilder builder(op);
112  for (auto t : llvm::zip(copyIntoOutParams, appendedEntryArgs))
113  builder.create<memref::CopyOp>(op.getLoc(), std::get<0>(t),
114  std::get<1>(t));
115  builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands);
116  op.erase();
117  });
118 }
119 
120 // Updates all CallOps in the scope of the given ModuleOp by allocating
121 // temporary buffers for newly introduced out params.
122 static LogicalResult
123 updateCalls(ModuleOp module,
125  bool didFail = false;
126  SymbolTable symtab(module);
127  module.walk([&](func::CallOp op) {
128  auto callee = symtab.lookup<func::FuncOp>(op.getCallee());
129  if (!callee) {
130  op.emitError() << "cannot find callee '" << op.getCallee() << "' in "
131  << "symbol table";
132  didFail = true;
133  return;
134  }
135  if (!options.filterFn(&callee))
136  return;
137  SmallVector<Value, 6> replaceWithNewCallResults;
138  SmallVector<Value, 6> replaceWithOutParams;
139  for (OpResult result : op.getResults()) {
140  if (isa<MemRefType>(result.getType()))
141  replaceWithOutParams.push_back(result);
142  else
143  replaceWithNewCallResults.push_back(result);
144  }
145  SmallVector<Value, 6> outParams;
146  OpBuilder builder(op);
147  for (Value memref : replaceWithOutParams) {
148  if (!cast<MemRefType>(memref.getType()).hasStaticShape()) {
149  op.emitError()
150  << "cannot create out param for dynamically shaped result";
151  didFail = true;
152  return;
153  }
154  auto memrefType = cast<MemRefType>(memref.getType());
155  auto allocType =
156  MemRefType::get(memrefType.getShape(), memrefType.getElementType(),
157  AffineMap(), memrefType.getMemorySpace());
158  Value outParam = builder.create<memref::AllocOp>(op.getLoc(), allocType);
159  if (!hasStaticIdentityLayout(memrefType)) {
160  // Layout maps are already checked in `updateFuncOp`.
161  assert(hasFullyDynamicLayoutMap(memrefType) &&
162  "layout map not supported");
163  outParam =
164  builder.create<memref::CastOp>(op.getLoc(), memrefType, outParam);
165  }
166  memref.replaceAllUsesWith(outParam);
167  outParams.push_back(outParam);
168  }
169 
170  auto newOperands = llvm::to_vector<6>(op.getOperands());
171  newOperands.append(outParams.begin(), outParams.end());
172  auto newResultTypes = llvm::to_vector<6>(llvm::map_range(
173  replaceWithNewCallResults, [](Value v) { return v.getType(); }));
174  auto newCall = builder.create<func::CallOp>(op.getLoc(), op.getCalleeAttr(),
175  newResultTypes, newOperands);
176  for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults()))
177  std::get<0>(t).replaceAllUsesWith(std::get<1>(t));
178  op.erase();
179  });
180 
181  return failure(didFail);
182 }
183 
185  ModuleOp module,
187  for (auto func : module.getOps<func::FuncOp>()) {
188  if (!options.filterFn(&func))
189  continue;
190  SmallVector<BlockArgument, 6> appendedEntryArgs;
191  if (failed(updateFuncOp(func, appendedEntryArgs)))
192  return failure();
193  if (func.isExternal())
194  continue;
195  updateReturnOps(func, appendedEntryArgs);
196  }
197  if (failed(updateCalls(module, options)))
198  return failure();
199  return success();
200 }
201 
202 namespace {
203 struct BufferResultsToOutParamsPass
204  : bufferization::impl::BufferResultsToOutParamsBase<
205  BufferResultsToOutParamsPass> {
206  explicit BufferResultsToOutParamsPass(
208  : options(options) {}
209 
210  void runOnOperation() override {
212  options)))
213  return signalPassFailure();
214  }
215 
216 private:
218 };
219 } // namespace
220 
223  return std::make_unique<BufferResultsToOutParamsPass>(options);
224 }
static bool hasStaticIdentityLayout(MemRefType type)
Return true if the given MemRef type has a static identity layout (i.e., no layout).
static LogicalResult updateFuncOp(func::FuncOp func, SmallVectorImpl< BlockArgument > &appendedEntryArgs)
static LogicalResult updateCalls(ModuleOp module, const bufferization::BufferResultsToOutParamsOptions &options)
static bool hasFullyDynamicLayoutMap(MemRefType type)
Return true if the given MemRef type has a fully dynamic layout.
static void updateReturnOps(func::FuncOp func, ArrayRef< BlockArgument > appendedEntryArgs)
static llvm::ManagedStatic< PassManagerOptions > options
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:44
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:206
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This is a value defined by a result of an operation.
Definition: Value.h:448
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:66
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
void replaceAllUsesWith(ValuesT &&values)
Replace all uses of results of this operation with the provided 'values'.
Definition: Operation.h:272
result_range getResults()
Definition: Operation.h:410
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition: SymbolTable.h:24
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
LogicalResult promoteBufferResultsToOutParams(ModuleOp module, const BufferResultsToOutParamsOptions &options)
Replace buffers that are returned from a function with an out parameter.
std::unique_ptr< Pass > createBufferResultsToOutParamsPass(const BufferResultsToOutParamsOptions &options={})
Creates a pass that converts memref function results to out-params.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26