MLIR 22.0.0git
FunctionInterfaces.cpp
Go to the documentation of this file.
1//===- FunctionSupport.cpp - Utility types for function-like ops ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
11using namespace mlir;
12
13//===----------------------------------------------------------------------===//
14// Tablegen Interface Definitions
15//===----------------------------------------------------------------------===//
16
17#include "mlir/Interfaces/FunctionInterfaces.cpp.inc"
18
19//===----------------------------------------------------------------------===//
20// Function Arguments and Results.
21//===----------------------------------------------------------------------===//
22
23static bool isEmptyAttrDict(Attribute attr) {
24 return llvm::cast<DictionaryAttr>(attr).empty();
25}
26
27DictionaryAttr function_interface_impl::getArgAttrDict(FunctionOpInterface op,
28 unsigned index) {
29 ArrayAttr attrs = op.getArgAttrsAttr();
30 DictionaryAttr argAttrs =
31 attrs ? llvm::cast<DictionaryAttr>(attrs[index]) : DictionaryAttr();
32 return argAttrs;
33}
34
35DictionaryAttr
37 unsigned index) {
38 ArrayAttr attrs = op.getResAttrsAttr();
39 DictionaryAttr resAttrs =
40 attrs ? llvm::cast<DictionaryAttr>(attrs[index]) : DictionaryAttr();
41 return resAttrs;
42}
43
45function_interface_impl::getArgAttrs(FunctionOpInterface op, unsigned index) {
46 auto argDict = getArgAttrDict(op, index);
47 return argDict ? argDict.getValue() : ArrayRef<NamedAttribute>();
48}
49
52 unsigned index) {
53 auto resultDict = getResultAttrDict(op, index);
54 return resultDict ? resultDict.getValue() : ArrayRef<NamedAttribute>();
55}
56
57/// Get either the argument or result attributes array.
58template <bool isArg>
59static ArrayAttr getArgResAttrs(FunctionOpInterface op) {
60 if constexpr (isArg)
61 return op.getArgAttrsAttr();
62 else
63 return op.getResAttrsAttr();
64}
65
66/// Set either the argument or result attributes array.
67template <bool isArg>
68static void setArgResAttrs(FunctionOpInterface op, ArrayAttr attrs) {
69 if constexpr (isArg)
70 op.setArgAttrsAttr(attrs);
71 else
72 op.setResAttrsAttr(attrs);
73}
74
75/// Erase either the argument or result attributes array.
76template <bool isArg>
77static void removeArgResAttrs(FunctionOpInterface op) {
78 if constexpr (isArg)
79 op.removeArgAttrsAttr();
80 else
81 op.removeResAttrsAttr();
82}
83
84/// Set all of the argument or result attribute dictionaries for a function.
85template <bool isArg>
86static void setAllArgResAttrDicts(FunctionOpInterface op,
87 ArrayRef<Attribute> attrs) {
88 if (llvm::all_of(attrs, isEmptyAttrDict))
90 else
91 setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), attrs));
92}
93
95 FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) {
96 setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
97}
98
100 ArrayRef<Attribute> attrs) {
101 auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
102 return !attr ? DictionaryAttr::get(op->getContext()) : attr;
103 });
104 setAllArgResAttrDicts</*isArg=*/true>(op, llvm::to_vector<8>(wrappedAttrs));
105}
106
108 FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) {
109 setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
110}
111
113 ArrayRef<Attribute> attrs) {
114 auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
115 return !attr ? DictionaryAttr::get(op->getContext()) : attr;
116 });
117 setAllArgResAttrDicts</*isArg=*/false>(op, llvm::to_vector<8>(wrappedAttrs));
118}
119
120/// Update the given index into an argument or result attribute dictionary.
121template <bool isArg>
122static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices,
123 unsigned index, DictionaryAttr attrs) {
124 ArrayAttr allAttrs = getArgResAttrs<isArg>(op);
125 if (!allAttrs) {
126 if (attrs.empty())
127 return;
128
129 // If this attribute is not empty, we need to create a new attribute array.
130 SmallVector<Attribute, 8> newAttrs(numTotalIndices,
131 DictionaryAttr::get(op->getContext()));
132 newAttrs[index] = attrs;
133 setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs));
134 return;
135 }
136 // Check to see if the attribute is different from what we already have.
137 if (allAttrs[index] == attrs)
138 return;
139
140 // If it is, check to see if the attribute array would now contain only empty
141 // dictionaries.
142 ArrayRef<Attribute> rawAttrArray = allAttrs.getValue();
143 if (attrs.empty() &&
144 llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) &&
145 llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict))
146 return removeArgResAttrs<isArg>(op);
147
148 // Otherwise, create a new attribute array with the updated dictionary.
149 SmallVector<Attribute, 8> newAttrs(rawAttrArray);
150 newAttrs[index] = attrs;
151 setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs));
152}
153
154void function_interface_impl::setArgAttrs(FunctionOpInterface op,
155 unsigned index,
156 ArrayRef<NamedAttribute> attributes) {
157 assert(index < op.getNumArguments() && "invalid argument number");
158 return setArgResAttrDict</*isArg=*/true>(
159 op, op.getNumArguments(), index,
160 DictionaryAttr::get(op->getContext(), attributes));
161}
162
163void function_interface_impl::setArgAttrs(FunctionOpInterface op,
164 unsigned index,
165 DictionaryAttr attributes) {
166 return setArgResAttrDict</*isArg=*/true>(
167 op, op.getNumArguments(), index,
168 attributes ? attributes : DictionaryAttr::get(op->getContext()));
169}
170
172 FunctionOpInterface op, unsigned index,
173 ArrayRef<NamedAttribute> attributes) {
174 assert(index < op.getNumResults() && "invalid result number");
175 return setArgResAttrDict</*isArg=*/false>(
176 op, op.getNumResults(), index,
177 DictionaryAttr::get(op->getContext(), attributes));
178}
179
180void function_interface_impl::setResultAttrs(FunctionOpInterface op,
181 unsigned index,
182 DictionaryAttr attributes) {
183 assert(index < op.getNumResults() && "invalid result number");
184 return setArgResAttrDict</*isArg=*/false>(
185 op, op.getNumResults(), index,
186 attributes ? attributes : DictionaryAttr::get(op->getContext()));
187}
188
190 FunctionOpInterface op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
192 unsigned originalNumArgs, Type newType) {
193 assert(argIndices.size() == argTypes.size());
194 assert(argIndices.size() == argAttrs.size() || argAttrs.empty());
195 assert(argIndices.size() == argLocs.size());
196 if (argIndices.empty())
197 return;
198
199 // There are 3 things that need to be updated:
200 // - Function type.
201 // - Arg attrs.
202 // - Block arguments of entry block, if not empty.
203
204 // Update the argument attributes of the function.
205 ArrayAttr oldArgAttrs = op.getArgAttrsAttr();
206 if (oldArgAttrs || !argAttrs.empty()) {
208 newArgAttrs.reserve(originalNumArgs + argIndices.size());
209 unsigned oldIdx = 0;
210 auto migrate = [&](unsigned untilIdx) {
211 if (!oldArgAttrs) {
212 newArgAttrs.resize(newArgAttrs.size() + untilIdx - oldIdx);
213 } else {
214 auto oldArgAttrRange = oldArgAttrs.getAsRange<DictionaryAttr>();
215 newArgAttrs.append(oldArgAttrRange.begin() + oldIdx,
216 oldArgAttrRange.begin() + untilIdx);
217 }
218 oldIdx = untilIdx;
219 };
220 for (unsigned i = 0, e = argIndices.size(); i < e; ++i) {
221 migrate(argIndices[i]);
222 newArgAttrs.push_back(argAttrs.empty() ? DictionaryAttr{} : argAttrs[i]);
223 }
224 migrate(originalNumArgs);
225 setAllArgAttrDicts(op, newArgAttrs);
226 }
227
228 // Update the function type.
229 op.setFunctionTypeAttr(TypeAttr::get(newType));
230
231 // Update entry block arguments, if not empty.
232 if (!op.isExternal()) {
233 Block &entry = op->getRegion(0).front();
234 for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
235 entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]);
236 }
237}
238
240 FunctionOpInterface op, ArrayRef<unsigned> resultIndices,
241 TypeRange resultTypes, ArrayRef<DictionaryAttr> resultAttrs,
242 unsigned originalNumResults, Type newType) {
243 assert(resultIndices.size() == resultTypes.size());
244 assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty());
245 if (resultIndices.empty())
246 return;
247
248 // There are 2 things that need to be updated:
249 // - Function type.
250 // - Result attrs.
251
252 // Update the result attributes of the function.
253 ArrayAttr oldResultAttrs = op.getResAttrsAttr();
254 if (oldResultAttrs || !resultAttrs.empty()) {
255 SmallVector<DictionaryAttr, 4> newResultAttrs;
256 newResultAttrs.reserve(originalNumResults + resultIndices.size());
257 unsigned oldIdx = 0;
258 auto migrate = [&](unsigned untilIdx) {
259 if (!oldResultAttrs) {
260 newResultAttrs.resize(newResultAttrs.size() + untilIdx - oldIdx);
261 } else {
262 auto oldResultAttrsRange = oldResultAttrs.getAsRange<DictionaryAttr>();
263 newResultAttrs.append(oldResultAttrsRange.begin() + oldIdx,
264 oldResultAttrsRange.begin() + untilIdx);
265 }
266 oldIdx = untilIdx;
267 };
268 for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) {
269 migrate(resultIndices[i]);
270 newResultAttrs.push_back(resultAttrs.empty() ? DictionaryAttr{}
271 : resultAttrs[i]);
272 }
273 migrate(originalNumResults);
274 setAllResultAttrDicts(op, newResultAttrs);
275 }
276
277 // Update the function type.
278 op.setFunctionTypeAttr(TypeAttr::get(newType));
279}
280
282 FunctionOpInterface op, const BitVector &argIndices, Type newType) {
283 // There are 3 things that need to be updated:
284 // - Function type.
285 // - Arg attrs.
286 // - Block arguments of entry block, if not empty.
287
288 // Update the argument attributes of the function.
289 if (ArrayAttr argAttrs = op.getArgAttrsAttr()) {
291 newArgAttrs.reserve(argAttrs.size());
292 for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
293 if (!argIndices[i])
294 newArgAttrs.emplace_back(llvm::cast<DictionaryAttr>(argAttrs[i]));
295 setAllArgAttrDicts(op, newArgAttrs);
296 }
297
298 // Update the function type.
299 op.setFunctionTypeAttr(TypeAttr::get(newType));
300
301 // Update entry block arguments, if not empty.
302 if (!op.isExternal()) {
303 Block &entry = op->getRegion(0).front();
304 entry.eraseArguments(argIndices);
305 }
306}
307
309 FunctionOpInterface op, const BitVector &resultIndices, Type newType) {
310 // There are 2 things that need to be updated:
311 // - Function type.
312 // - Result attrs.
313
314 // Update the result attributes of the function.
315 if (ArrayAttr resAttrs = op.getResAttrsAttr()) {
316 SmallVector<DictionaryAttr, 4> newResultAttrs;
317 newResultAttrs.reserve(resAttrs.size());
318 for (unsigned i = 0, e = resultIndices.size(); i < e; ++i)
319 if (!resultIndices[i])
320 newResultAttrs.emplace_back(llvm::cast<DictionaryAttr>(resAttrs[i]));
321 setAllResultAttrDicts(op, newResultAttrs);
322 }
323
324 // Update the function type.
325 op.setFunctionTypeAttr(TypeAttr::get(newType));
326}
327
328//===----------------------------------------------------------------------===//
329// Function type signature.
330//===----------------------------------------------------------------------===//
331
333 Type newType) {
334 unsigned oldNumArgs = op.getNumArguments();
335 unsigned oldNumResults = op.getNumResults();
336 op.setFunctionTypeAttr(TypeAttr::get(newType));
337 unsigned newNumArgs = op.getNumArguments();
338 unsigned newNumResults = op.getNumResults();
339
340 // Functor used to update the argument and result attributes of the function.
341 auto emptyDict = DictionaryAttr::get(op.getContext());
342 auto updateAttrFn = [&](auto isArg, unsigned oldCount, unsigned newCount) {
343 constexpr bool isArgVal = std::is_same_v<decltype(isArg), std::true_type>;
344
345 if (oldCount == newCount)
346 return;
347 // The new type has no arguments/results, just drop the attribute.
348 if (newCount == 0)
351 if (!attrs)
352 return;
353
354 // The new type has less arguments/results, take the first N attributes.
355 if (newCount < oldCount)
357 op, attrs.getValue().take_front(newCount));
358
359 // Otherwise, the new type has more arguments/results. Initialize the new
360 // arguments/results with empty dictionary attributes.
361 SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end());
362 newAttrs.resize(newCount, emptyDict);
364 };
365
366 // Update the argument and result attributes.
367 updateAttrFn(std::true_type{}, oldNumArgs, newNumArgs);
368 updateAttrFn(std::false_type{}, oldNumResults, newNumResults);
369}
static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices, unsigned index, DictionaryAttr attrs)
Update the given index into an argument or result attribute dictionary.
static void removeArgResAttrs(FunctionOpInterface op)
Erase either the argument or result attributes array.
static bool isEmptyAttrDict(Attribute attr)
static ArrayAttr getArgResAttrs(FunctionOpInterface op)
Get either the argument or result attributes array.
static void setArgResAttrs(FunctionOpInterface op, ArrayAttr attrs)
Set either the argument or result attributes array.
static void setAllArgResAttrDicts(FunctionOpInterface op, ArrayRef< Attribute > attrs)
Set all of the argument or result attribute dictionaries for a function.
ArrayAttr()
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument insertArgument(args_iterator it, Type type, Location loc)
Insert one value to the position in the argument list indicated by the given iterator.
Definition Block.cpp:187
Operation & front()
Definition Block.h:153
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition Block.cpp:201
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
ArrayRef< NamedAttribute > getResultAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the result at 'index'.
void setAllResultAttrDicts(FunctionOpInterface op, ArrayRef< DictionaryAttr > attrs)
void insertFunctionArguments(FunctionOpInterface op, ArrayRef< unsigned > argIndices, TypeRange argTypes, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< Location > argLocs, unsigned originalNumArgs, Type newType)
Insert the specified arguments and update the function type attribute.
void setResultAttrs(FunctionOpInterface op, unsigned index, ArrayRef< NamedAttribute > attributes)
Set the attributes held by the result at 'index'.
void eraseFunctionResults(FunctionOpInterface op, const BitVector &resultIndices, Type newType)
Erase the specified results and update the function type attribute.
void setArgAttrs(FunctionOpInterface op, unsigned index, ArrayRef< NamedAttribute > attributes)
Set the attributes held by the argument at 'index'.
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
void setAllArgAttrDicts(FunctionOpInterface op, ArrayRef< DictionaryAttr > attrs)
Set all of the argument or result attribute dictionaries for a function.
void insertFunctionResults(FunctionOpInterface op, ArrayRef< unsigned > resultIndices, TypeRange resultTypes, ArrayRef< DictionaryAttr > resultAttrs, unsigned originalNumResults, Type newType)
Insert the specified results and update the function type attribute.
DictionaryAttr getResultAttrDict(FunctionOpInterface op, unsigned index)
Returns the dictionary attribute corresponding to the result at 'index'.
DictionaryAttr getArgAttrDict(FunctionOpInterface op, unsigned index)
Returns the dictionary attribute corresponding to the argument at 'index'.
void eraseFunctionArguments(FunctionOpInterface op, const BitVector &argIndices, Type newType)
Erase the specified arguments and update the function type attribute.
void setFunctionType(FunctionOpInterface op, Type newType)
Set a FunctionOpInterface operation's type signature.
Include the generated interface declarations.