MLIR  16.0.0git
FunctionInterfaces.cpp
Go to the documentation of this file.
1 //===- FunctionSupport.cpp - Utility types for function-like ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 using namespace mlir;
12 
13 //===----------------------------------------------------------------------===//
14 // Tablegen Interface Definitions
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/IR/FunctionOpInterfaces.cpp.inc"
18 
19 //===----------------------------------------------------------------------===//
20 // Function Arguments and Results.
21 //===----------------------------------------------------------------------===//
22 
23 static bool isEmptyAttrDict(Attribute attr) {
24  return attr.cast<DictionaryAttr>().empty();
25 }
26 
28  unsigned index) {
29  ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
30  DictionaryAttr argAttrs =
31  attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
32  return argAttrs;
33 }
34 
35 DictionaryAttr
37  unsigned index) {
38  ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
39  DictionaryAttr resAttrs =
40  attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
41  return resAttrs;
42 }
43 
45  Operation *op, StringRef attrName, unsigned numTotalIndices, unsigned index,
46  DictionaryAttr attrs) {
47  ArrayAttr allAttrs = op->getAttrOfType<ArrayAttr>(attrName);
48  if (!allAttrs) {
49  if (attrs.empty())
50  return;
51 
52  // If this attribute is not empty, we need to create a new attribute array.
53  SmallVector<Attribute, 8> newAttrs(numTotalIndices,
54  DictionaryAttr::get(op->getContext()));
55  newAttrs[index] = attrs;
56  op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
57  return;
58  }
59  // Check to see if the attribute is different from what we already have.
60  if (allAttrs[index] == attrs)
61  return;
62 
63  // If it is, check to see if the attribute array would now contain only empty
64  // dictionaries.
65  ArrayRef<Attribute> rawAttrArray = allAttrs.getValue();
66  if (attrs.empty() &&
67  llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) &&
68  llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) {
69  op->removeAttr(attrName);
70  return;
71  }
72 
73  // Otherwise, create a new attribute array with the updated dictionary.
74  SmallVector<Attribute, 8> newAttrs(rawAttrArray.begin(), rawAttrArray.end());
75  newAttrs[index] = attrs;
76  op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
77 }
78 
79 /// Set all of the argument or result attribute dictionaries for a function.
80 static void setAllArgResAttrDicts(Operation *op, StringRef attrName,
81  ArrayRef<Attribute> attrs) {
82  if (llvm::all_of(attrs, isEmptyAttrDict))
83  op->removeAttr(attrName);
84  else
85  op->setAttr(attrName, ArrayAttr::get(op->getContext(), attrs));
86 }
87 
90  setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
91 }
93  Operation *op, ArrayRef<Attribute> attrs) {
94  auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
95  return !attr ? DictionaryAttr::get(op->getContext()) : attr;
96  });
98  llvm::to_vector<8>(wrappedAttrs));
99 }
100 
102  Operation *op, ArrayRef<DictionaryAttr> attrs) {
103  setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
104 }
106  Operation *op, ArrayRef<Attribute> attrs) {
107  auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
108  return !attr ? DictionaryAttr::get(op->getContext()) : attr;
109  });
111  llvm::to_vector<8>(wrappedAttrs));
112 }
113 
115  Operation *op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
117  unsigned originalNumArgs, Type newType) {
118  assert(argIndices.size() == argTypes.size());
119  assert(argIndices.size() == argAttrs.size() || argAttrs.empty());
120  assert(argIndices.size() == argLocs.size());
121  if (argIndices.empty())
122  return;
123 
124  // There are 3 things that need to be updated:
125  // - Function type.
126  // - Arg attrs.
127  // - Block arguments of entry block.
128  Block &entry = op->getRegion(0).front();
129 
130  // Update the argument attributes of the function.
131  auto oldArgAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
132  if (oldArgAttrs || !argAttrs.empty()) {
133  SmallVector<DictionaryAttr, 4> newArgAttrs;
134  newArgAttrs.reserve(originalNumArgs + argIndices.size());
135  unsigned oldIdx = 0;
136  auto migrate = [&](unsigned untilIdx) {
137  if (!oldArgAttrs) {
138  newArgAttrs.resize(newArgAttrs.size() + untilIdx - oldIdx);
139  } else {
140  auto oldArgAttrRange = oldArgAttrs.getAsRange<DictionaryAttr>();
141  newArgAttrs.append(oldArgAttrRange.begin() + oldIdx,
142  oldArgAttrRange.begin() + untilIdx);
143  }
144  oldIdx = untilIdx;
145  };
146  for (unsigned i = 0, e = argIndices.size(); i < e; ++i) {
147  migrate(argIndices[i]);
148  newArgAttrs.push_back(argAttrs.empty() ? DictionaryAttr{} : argAttrs[i]);
149  }
150  migrate(originalNumArgs);
151  setAllArgAttrDicts(op, newArgAttrs);
152  }
153 
154  // Update the function type and any entry block arguments.
155  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
156  for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
157  entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]);
158 }
159 
161  Operation *op, ArrayRef<unsigned> resultIndices, TypeRange resultTypes,
162  ArrayRef<DictionaryAttr> resultAttrs, unsigned originalNumResults,
163  Type newType) {
164  assert(resultIndices.size() == resultTypes.size());
165  assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty());
166  if (resultIndices.empty())
167  return;
168 
169  // There are 2 things that need to be updated:
170  // - Function type.
171  // - Result attrs.
172 
173  // Update the result attributes of the function.
174  auto oldResultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
175  if (oldResultAttrs || !resultAttrs.empty()) {
176  SmallVector<DictionaryAttr, 4> newResultAttrs;
177  newResultAttrs.reserve(originalNumResults + resultIndices.size());
178  unsigned oldIdx = 0;
179  auto migrate = [&](unsigned untilIdx) {
180  if (!oldResultAttrs) {
181  newResultAttrs.resize(newResultAttrs.size() + untilIdx - oldIdx);
182  } else {
183  auto oldResultAttrsRange = oldResultAttrs.getAsRange<DictionaryAttr>();
184  newResultAttrs.append(oldResultAttrsRange.begin() + oldIdx,
185  oldResultAttrsRange.begin() + untilIdx);
186  }
187  oldIdx = untilIdx;
188  };
189  for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) {
190  migrate(resultIndices[i]);
191  newResultAttrs.push_back(resultAttrs.empty() ? DictionaryAttr{}
192  : resultAttrs[i]);
193  }
194  migrate(originalNumResults);
195  setAllResultAttrDicts(op, newResultAttrs);
196  }
197 
198  // Update the function type.
199  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
200 }
201 
203  Operation *op, const BitVector &argIndices, Type newType) {
204  // There are 3 things that need to be updated:
205  // - Function type.
206  // - Arg attrs.
207  // - Block arguments of entry block.
208  Block &entry = op->getRegion(0).front();
209 
210  // Update the argument attributes of the function.
211  if (auto argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName())) {
212  SmallVector<DictionaryAttr, 4> newArgAttrs;
213  newArgAttrs.reserve(argAttrs.size());
214  for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
215  if (!argIndices[i])
216  newArgAttrs.emplace_back(argAttrs[i].cast<DictionaryAttr>());
217  setAllArgAttrDicts(op, newArgAttrs);
218  }
219 
220  // Update the function type and any entry block arguments.
221  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
222  entry.eraseArguments(argIndices);
223 }
224 
226  Operation *op, const BitVector &resultIndices, Type newType) {
227  // There are 2 things that need to be updated:
228  // - Function type.
229  // - Result attrs.
230 
231  // Update the result attributes of the function.
232  if (auto resAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName())) {
233  SmallVector<DictionaryAttr, 4> newResultAttrs;
234  newResultAttrs.reserve(resAttrs.size());
235  for (unsigned i = 0, e = resultIndices.size(); i < e; ++i)
236  if (!resultIndices[i])
237  newResultAttrs.emplace_back(resAttrs[i].cast<DictionaryAttr>());
238  setAllResultAttrDicts(op, newResultAttrs);
239  }
240 
241  // Update the function type.
242  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
243 }
244 
246  TypeRange oldTypes, ArrayRef<unsigned> indices, TypeRange newTypes,
247  SmallVectorImpl<Type> &storage) {
248  assert(indices.size() == newTypes.size() &&
249  "mismatch between indice and type count");
250  if (indices.empty())
251  return oldTypes;
252 
253  auto fromIt = oldTypes.begin();
254  for (auto it : llvm::zip(indices, newTypes)) {
255  const auto toIt = oldTypes.begin() + std::get<0>(it);
256  storage.append(fromIt, toIt);
257  storage.push_back(std::get<1>(it));
258  fromIt = toIt;
259  }
260  storage.append(fromIt, oldTypes.end());
261  return storage;
262 }
263 
265  TypeRange types, const BitVector &indices, SmallVectorImpl<Type> &storage) {
266  if (indices.none())
267  return types;
268 
269  for (unsigned i = 0, e = types.size(); i < e; ++i)
270  if (!indices[i])
271  storage.emplace_back(types[i]);
272  return storage;
273 }
274 
275 //===----------------------------------------------------------------------===//
276 // Function type signature.
277 //===----------------------------------------------------------------------===//
278 
280  Type newType) {
281  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
282  unsigned oldNumArgs = funcOp.getNumArguments();
283  unsigned oldNumResults = funcOp.getNumResults();
284  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
285  unsigned newNumArgs = funcOp.getNumArguments();
286  unsigned newNumResults = funcOp.getNumResults();
287 
288  // Functor used to update the argument and result attributes of the function.
289  auto updateAttrFn = [&](StringRef attrName, unsigned oldCount,
290  unsigned newCount, auto setAttrFn) {
291  if (oldCount == newCount)
292  return;
293  // The new type has no arguments/results, just drop the attribute.
294  if (newCount == 0) {
295  op->removeAttr(attrName);
296  return;
297  }
298  ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(attrName);
299  if (!attrs)
300  return;
301 
302  // The new type has less arguments/results, take the first N attributes.
303  if (newCount < oldCount)
304  return setAttrFn(op, attrs.getValue().take_front(newCount));
305 
306  // Otherwise, the new type has more arguments/results. Initialize the new
307  // arguments/results with empty attributes.
308  SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end());
309  newAttrs.resize(newCount);
310  setAttrFn(op, newAttrs);
311  };
312 
313  // Update the argument and result attributes.
314  updateAttrFn(
315  getArgDictAttrName(), oldNumArgs, newNumArgs,
316  [&](Operation *op, auto &&attrs) { setAllArgAttrDicts(op, attrs); });
317  updateAttrFn(
318  getResultDictAttrName(), oldNumResults, newNumResults,
319  [&](Operation *op, auto &&attrs) { setAllResultAttrDicts(op, attrs); });
320 }
Include the generated interface declarations.
StringRef getResultDictAttrName()
Return the name of the attribute used for function argument attributes.
U cast() const
Definition: Attributes.h:135
DictionaryAttr getResultAttrDict(Operation *op, unsigned index)
Returns the dictionary attribute corresponding to the result at &#39;index&#39;.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
StringRef getArgDictAttrName()
Return the name of the attribute used for function argument attributes.
Block represents an ordered list of Operations.
Definition: Block.h:29
Block & front()
Definition: Region.h:65
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:375
BlockArgument insertArgument(args_iterator it, Type type, Location loc)
Insert one value to the position in the argument list indicated by the given iterator.
Definition: Block.cpp:175
void setAllArgAttrDicts(Operation *op, ArrayRef< DictionaryAttr > attrs)
Set all of the argument or result attribute dictionaries for a function.
void insertFunctionResults(Operation *op, ArrayRef< unsigned > resultIndices, TypeRange resultTypes, ArrayRef< DictionaryAttr > resultAttrs, unsigned originalNumResults, Type newType)
Insert the specified results and update the function type attribute.
TypeRange filterTypesOut(TypeRange types, const BitVector &indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
void eraseFunctionResults(Operation *op, const BitVector &resultIndices, Type newType)
Erase the specified results and update the function type attribute.
StringRef getTypeAttrName()
Return the name of the attribute used for function types.
DictionaryAttr getArgAttrDict(Operation *op, unsigned index)
Returns the dictionary attribute corresponding to the argument at &#39;index&#39;.
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:147
void eraseArguments(ArrayRef< unsigned > argIndices)
Erases the arguments listed in argIndices and removes them from the argument list.
Definition: Block.cpp:189
void setArgResAttrDict(Operation *op, StringRef attrName, unsigned numTotalIndices, unsigned index, DictionaryAttr attrs)
Update the given index into an argument or result attribute dictionary.
Attributes are known-constant values of operations.
Definition: Attributes.h:24
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition: Operation.h:407
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:32
void insertFunctionArguments(Operation *op, ArrayRef< unsigned > argIndices, TypeRange argTypes, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< Location > argLocs, unsigned originalNumArgs, Type newType)
Insert the specified arguments and update the function type attribute.
void setFunctionType(Operation *op, Type newType)
Set a FunctionOpInterface operation&#39;s type signature.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:395
void eraseFunctionArguments(Operation *op, const BitVector &argIndices, Type newType)
Erase the specified arguments and update the function type attribute.
void setAllResultAttrDicts(Operation *op, ArrayRef< DictionaryAttr > attrs)
static bool isEmptyAttrDict(Attribute attr)
static void setAllArgResAttrDicts(Operation *op, StringRef attrName, ArrayRef< Attribute > attrs)
Set all of the argument or result attribute dictionaries for a function.
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.
Region & getRegion(unsigned index)
Returns the region held by this operation at position &#39;index&#39;.
Definition: Operation.h:486