MLIR  22.0.0git
Utils.cpp
Go to the documentation of this file.
1 //===- Utils.cpp - Utilities to support the Func dialect ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the Func dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/IR/IRMapping.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/Support/DebugLog.h"
20 
21 #define DEBUG_TYPE "func-utils"
22 
23 using namespace mlir;
24 
25 /// This method creates an inverse mapping of the provided map `oldToNew`.
26 /// Given an array where `oldIdxToNewIdx[i] = j` means old index `i` maps
27 /// to new index `j`,
28 /// This method returns a vector where `result[j]` contains all old indices
29 /// that map to new index `j`.
30 ///
31 /// Example:
32 /// ```
33 /// oldIdxToNewIdx = [0, 1, 2, 2, 3]
34 /// getInverseMapping(oldIdxToNewIdx) = [[0], [1], [2, 3], [4]]
35 /// ```
36 ///
39  int numOfNewIdxs = 0;
40  if (!oldIdxToNewIdx.empty())
41  numOfNewIdxs = 1 + *llvm::max_element(oldIdxToNewIdx);
42  llvm::SmallVector<llvm::SmallVector<int>> newToOldIdxs(numOfNewIdxs);
43  for (auto [oldIdx, newIdx] : llvm::enumerate(oldIdxToNewIdx))
44  newToOldIdxs[newIdx].push_back(oldIdx);
45  return newToOldIdxs;
46 }
47 
48 /// This method returns a new vector of elements that are mapped from the
49 /// `origElements` based on the `newIdxToOldIdxs` mapping. This function assumes
50 /// that the `newIdxToOldIdxs` mapping is valid, i.e. for each new index, there
51 /// is at least one old index that maps to it. Also, It assumes that mapping to
52 /// the same old index has the same element in the `origElements` vector.
53 template <typename Element>
55  ArrayRef<Element> origElements,
56  const llvm::SmallVector<llvm::SmallVector<int>> &newIdxToOldIdxs) {
57  SmallVector<Element> newElements;
58  for (const auto &oldIdxs : newIdxToOldIdxs) {
59  assert(llvm::all_of(oldIdxs,
60  [&origElements](int idx) -> bool {
61  return idx >= 0 &&
62  static_cast<size_t>(idx) < origElements.size();
63  }) &&
64  "idx must be less than the number of elements in the original "
65  "elements");
66  assert(!oldIdxs.empty() && "oldIdx must not be empty");
67  Element origTypeToCheck = origElements[oldIdxs.front()];
68  assert(llvm::all_of(oldIdxs,
69  [&](int idx) -> bool {
70  return origElements[idx] == origTypeToCheck;
71  }) &&
72  "all oldIdxs must be equal");
73  newElements.push_back(origTypeToCheck);
74  }
75  return newElements;
76 }
77 
78 FailureOr<func::FuncOp>
79 func::replaceFuncWithNewMapping(RewriterBase &rewriter, func::FuncOp funcOp,
80  ArrayRef<int> oldArgIdxToNewArgIdx,
81  ArrayRef<int> oldResIdxToNewResIdx) {
82  // Generate an empty new function operation with the same name as the
83  // original.
84  assert(funcOp.getNumArguments() == oldArgIdxToNewArgIdx.size() &&
85  "oldArgIdxToNewArgIdx must match the number of arguments in the "
86  "function");
87  assert(
88  funcOp.getNumResults() == oldResIdxToNewResIdx.size() &&
89  "oldResIdxToNewResIdx must match the number of results in the function");
90 
91  if (!funcOp.getBody().hasOneBlock())
92  return rewriter.notifyMatchFailure(
93  funcOp, "expected function to have exactly one block");
94 
95  // We may have some duplicate arguments in the old function, i.e.
96  // in the mapping `newArgIdxToOldArgIdxs` for some new argument index
97  // there may be multiple old argument indices.
98  llvm::SmallVector<llvm::SmallVector<int>> newArgIdxToOldArgIdxs =
99  getInverseMapping(oldArgIdxToNewArgIdx);
100  SmallVector<Type> newInputTypes = getMappedElements(
101  funcOp.getFunctionType().getInputs(), newArgIdxToOldArgIdxs);
102 
104  for (const auto &oldArgIdxs : newArgIdxToOldArgIdxs)
105  locs.push_back(funcOp.getArgument(oldArgIdxs.front()).getLoc());
106 
107  llvm::SmallVector<llvm::SmallVector<int>> newResToOldResIdxs =
108  getInverseMapping(oldResIdxToNewResIdx);
109  SmallVector<Type> newOutputTypes = getMappedElements(
110  funcOp.getFunctionType().getResults(), newResToOldResIdxs);
111 
112  rewriter.setInsertionPoint(funcOp);
113  auto newFuncOp = func::FuncOp::create(
114  rewriter, funcOp.getLoc(), funcOp.getName(),
115  rewriter.getFunctionType(newInputTypes, newOutputTypes));
116 
117  Region &newRegion = newFuncOp.getBody();
118  rewriter.createBlock(&newRegion, newRegion.begin(), newInputTypes, locs);
119  newFuncOp.setVisibility(funcOp.getVisibility());
120 
121  // Map the arguments of the original function to the new function in
122  // the new order and adjust the attributes accordingly.
123  IRMapping operandMapper;
124  SmallVector<DictionaryAttr> argAttrs, resultAttrs;
125  funcOp.getAllArgAttrs(argAttrs);
126  for (auto [oldArgIdx, newArgIdx] : llvm::enumerate(oldArgIdxToNewArgIdx))
127  operandMapper.map(funcOp.getArgument(oldArgIdx),
128  newFuncOp.getArgument(newArgIdx));
129  for (auto [newArgIdx, oldArgIdx] : llvm::enumerate(newArgIdxToOldArgIdxs))
130  newFuncOp.setArgAttrs(newArgIdx, argAttrs[oldArgIdx.front()]);
131 
132  funcOp.getAllResultAttrs(resultAttrs);
133  for (auto [newResIdx, oldResIdx] : llvm::enumerate(newResToOldResIdxs))
134  newFuncOp.setResultAttrs(newResIdx, resultAttrs[oldResIdx.front()]);
135 
136  // Clone the operations from the original function to the new function.
137  rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
138  for (Operation &op : funcOp.getOps())
139  rewriter.clone(op, operandMapper);
140 
141  // Handle the return operation.
142  auto returnOp = cast<func::ReturnOp>(
143  newFuncOp.getFunctionBody().begin()->getTerminator());
144  SmallVector<Value> newReturnValues;
145  for (const auto &oldResIdxs : newResToOldResIdxs)
146  newReturnValues.push_back(returnOp.getOperand(oldResIdxs.front()));
147 
148  rewriter.setInsertionPoint(returnOp);
149  func::ReturnOp::create(rewriter, newFuncOp.getLoc(), newReturnValues);
150  rewriter.eraseOp(returnOp);
151 
152  rewriter.eraseOp(funcOp);
153 
154  return newFuncOp;
155 }
156 
157 func::CallOp
158 func::replaceCallOpWithNewMapping(RewriterBase &rewriter, func::CallOp callOp,
159  ArrayRef<int> oldArgIdxToNewArgIdx,
160  ArrayRef<int> oldResIdxToNewResIdx) {
161  assert(callOp.getNumOperands() == oldArgIdxToNewArgIdx.size() &&
162  "oldArgIdxToNewArgIdx must match the number of operands in the call "
163  "operation");
164  assert(callOp.getNumResults() == oldResIdxToNewResIdx.size() &&
165  "oldResIdxToNewResIdx must match the number of results in the call "
166  "operation");
167 
168  SmallVector<Value> origOperands = callOp.getOperands();
169  SmallVector<llvm::SmallVector<int>> newArgIdxToOldArgIdxs =
170  getInverseMapping(oldArgIdxToNewArgIdx);
171  SmallVector<Value> newOperandsValues =
172  getMappedElements<Value>(origOperands, newArgIdxToOldArgIdxs);
173  SmallVector<llvm::SmallVector<int>> newResToOldResIdxs =
174  getInverseMapping(oldResIdxToNewResIdx);
175  SmallVector<Type> origResultTypes = llvm::to_vector(callOp.getResultTypes());
176  SmallVector<Type> newResultTypes =
177  getMappedElements<Type>(origResultTypes, newResToOldResIdxs);
178 
179  // Replace the kernel call operation with a new one that has the
180  // mapped arguments.
181  rewriter.setInsertionPoint(callOp);
182  auto newCallOp =
183  func::CallOp::create(rewriter, callOp.getLoc(), callOp.getCallee(),
184  newResultTypes, newOperandsValues);
185  newCallOp.setNoInlineAttr(callOp.getNoInlineAttr());
186  for (auto &&[oldResIdx, newResIdx] : llvm::enumerate(oldResIdxToNewResIdx))
187  rewriter.replaceAllUsesWith(callOp.getResult(oldResIdx),
188  newCallOp.getResult(newResIdx));
189  rewriter.eraseOp(callOp);
190 
191  return newCallOp;
192 }
193 
194 FailureOr<std::pair<func::FuncOp, func::CallOp>>
195 func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp,
196  ModuleOp moduleOp) {
198  auto traversalResult = moduleOp.walk([&](func::CallOp callOp) {
199  if (callOp.getCallee() == funcOp.getSymName()) {
200  if (!callOps.empty())
201  // Only support one callOp for now
202  return WalkResult::interrupt();
203  callOps.push_back(callOp);
204  }
205  return WalkResult::advance();
206  });
207 
208  if (traversalResult.wasInterrupted()) {
209  LDBG() << "function " << funcOp.getName() << " has more than one callOp";
210  return failure();
211  }
212 
213  if (callOps.empty()) {
214  LDBG() << "function " << funcOp.getName() << " does not have any callOp";
215  return failure();
216  }
217 
218  func::CallOp callOp = callOps.front();
219 
220  // Create mapping for arguments (deduplicate operands)
221  SmallVector<int> oldArgIdxToNewArgIdx(callOp.getNumOperands());
222  llvm::DenseMap<Value, int> valueToNewArgIdx;
223  for (auto [operandIdx, operand] : llvm::enumerate(callOp.getOperands())) {
224  auto [iterator, inserted] = valueToNewArgIdx.insert(
225  {operand, static_cast<int>(valueToNewArgIdx.size())});
226  // Reduce the duplicate operands and maintain the original order.
227  oldArgIdxToNewArgIdx[operandIdx] = iterator->second;
228  }
229 
230  bool hasDuplicateOperands =
231  valueToNewArgIdx.size() != callOp.getNumOperands();
232  if (!hasDuplicateOperands) {
233  LDBG() << "function " << funcOp.getName()
234  << " does not have duplicate operands";
235  return failure();
236  }
237 
238  // Create identity mapping for results (no deduplication needed)
239  SmallVector<int> oldResIdxToNewResIdx(callOp.getNumResults());
240  for (int resultIdx : llvm::seq<int>(0, callOp.getNumResults()))
241  oldResIdxToNewResIdx[resultIdx] = resultIdx;
242 
243  // Apply the transformation to create new function and call operations
244  FailureOr<func::FuncOp> newFuncOpOrFailure = replaceFuncWithNewMapping(
245  rewriter, funcOp, oldArgIdxToNewArgIdx, oldResIdxToNewResIdx);
246  if (failed(newFuncOpOrFailure)) {
247  LDBG() << "failed to replace function signature with name "
248  << funcOp.getName() << " with new order";
249  return failure();
250  }
251 
252  func::CallOp newCallOp = replaceCallOpWithNewMapping(
253  rewriter, callOp, oldArgIdxToNewArgIdx, oldResIdxToNewResIdx);
254 
255  return std::make_pair(*newFuncOpOrFailure, newCallOp);
256 }
static SmallVector< Element > getMappedElements(ArrayRef< Element > origElements, const llvm::SmallVector< llvm::SmallVector< int >> &newIdxToOldIdxs)
This method returns a new vector of elements that are mapped from the origElements based on the newId...
Definition: Utils.cpp:54
static llvm::SmallVector< llvm::SmallVector< int > > getInverseMapping(ArrayRef< int > oldIdxToNewIdx)
This method creates an inverse mapping of the provided map oldToNew.
Definition: Utils.cpp:38
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:75
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:429
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:552
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator begin()
Definition: Region.h:55
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:646
static WalkResult advance()
Definition: WalkResult.h:47
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
mlir::func::CallOp replaceCallOpWithNewMapping(mlir::RewriterBase &rewriter, mlir::func::CallOp callOp, ArrayRef< int > oldArgIdxToNewArgIdx, ArrayRef< int > oldResIdxToNewResIdx)
Creates a new call operation with the values as the original call operation, but with the arguments m...
mlir::FailureOr< mlir::func::FuncOp > replaceFuncWithNewMapping(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp, ArrayRef< int > oldArgIdxToNewArgIdx, ArrayRef< int > oldResIdxToNewResIdx)
Creates a new function operation with the same name as the original function operation,...
mlir::FailureOr< std::pair< mlir::func::FuncOp, mlir::func::CallOp > > deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp, mlir::ModuleOp moduleOp)
This utility function examines all call operations within the given moduleOp that target the specifie...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.