MLIR  20.0.0git
FunctionInterfaces.cpp
Go to the documentation of this file.
1 //===- FunctionSupport.cpp - Utility types for function-like ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 using namespace mlir;
12 
13 //===----------------------------------------------------------------------===//
14 // Tablegen Interface Definitions
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Interfaces/FunctionInterfaces.cpp.inc"
18 
19 //===----------------------------------------------------------------------===//
20 // Function Arguments and Results.
21 //===----------------------------------------------------------------------===//
22 
23 static bool isEmptyAttrDict(Attribute attr) {
24  return llvm::cast<DictionaryAttr>(attr).empty();
25 }
26 
27 DictionaryAttr function_interface_impl::getArgAttrDict(FunctionOpInterface op,
28  unsigned index) {
29  ArrayAttr attrs = op.getArgAttrsAttr();
30  DictionaryAttr argAttrs =
31  attrs ? llvm::cast<DictionaryAttr>(attrs[index]) : DictionaryAttr();
32  return argAttrs;
33 }
34 
35 DictionaryAttr
37  unsigned index) {
38  ArrayAttr attrs = op.getResAttrsAttr();
39  DictionaryAttr resAttrs =
40  attrs ? llvm::cast<DictionaryAttr>(attrs[index]) : DictionaryAttr();
41  return resAttrs;
42 }
43 
45 function_interface_impl::getArgAttrs(FunctionOpInterface op, unsigned index) {
46  auto argDict = getArgAttrDict(op, index);
47  return argDict ? argDict.getValue() : std::nullopt;
48 }
49 
52  unsigned index) {
53  auto resultDict = getResultAttrDict(op, index);
54  return resultDict ? resultDict.getValue() : std::nullopt;
55 }
56 
57 /// Get either the argument or result attributes array.
58 template <bool isArg>
59 static ArrayAttr getArgResAttrs(FunctionOpInterface op) {
60  if constexpr (isArg)
61  return op.getArgAttrsAttr();
62  else
63  return op.getResAttrsAttr();
64 }
65 
66 /// Set either the argument or result attributes array.
67 template <bool isArg>
68 static void setArgResAttrs(FunctionOpInterface op, ArrayAttr attrs) {
69  if constexpr (isArg)
70  op.setArgAttrsAttr(attrs);
71  else
72  op.setResAttrsAttr(attrs);
73 }
74 
75 /// Erase either the argument or result attributes array.
76 template <bool isArg>
77 static void removeArgResAttrs(FunctionOpInterface op) {
78  if constexpr (isArg)
79  op.removeArgAttrsAttr();
80  else
81  op.removeResAttrsAttr();
82 }
83 
84 /// Set all of the argument or result attribute dictionaries for a function.
85 template <bool isArg>
86 static void setAllArgResAttrDicts(FunctionOpInterface op,
87  ArrayRef<Attribute> attrs) {
88  if (llvm::all_of(attrs, isEmptyAttrDict))
89  removeArgResAttrs<isArg>(op);
90  else
91  setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), attrs));
92 }
93 
95  FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) {
96  setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
97 }
98 
99 void function_interface_impl::setAllArgAttrDicts(FunctionOpInterface op,
100  ArrayRef<Attribute> attrs) {
101  auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
102  return !attr ? DictionaryAttr::get(op->getContext()) : attr;
103  });
104  setAllArgResAttrDicts</*isArg=*/true>(op, llvm::to_vector<8>(wrappedAttrs));
105 }
106 
108  FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) {
109  setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
110 }
111 
113  ArrayRef<Attribute> attrs) {
114  auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
115  return !attr ? DictionaryAttr::get(op->getContext()) : attr;
116  });
117  setAllArgResAttrDicts</*isArg=*/false>(op, llvm::to_vector<8>(wrappedAttrs));
118 }
119 
120 /// Update the given index into an argument or result attribute dictionary.
121 template <bool isArg>
122 static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices,
123  unsigned index, DictionaryAttr attrs) {
124  ArrayAttr allAttrs = getArgResAttrs<isArg>(op);
125  if (!allAttrs) {
126  if (attrs.empty())
127  return;
128 
129  // If this attribute is not empty, we need to create a new attribute array.
130  SmallVector<Attribute, 8> newAttrs(numTotalIndices,
132  newAttrs[index] = attrs;
133  setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs));
134  return;
135  }
136  // Check to see if the attribute is different from what we already have.
137  if (allAttrs[index] == attrs)
138  return;
139 
140  // If it is, check to see if the attribute array would now contain only empty
141  // dictionaries.
142  ArrayRef<Attribute> rawAttrArray = allAttrs.getValue();
143  if (attrs.empty() &&
144  llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) &&
145  llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict))
146  return removeArgResAttrs<isArg>(op);
147 
148  // Otherwise, create a new attribute array with the updated dictionary.
149  SmallVector<Attribute, 8> newAttrs(rawAttrArray);
150  newAttrs[index] = attrs;
151  setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs));
152 }
153 
154 void function_interface_impl::setArgAttrs(FunctionOpInterface op,
155  unsigned index,
156  ArrayRef<NamedAttribute> attributes) {
157  assert(index < op.getNumArguments() && "invalid argument number");
158  return setArgResAttrDict</*isArg=*/true>(
159  op, op.getNumArguments(), index,
160  DictionaryAttr::get(op->getContext(), attributes));
161 }
162 
163 void function_interface_impl::setArgAttrs(FunctionOpInterface op,
164  unsigned index,
165  DictionaryAttr attributes) {
166  return setArgResAttrDict</*isArg=*/true>(
167  op, op.getNumArguments(), index,
168  attributes ? attributes : DictionaryAttr::get(op->getContext()));
169 }
170 
172  FunctionOpInterface op, unsigned index,
173  ArrayRef<NamedAttribute> attributes) {
174  assert(index < op.getNumResults() && "invalid result number");
175  return setArgResAttrDict</*isArg=*/false>(
176  op, op.getNumResults(), index,
177  DictionaryAttr::get(op->getContext(), attributes));
178 }
179 
180 void function_interface_impl::setResultAttrs(FunctionOpInterface op,
181  unsigned index,
182  DictionaryAttr attributes) {
183  assert(index < op.getNumResults() && "invalid result number");
184  return setArgResAttrDict</*isArg=*/false>(
185  op, op.getNumResults(), index,
186  attributes ? attributes : DictionaryAttr::get(op->getContext()));
187 }
188 
190  FunctionOpInterface op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
192  unsigned originalNumArgs, Type newType) {
193  assert(argIndices.size() == argTypes.size());
194  assert(argIndices.size() == argAttrs.size() || argAttrs.empty());
195  assert(argIndices.size() == argLocs.size());
196  if (argIndices.empty())
197  return;
198 
199  // There are 3 things that need to be updated:
200  // - Function type.
201  // - Arg attrs.
202  // - Block arguments of entry block.
203  Block &entry = op->getRegion(0).front();
204 
205  // Update the argument attributes of the function.
206  ArrayAttr oldArgAttrs = op.getArgAttrsAttr();
207  if (oldArgAttrs || !argAttrs.empty()) {
208  SmallVector<DictionaryAttr, 4> newArgAttrs;
209  newArgAttrs.reserve(originalNumArgs + argIndices.size());
210  unsigned oldIdx = 0;
211  auto migrate = [&](unsigned untilIdx) {
212  if (!oldArgAttrs) {
213  newArgAttrs.resize(newArgAttrs.size() + untilIdx - oldIdx);
214  } else {
215  auto oldArgAttrRange = oldArgAttrs.getAsRange<DictionaryAttr>();
216  newArgAttrs.append(oldArgAttrRange.begin() + oldIdx,
217  oldArgAttrRange.begin() + untilIdx);
218  }
219  oldIdx = untilIdx;
220  };
221  for (unsigned i = 0, e = argIndices.size(); i < e; ++i) {
222  migrate(argIndices[i]);
223  newArgAttrs.push_back(argAttrs.empty() ? DictionaryAttr{} : argAttrs[i]);
224  }
225  migrate(originalNumArgs);
226  setAllArgAttrDicts(op, newArgAttrs);
227  }
228 
229  // Update the function type and any entry block arguments.
230  op.setFunctionTypeAttr(TypeAttr::get(newType));
231  for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
232  entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]);
233 }
234 
236  FunctionOpInterface op, ArrayRef<unsigned> resultIndices,
237  TypeRange resultTypes, ArrayRef<DictionaryAttr> resultAttrs,
238  unsigned originalNumResults, Type newType) {
239  assert(resultIndices.size() == resultTypes.size());
240  assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty());
241  if (resultIndices.empty())
242  return;
243 
244  // There are 2 things that need to be updated:
245  // - Function type.
246  // - Result attrs.
247 
248  // Update the result attributes of the function.
249  ArrayAttr oldResultAttrs = op.getResAttrsAttr();
250  if (oldResultAttrs || !resultAttrs.empty()) {
251  SmallVector<DictionaryAttr, 4> newResultAttrs;
252  newResultAttrs.reserve(originalNumResults + resultIndices.size());
253  unsigned oldIdx = 0;
254  auto migrate = [&](unsigned untilIdx) {
255  if (!oldResultAttrs) {
256  newResultAttrs.resize(newResultAttrs.size() + untilIdx - oldIdx);
257  } else {
258  auto oldResultAttrsRange = oldResultAttrs.getAsRange<DictionaryAttr>();
259  newResultAttrs.append(oldResultAttrsRange.begin() + oldIdx,
260  oldResultAttrsRange.begin() + untilIdx);
261  }
262  oldIdx = untilIdx;
263  };
264  for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) {
265  migrate(resultIndices[i]);
266  newResultAttrs.push_back(resultAttrs.empty() ? DictionaryAttr{}
267  : resultAttrs[i]);
268  }
269  migrate(originalNumResults);
270  setAllResultAttrDicts(op, newResultAttrs);
271  }
272 
273  // Update the function type.
274  op.setFunctionTypeAttr(TypeAttr::get(newType));
275 }
276 
278  FunctionOpInterface op, const BitVector &argIndices, Type newType) {
279  // There are 3 things that need to be updated:
280  // - Function type.
281  // - Arg attrs.
282  // - Block arguments of entry block.
283  Block &entry = op->getRegion(0).front();
284 
285  // Update the argument attributes of the function.
286  if (ArrayAttr argAttrs = op.getArgAttrsAttr()) {
287  SmallVector<DictionaryAttr, 4> newArgAttrs;
288  newArgAttrs.reserve(argAttrs.size());
289  for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
290  if (!argIndices[i])
291  newArgAttrs.emplace_back(llvm::cast<DictionaryAttr>(argAttrs[i]));
292  setAllArgAttrDicts(op, newArgAttrs);
293  }
294 
295  // Update the function type and any entry block arguments.
296  op.setFunctionTypeAttr(TypeAttr::get(newType));
297  entry.eraseArguments(argIndices);
298 }
299 
301  FunctionOpInterface op, const BitVector &resultIndices, Type newType) {
302  // There are 2 things that need to be updated:
303  // - Function type.
304  // - Result attrs.
305 
306  // Update the result attributes of the function.
307  if (ArrayAttr resAttrs = op.getResAttrsAttr()) {
308  SmallVector<DictionaryAttr, 4> newResultAttrs;
309  newResultAttrs.reserve(resAttrs.size());
310  for (unsigned i = 0, e = resultIndices.size(); i < e; ++i)
311  if (!resultIndices[i])
312  newResultAttrs.emplace_back(llvm::cast<DictionaryAttr>(resAttrs[i]));
313  setAllResultAttrDicts(op, newResultAttrs);
314  }
315 
316  // Update the function type.
317  op.setFunctionTypeAttr(TypeAttr::get(newType));
318 }
319 
320 //===----------------------------------------------------------------------===//
321 // Function type signature.
322 //===----------------------------------------------------------------------===//
323 
324 void function_interface_impl::setFunctionType(FunctionOpInterface op,
325  Type newType) {
326  unsigned oldNumArgs = op.getNumArguments();
327  unsigned oldNumResults = op.getNumResults();
328  op.setFunctionTypeAttr(TypeAttr::get(newType));
329  unsigned newNumArgs = op.getNumArguments();
330  unsigned newNumResults = op.getNumResults();
331 
332  // Functor used to update the argument and result attributes of the function.
333  auto emptyDict = DictionaryAttr::get(op.getContext());
334  auto updateAttrFn = [&](auto isArg, unsigned oldCount, unsigned newCount) {
335  constexpr bool isArgVal = std::is_same_v<decltype(isArg), std::true_type>;
336 
337  if (oldCount == newCount)
338  return;
339  // The new type has no arguments/results, just drop the attribute.
340  if (newCount == 0)
341  return removeArgResAttrs<isArgVal>(op);
342  ArrayAttr attrs = getArgResAttrs<isArgVal>(op);
343  if (!attrs)
344  return;
345 
346  // The new type has less arguments/results, take the first N attributes.
347  if (newCount < oldCount)
348  return setAllArgResAttrDicts<isArgVal>(
349  op, attrs.getValue().take_front(newCount));
350 
351  // Otherwise, the new type has more arguments/results. Initialize the new
352  // arguments/results with empty dictionary attributes.
353  SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end());
354  newAttrs.resize(newCount, emptyDict);
355  setAllArgResAttrDicts<isArgVal>(op, newAttrs);
356  };
357 
358  // Update the argument and result attributes.
359  updateAttrFn(std::true_type{}, oldNumArgs, newNumArgs);
360  updateAttrFn(std::false_type{}, oldNumResults, newNumResults);
361 }
static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices, unsigned index, DictionaryAttr attrs)
Update the given index into an argument or result attribute dictionary.
static void removeArgResAttrs(FunctionOpInterface op)
Erase either the argument or result attributes array.
static bool isEmptyAttrDict(Attribute attr)
static ArrayAttr getArgResAttrs(FunctionOpInterface op)
Get either the argument or result attributes array.
static void setArgResAttrs(FunctionOpInterface op, ArrayAttr attrs)
Set either the argument or result attributes array.
static void setAllArgResAttrDicts(FunctionOpInterface op, ArrayRef< Attribute > attrs)
Set all of the argument or result attribute dictionaries for a function.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgument insertArgument(args_iterator it, Type type, Location loc)
Insert one value to the position in the argument list indicated by the given iterator.
Definition: Block.cpp:186
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:200
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
Block & front()
Definition: Region.h:65
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
ArrayRef< NamedAttribute > getResultAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the result at 'index'.
void setAllResultAttrDicts(FunctionOpInterface op, ArrayRef< DictionaryAttr > attrs)
void insertFunctionArguments(FunctionOpInterface op, ArrayRef< unsigned > argIndices, TypeRange argTypes, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< Location > argLocs, unsigned originalNumArgs, Type newType)
Insert the specified arguments and update the function type attribute.
void setResultAttrs(FunctionOpInterface op, unsigned index, ArrayRef< NamedAttribute > attributes)
Set the attributes held by the result at 'index'.
void eraseFunctionResults(FunctionOpInterface op, const BitVector &resultIndices, Type newType)
Erase the specified results and update the function type attribute.
void setArgAttrs(FunctionOpInterface op, unsigned index, ArrayRef< NamedAttribute > attributes)
Set the attributes held by the argument at 'index'.
ArrayRef< NamedAttribute > getArgAttrs(FunctionOpInterface op, unsigned index)
Return all of the attributes for the argument at 'index'.
void setAllArgAttrDicts(FunctionOpInterface op, ArrayRef< DictionaryAttr > attrs)
Set all of the argument or result attribute dictionaries for a function.
void insertFunctionResults(FunctionOpInterface op, ArrayRef< unsigned > resultIndices, TypeRange resultTypes, ArrayRef< DictionaryAttr > resultAttrs, unsigned originalNumResults, Type newType)
Insert the specified results and update the function type attribute.
DictionaryAttr getResultAttrDict(FunctionOpInterface op, unsigned index)
Returns the dictionary attribute corresponding to the result at 'index'.
DictionaryAttr getArgAttrDict(FunctionOpInterface op, unsigned index)
Returns the dictionary attribute corresponding to the argument at 'index'.
void eraseFunctionArguments(FunctionOpInterface op, const BitVector &argIndices, Type newType)
Erase the specified arguments and update the function type attribute.
void setFunctionType(FunctionOpInterface op, Type newType)
Set a FunctionOpInterface operation's type signature.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...