MLIR 22.0.0git
Utils.cpp
Go to the documentation of this file.
1//===- Utils.cpp - Utilities to support the Func dialect ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for the Func dialect.
10//
11//===----------------------------------------------------------------------===//
12
15#include "mlir/IR/IRMapping.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/SmallVector.h"
19#include "llvm/Support/DebugLog.h"
20
21#define DEBUG_TYPE "func-utils"
22
23using namespace mlir;
24
25/// This method creates an inverse mapping of the provided map `oldToNew`.
26/// Given an array where `oldIdxToNewIdx[i] = j` means old index `i` maps
27/// to new index `j`,
28/// This method returns a vector where `result[j]` contains all old indices
29/// that map to new index `j`.
30///
31/// Example:
32/// ```
33/// oldIdxToNewIdx = [0, 1, 2, 2, 3]
34/// getInverseMapping(oldIdxToNewIdx) = [[0], [1], [2, 3], [4]]
35/// ```
36///
39 int numOfNewIdxs = 0;
40 if (!oldIdxToNewIdx.empty())
41 numOfNewIdxs = 1 + *llvm::max_element(oldIdxToNewIdx);
42 llvm::SmallVector<llvm::SmallVector<int>> newToOldIdxs(numOfNewIdxs);
43 for (auto [oldIdx, newIdx] : llvm::enumerate(oldIdxToNewIdx))
44 newToOldIdxs[newIdx].push_back(oldIdx);
45 return newToOldIdxs;
46}
47
48/// This method returns a new vector of elements that are mapped from the
49/// `origElements` based on the `newIdxToOldIdxs` mapping. This function assumes
50/// that the `newIdxToOldIdxs` mapping is valid, i.e. for each new index, there
51/// is at least one old index that maps to it. Also, It assumes that mapping to
52/// the same old index has the same element in the `origElements` vector.
53template <typename Element>
55 ArrayRef<Element> origElements,
56 const llvm::SmallVector<llvm::SmallVector<int>> &newIdxToOldIdxs) {
57 SmallVector<Element> newElements;
58 for (const auto &oldIdxs : newIdxToOldIdxs) {
59 assert(llvm::all_of(oldIdxs,
60 [&origElements](int idx) -> bool {
61 return idx >= 0 &&
62 static_cast<size_t>(idx) < origElements.size();
63 }) &&
64 "idx must be less than the number of elements in the original "
65 "elements");
66 assert(!oldIdxs.empty() && "oldIdx must not be empty");
67 Element origTypeToCheck = origElements[oldIdxs.front()];
68 assert(llvm::all_of(oldIdxs,
69 [&](int idx) -> bool {
70 return origElements[idx] == origTypeToCheck;
71 }) &&
72 "all oldIdxs must be equal");
73 newElements.push_back(origTypeToCheck);
74 }
75 return newElements;
76}
77
78FailureOr<func::FuncOp>
79func::replaceFuncWithNewMapping(RewriterBase &rewriter, func::FuncOp funcOp,
80 ArrayRef<int> oldArgIdxToNewArgIdx,
81 ArrayRef<int> oldResIdxToNewResIdx) {
82 // Generate an empty new function operation with the same name as the
83 // original.
84 assert(funcOp.getNumArguments() == oldArgIdxToNewArgIdx.size() &&
85 "oldArgIdxToNewArgIdx must match the number of arguments in the "
86 "function");
87 assert(
88 funcOp.getNumResults() == oldResIdxToNewResIdx.size() &&
89 "oldResIdxToNewResIdx must match the number of results in the function");
90
91 if (!funcOp.getBody().hasOneBlock())
92 return rewriter.notifyMatchFailure(
93 funcOp, "expected function to have exactly one block");
94
95 // We may have some duplicate arguments in the old function, i.e.
96 // in the mapping `newArgIdxToOldArgIdxs` for some new argument index
97 // there may be multiple old argument indices.
98 llvm::SmallVector<llvm::SmallVector<int>> newArgIdxToOldArgIdxs =
99 getInverseMapping(oldArgIdxToNewArgIdx);
100 SmallVector<Type> newInputTypes = getMappedElements(
101 funcOp.getFunctionType().getInputs(), newArgIdxToOldArgIdxs);
102
104 for (const auto &oldArgIdxs : newArgIdxToOldArgIdxs)
105 locs.push_back(funcOp.getArgument(oldArgIdxs.front()).getLoc());
106
107 llvm::SmallVector<llvm::SmallVector<int>> newResToOldResIdxs =
108 getInverseMapping(oldResIdxToNewResIdx);
109 SmallVector<Type> newOutputTypes = getMappedElements(
110 funcOp.getFunctionType().getResults(), newResToOldResIdxs);
111
112 rewriter.setInsertionPoint(funcOp);
113 auto newFuncOp = func::FuncOp::create(
114 rewriter, funcOp.getLoc(), funcOp.getName(),
115 rewriter.getFunctionType(newInputTypes, newOutputTypes));
116
117 Region &newRegion = newFuncOp.getBody();
118 rewriter.createBlock(&newRegion, newRegion.begin(), newInputTypes, locs);
119 newFuncOp.setVisibility(funcOp.getVisibility());
120
121 // Map the arguments of the original function to the new function in
122 // the new order and adjust the attributes accordingly.
123 IRMapping operandMapper;
124 SmallVector<DictionaryAttr> argAttrs, resultAttrs;
125 funcOp.getAllArgAttrs(argAttrs);
126 for (auto [oldArgIdx, newArgIdx] : llvm::enumerate(oldArgIdxToNewArgIdx))
127 operandMapper.map(funcOp.getArgument(oldArgIdx),
128 newFuncOp.getArgument(newArgIdx));
129 for (auto [newArgIdx, oldArgIdx] : llvm::enumerate(newArgIdxToOldArgIdxs))
130 newFuncOp.setArgAttrs(newArgIdx, argAttrs[oldArgIdx.front()]);
131
132 funcOp.getAllResultAttrs(resultAttrs);
133 for (auto [newResIdx, oldResIdx] : llvm::enumerate(newResToOldResIdxs))
134 newFuncOp.setResultAttrs(newResIdx, resultAttrs[oldResIdx.front()]);
135
136 // Clone the operations from the original function to the new function.
137 rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
138 for (Operation &op : funcOp.getOps())
139 rewriter.clone(op, operandMapper);
140
141 // Handle the return operation.
142 auto returnOp = cast<func::ReturnOp>(
143 newFuncOp.getFunctionBody().begin()->getTerminator());
144 SmallVector<Value> newReturnValues;
145 for (const auto &oldResIdxs : newResToOldResIdxs)
146 newReturnValues.push_back(returnOp.getOperand(oldResIdxs.front()));
147
148 rewriter.setInsertionPoint(returnOp);
149 func::ReturnOp::create(rewriter, newFuncOp.getLoc(), newReturnValues);
150 rewriter.eraseOp(returnOp);
151
152 rewriter.eraseOp(funcOp);
153
154 return newFuncOp;
155}
156
157func::CallOp
158func::replaceCallOpWithNewMapping(RewriterBase &rewriter, func::CallOp callOp,
159 ArrayRef<int> oldArgIdxToNewArgIdx,
160 ArrayRef<int> oldResIdxToNewResIdx) {
161 assert(callOp.getNumOperands() == oldArgIdxToNewArgIdx.size() &&
162 "oldArgIdxToNewArgIdx must match the number of operands in the call "
163 "operation");
164 assert(callOp.getNumResults() == oldResIdxToNewResIdx.size() &&
165 "oldResIdxToNewResIdx must match the number of results in the call "
166 "operation");
167
168 SmallVector<Value> origOperands = callOp.getOperands();
169 SmallVector<llvm::SmallVector<int>> newArgIdxToOldArgIdxs =
170 getInverseMapping(oldArgIdxToNewArgIdx);
171 SmallVector<Value> newOperandsValues =
172 getMappedElements<Value>(origOperands, newArgIdxToOldArgIdxs);
173 SmallVector<llvm::SmallVector<int>> newResToOldResIdxs =
174 getInverseMapping(oldResIdxToNewResIdx);
175 SmallVector<Type> origResultTypes = llvm::to_vector(callOp.getResultTypes());
176 SmallVector<Type> newResultTypes =
177 getMappedElements<Type>(origResultTypes, newResToOldResIdxs);
178
179 // Replace the kernel call operation with a new one that has the
180 // mapped arguments.
181 rewriter.setInsertionPoint(callOp);
182 auto newCallOp =
183 func::CallOp::create(rewriter, callOp.getLoc(), callOp.getCallee(),
184 newResultTypes, newOperandsValues);
185 newCallOp.setNoInlineAttr(callOp.getNoInlineAttr());
186 for (auto &&[oldResIdx, newResIdx] : llvm::enumerate(oldResIdxToNewResIdx))
187 rewriter.replaceAllUsesWith(callOp.getResult(oldResIdx),
188 newCallOp.getResult(newResIdx));
189 rewriter.eraseOp(callOp);
190
191 return newCallOp;
192}
193
194FailureOr<std::pair<func::FuncOp, func::CallOp>>
195func::deduplicateArgsOfFuncOp(RewriterBase &rewriter, func::FuncOp funcOp,
196 ModuleOp moduleOp) {
198 auto traversalResult = moduleOp.walk([&](func::CallOp callOp) {
199 if (callOp.getCallee() == funcOp.getSymName()) {
200 if (!callOps.empty())
201 // Only support one callOp for now
202 return WalkResult::interrupt();
203 callOps.push_back(callOp);
204 }
205 return WalkResult::advance();
206 });
207
208 if (traversalResult.wasInterrupted()) {
209 LDBG() << "function " << funcOp.getName() << " has more than one callOp";
210 return failure();
211 }
212
213 if (callOps.empty()) {
214 LDBG() << "function " << funcOp.getName() << " does not have any callOp";
215 return failure();
216 }
217
218 func::CallOp callOp = callOps.front();
219
220 // Create mapping for arguments (deduplicate operands)
221 SmallVector<int> oldArgIdxToNewArgIdx(callOp.getNumOperands());
222 llvm::DenseMap<Value, int> valueToNewArgIdx;
223 for (auto [operandIdx, operand] : llvm::enumerate(callOp.getOperands())) {
224 auto [iterator, inserted] = valueToNewArgIdx.insert(
225 {operand, static_cast<int>(valueToNewArgIdx.size())});
226 // Reduce the duplicate operands and maintain the original order.
227 oldArgIdxToNewArgIdx[operandIdx] = iterator->second;
228 }
229
230 bool hasDuplicateOperands =
231 valueToNewArgIdx.size() != callOp.getNumOperands();
232 if (!hasDuplicateOperands) {
233 LDBG() << "function " << funcOp.getName()
234 << " does not have duplicate operands";
235 return failure();
236 }
237
238 // Create identity mapping for results (no deduplication needed)
239 SmallVector<int> oldResIdxToNewResIdx(callOp.getNumResults());
240 for (int resultIdx : llvm::seq<int>(0, callOp.getNumResults()))
241 oldResIdxToNewResIdx[resultIdx] = resultIdx;
242
243 // Apply the transformation to create new function and call operations
244 FailureOr<func::FuncOp> newFuncOpOrFailure = replaceFuncWithNewMapping(
245 rewriter, funcOp, oldArgIdxToNewArgIdx, oldResIdxToNewResIdx);
246 if (failed(newFuncOpOrFailure)) {
247 LDBG() << "failed to replace function signature with name "
248 << funcOp.getName() << " with new order";
249 return failure();
250 }
251
252 func::CallOp newCallOp = replaceCallOpWithNewMapping(
253 rewriter, callOp, oldArgIdxToNewArgIdx, oldResIdxToNewResIdx);
254
255 return std::make_pair(*newFuncOpOrFailure, newCallOp);
256}
257
258FailureOr<func::FuncOp>
259func::lookupFnDecl(SymbolOpInterface symTable, StringRef name,
260 FunctionType funcT, SymbolTableCollection *symbolTables) {
261 FuncOp func;
262 if (symbolTables) {
263 func = symbolTables->lookupSymbolIn<FuncOp>(
264 symTable, StringAttr::get(symTable->getContext(), name));
265 } else {
266 func = llvm::dyn_cast_or_null<FuncOp>(
267 SymbolTable::lookupSymbolIn(symTable, name));
268 }
269
270 if (!func)
271 return func;
272
273 mlir::FunctionType foundFuncT = func.getFunctionType();
274 // Assert the signature of the found function is same as expected
275 if (funcT != foundFuncT) {
276 return func.emitError("matched function '")
277 << name << "' but with different type: " << foundFuncT
278 << " (expected " << funcT << ")";
279 }
280 return func;
281}
static llvm::SmallVector< llvm::SmallVector< int > > getInverseMapping(ArrayRef< int > oldIdxToNewIdx)
This method creates an inverse mapping of the provided map oldToNew.
Definition Utils.cpp:38
static SmallVector< Element > getMappedElements(ArrayRef< Element > origElements, const llvm::SmallVector< llvm::SmallVector< int > > &newIdxToOldIdxs)
This method returns a new vector of elements that are mapped from the origElements based on the newId...
Definition Utils.cpp:54
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:76
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
iterator begin()
Definition Region.h:55
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents a collection of SymbolTables.
virtual Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol)
Look up a symbol with the specified name within the specified symbol table operation,...
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
mlir::func::CallOp replaceCallOpWithNewMapping(mlir::RewriterBase &rewriter, mlir::func::CallOp callOp, ArrayRef< int > oldArgIdxToNewArgIdx, ArrayRef< int > oldResIdxToNewResIdx)
Creates a new call operation with the values as the original call operation, but with the arguments m...
FailureOr< FuncOp > lookupFnDecl(SymbolOpInterface symTable, StringRef name, FunctionType funcT, SymbolTableCollection *symbolTables=nullptr)
Look up a FuncOp with signature resultTypes(paramTypes) and name / name`.
Definition Utils.cpp:259
mlir::FailureOr< mlir::func::FuncOp > replaceFuncWithNewMapping(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp, ArrayRef< int > oldArgIdxToNewArgIdx, ArrayRef< int > oldResIdxToNewResIdx)
Creates a new function operation with the same name as the original function operation,...
mlir::FailureOr< std::pair< mlir::func::FuncOp, mlir::func::CallOp > > deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp, mlir::ModuleOp moduleOp)
This utility function examines all call operations within the given moduleOp that target the specifie...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.