MLIR  16.0.0git
FunctionInterfaces.cpp
Go to the documentation of this file.
1 //===- FunctionSupport.cpp - Utility types for function-like ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 using namespace mlir;
12 
13 //===----------------------------------------------------------------------===//
14 // Tablegen Interface Definitions
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/IR/FunctionOpInterfaces.cpp.inc"
18 
19 //===----------------------------------------------------------------------===//
20 // Function Arguments and Results.
21 //===----------------------------------------------------------------------===//
22 
23 static bool isEmptyAttrDict(Attribute attr) {
24  return attr.cast<DictionaryAttr>().empty();
25 }
26 
28  unsigned index) {
29  ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
30  DictionaryAttr argAttrs =
31  attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
32  return argAttrs;
33 }
34 
35 DictionaryAttr
37  unsigned index) {
38  ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
39  DictionaryAttr resAttrs =
40  attrs ? attrs[index].cast<DictionaryAttr>() : DictionaryAttr();
41  return resAttrs;
42 }
43 
45  Operation *op, StringRef attrName, unsigned numTotalIndices, unsigned index,
46  DictionaryAttr attrs) {
47  ArrayAttr allAttrs = op->getAttrOfType<ArrayAttr>(attrName);
48  if (!allAttrs) {
49  if (attrs.empty())
50  return;
51 
52  // If this attribute is not empty, we need to create a new attribute array.
53  SmallVector<Attribute, 8> newAttrs(numTotalIndices,
54  DictionaryAttr::get(op->getContext()));
55  newAttrs[index] = attrs;
56  op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
57  return;
58  }
59  // Check to see if the attribute is different from what we already have.
60  if (allAttrs[index] == attrs)
61  return;
62 
63  // If it is, check to see if the attribute array would now contain only empty
64  // dictionaries.
65  ArrayRef<Attribute> rawAttrArray = allAttrs.getValue();
66  if (attrs.empty() &&
67  llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) &&
68  llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict)) {
69  op->removeAttr(attrName);
70  return;
71  }
72 
73  // Otherwise, create a new attribute array with the updated dictionary.
74  SmallVector<Attribute, 8> newAttrs(rawAttrArray.begin(), rawAttrArray.end());
75  newAttrs[index] = attrs;
76  op->setAttr(attrName, ArrayAttr::get(op->getContext(), newAttrs));
77 }
78 
79 /// Set all of the argument or result attribute dictionaries for a function.
80 static void setAllArgResAttrDicts(Operation *op, StringRef attrName,
81  ArrayRef<Attribute> attrs) {
82  if (llvm::all_of(attrs, isEmptyAttrDict))
83  op->removeAttr(attrName);
84  else
85  op->setAttr(attrName, ArrayAttr::get(op->getContext(), attrs));
86 }
87 
90  setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
91 }
93  Operation *op, ArrayRef<Attribute> attrs) {
94  auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
95  return !attr ? DictionaryAttr::get(op->getContext()) : attr;
96  });
98  llvm::to_vector<8>(wrappedAttrs));
99 }
100 
102  Operation *op, ArrayRef<DictionaryAttr> attrs) {
103  setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
104 }
106  Operation *op, ArrayRef<Attribute> attrs) {
107  auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
108  return !attr ? DictionaryAttr::get(op->getContext()) : attr;
109  });
111  llvm::to_vector<8>(wrappedAttrs));
112 }
113 
115  Operation *op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
117  unsigned originalNumArgs, Type newType) {
118  assert(argIndices.size() == argTypes.size());
119  assert(argIndices.size() == argAttrs.size() || argAttrs.empty());
120  assert(argIndices.size() == argLocs.size());
121  if (argIndices.empty())
122  return;
123 
124  // There are 3 things that need to be updated:
125  // - Function type.
126  // - Arg attrs.
127  // - Block arguments of entry block.
128  Block &entry = op->getRegion(0).front();
129 
130  // Update the argument attributes of the function.
131  auto oldArgAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName());
132  if (oldArgAttrs || !argAttrs.empty()) {
133  SmallVector<DictionaryAttr, 4> newArgAttrs;
134  newArgAttrs.reserve(originalNumArgs + argIndices.size());
135  unsigned oldIdx = 0;
136  auto migrate = [&](unsigned untilIdx) {
137  if (!oldArgAttrs) {
138  newArgAttrs.resize(newArgAttrs.size() + untilIdx - oldIdx);
139  } else {
140  auto oldArgAttrRange = oldArgAttrs.getAsRange<DictionaryAttr>();
141  newArgAttrs.append(oldArgAttrRange.begin() + oldIdx,
142  oldArgAttrRange.begin() + untilIdx);
143  }
144  oldIdx = untilIdx;
145  };
146  for (unsigned i = 0, e = argIndices.size(); i < e; ++i) {
147  migrate(argIndices[i]);
148  newArgAttrs.push_back(argAttrs.empty() ? DictionaryAttr{} : argAttrs[i]);
149  }
150  migrate(originalNumArgs);
151  setAllArgAttrDicts(op, newArgAttrs);
152  }
153 
154  // Update the function type and any entry block arguments.
155  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
156  for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
157  entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]);
158 }
159 
161  Operation *op, ArrayRef<unsigned> resultIndices, TypeRange resultTypes,
162  ArrayRef<DictionaryAttr> resultAttrs, unsigned originalNumResults,
163  Type newType) {
164  assert(resultIndices.size() == resultTypes.size());
165  assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty());
166  if (resultIndices.empty())
167  return;
168 
169  // There are 2 things that need to be updated:
170  // - Function type.
171  // - Result attrs.
172 
173  // Update the result attributes of the function.
174  auto oldResultAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName());
175  if (oldResultAttrs || !resultAttrs.empty()) {
176  SmallVector<DictionaryAttr, 4> newResultAttrs;
177  newResultAttrs.reserve(originalNumResults + resultIndices.size());
178  unsigned oldIdx = 0;
179  auto migrate = [&](unsigned untilIdx) {
180  if (!oldResultAttrs) {
181  newResultAttrs.resize(newResultAttrs.size() + untilIdx - oldIdx);
182  } else {
183  auto oldResultAttrsRange = oldResultAttrs.getAsRange<DictionaryAttr>();
184  newResultAttrs.append(oldResultAttrsRange.begin() + oldIdx,
185  oldResultAttrsRange.begin() + untilIdx);
186  }
187  oldIdx = untilIdx;
188  };
189  for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) {
190  migrate(resultIndices[i]);
191  newResultAttrs.push_back(resultAttrs.empty() ? DictionaryAttr{}
192  : resultAttrs[i]);
193  }
194  migrate(originalNumResults);
195  setAllResultAttrDicts(op, newResultAttrs);
196  }
197 
198  // Update the function type.
199  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
200 }
201 
203  Operation *op, const BitVector &argIndices, Type newType) {
204  // There are 3 things that need to be updated:
205  // - Function type.
206  // - Arg attrs.
207  // - Block arguments of entry block.
208  Block &entry = op->getRegion(0).front();
209 
210  // Update the argument attributes of the function.
211  if (auto argAttrs = op->getAttrOfType<ArrayAttr>(getArgDictAttrName())) {
212  SmallVector<DictionaryAttr, 4> newArgAttrs;
213  newArgAttrs.reserve(argAttrs.size());
214  for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
215  if (!argIndices[i])
216  newArgAttrs.emplace_back(argAttrs[i].cast<DictionaryAttr>());
217  setAllArgAttrDicts(op, newArgAttrs);
218  }
219 
220  // Update the function type and any entry block arguments.
221  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
222  entry.eraseArguments(argIndices);
223 }
224 
226  Operation *op, const BitVector &resultIndices, Type newType) {
227  // There are 2 things that need to be updated:
228  // - Function type.
229  // - Result attrs.
230 
231  // Update the result attributes of the function.
232  if (auto resAttrs = op->getAttrOfType<ArrayAttr>(getResultDictAttrName())) {
233  SmallVector<DictionaryAttr, 4> newResultAttrs;
234  newResultAttrs.reserve(resAttrs.size());
235  for (unsigned i = 0, e = resultIndices.size(); i < e; ++i)
236  if (!resultIndices[i])
237  newResultAttrs.emplace_back(resAttrs[i].cast<DictionaryAttr>());
238  setAllResultAttrDicts(op, newResultAttrs);
239  }
240 
241  // Update the function type.
242  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
243 }
244 
246  TypeRange oldTypes, ArrayRef<unsigned> indices, TypeRange newTypes,
247  SmallVectorImpl<Type> &storage) {
248  assert(indices.size() == newTypes.size() &&
249  "mismatch between indice and type count");
250  if (indices.empty())
251  return oldTypes;
252 
253  auto fromIt = oldTypes.begin();
254  for (auto it : llvm::zip(indices, newTypes)) {
255  const auto toIt = oldTypes.begin() + std::get<0>(it);
256  storage.append(fromIt, toIt);
257  storage.push_back(std::get<1>(it));
258  fromIt = toIt;
259  }
260  storage.append(fromIt, oldTypes.end());
261  return storage;
262 }
263 
265  TypeRange types, const BitVector &indices, SmallVectorImpl<Type> &storage) {
266  if (indices.none())
267  return types;
268 
269  for (unsigned i = 0, e = types.size(); i < e; ++i)
270  if (!indices[i])
271  storage.emplace_back(types[i]);
272  return storage;
273 }
274 
275 //===----------------------------------------------------------------------===//
276 // Function type signature.
277 //===----------------------------------------------------------------------===//
278 
280  Type newType) {
281  FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
282  unsigned oldNumArgs = funcOp.getNumArguments();
283  unsigned oldNumResults = funcOp.getNumResults();
284  op->setAttr(getTypeAttrName(), TypeAttr::get(newType));
285  unsigned newNumArgs = funcOp.getNumArguments();
286  unsigned newNumResults = funcOp.getNumResults();
287 
288  // Functor used to update the argument and result attributes of the function.
289  auto updateAttrFn = [&](StringRef attrName, unsigned oldCount,
290  unsigned newCount, auto setAttrFn) {
291  if (oldCount == newCount)
292  return;
293  // The new type has no arguments/results, just drop the attribute.
294  if (newCount == 0) {
295  op->removeAttr(attrName);
296  return;
297  }
298  ArrayAttr attrs = op->getAttrOfType<ArrayAttr>(attrName);
299  if (!attrs)
300  return;
301 
302  // The new type has less arguments/results, take the first N attributes.
303  if (newCount < oldCount)
304  return setAttrFn(op, attrs.getValue().take_front(newCount));
305 
306  // Otherwise, the new type has more arguments/results. Initialize the new
307  // arguments/results with empty attributes.
308  SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end());
309  newAttrs.resize(newCount);
310  setAttrFn(op, newAttrs);
311  };
312 
313  // Update the argument and result attributes.
314  updateAttrFn(
315  getArgDictAttrName(), oldNumArgs, newNumArgs,
316  [&](Operation *op, auto &&attrs) { setAllArgAttrDicts(op, attrs); });
317  updateAttrFn(
318  getResultDictAttrName(), oldNumResults, newNumResults,
319  [&](Operation *op, auto &&attrs) { setAllResultAttrDicts(op, attrs); });
320 }
static void setAllArgResAttrDicts(Operation *op, StringRef attrName, ArrayRef< Attribute > attrs)
Set all of the argument or result attribute dictionaries for a function.
static bool isEmptyAttrDict(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U cast() const
Definition: Attributes.h:137
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument insertArgument(args_iterator it, Type type, Location loc)
Insert one value to the position in the argument list indicated by the given iterator.
Definition: Block.cpp:175
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:189
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:375
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:147
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:486
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:395
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
Definition: Operation.h:407
Block & front()
Definition: Region.h:65
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
void setArgResAttrDict(Operation *op, StringRef attrName, unsigned numTotalIndices, unsigned index, DictionaryAttr attrs)
Update the given index into an argument or result attribute dictionary.
StringRef getArgDictAttrName()
Return the name of the attribute used for function argument attributes.
void eraseFunctionResults(Operation *op, const BitVector &resultIndices, Type newType)
Erase the specified results and update the function type attribute.
void insertFunctionResults(Operation *op, ArrayRef< unsigned > resultIndices, TypeRange resultTypes, ArrayRef< DictionaryAttr > resultAttrs, unsigned originalNumResults, Type newType)
Insert the specified results and update the function type attribute.
DictionaryAttr getResultAttrDict(Operation *op, unsigned index)
Returns the dictionary attribute corresponding to the result at 'index'.
void insertFunctionArguments(Operation *op, ArrayRef< unsigned > argIndices, TypeRange argTypes, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< Location > argLocs, unsigned originalNumArgs, Type newType)
Insert the specified arguments and update the function type attribute.
TypeRange insertTypesInto(TypeRange oldTypes, ArrayRef< unsigned > indices, TypeRange newTypes, SmallVectorImpl< Type > &storage)
Insert a set of newTypes into oldTypes at the given indices.
StringRef getTypeAttrName()
Return the name of the attribute used for function types.
DictionaryAttr getArgAttrDict(Operation *op, unsigned index)
Returns the dictionary attribute corresponding to the argument at 'index'.
void setAllArgAttrDicts(Operation *op, ArrayRef< DictionaryAttr > attrs)
Set all of the argument or result attribute dictionaries for a function.
void eraseFunctionArguments(Operation *op, const BitVector &argIndices, Type newType)
Erase the specified arguments and update the function type attribute.
StringRef getResultDictAttrName()
Return the name of the attribute used for function argument attributes.
void setAllResultAttrDicts(Operation *op, ArrayRef< DictionaryAttr > attrs)
TypeRange filterTypesOut(TypeRange types, const BitVector &indices, SmallVectorImpl< Type > &storage)
Filters out any elements referenced by indices.
void setFunctionType(Operation *op, Type newType)
Set a FunctionOpInterface operation's type signature.
Include the generated interface declarations.