MLIR  20.0.0git
AttrTypeSubElements.cpp
Go to the documentation of this file.
1 //===- AttrTypeSubElements.cpp - Attr and Type SubElement Interfaces ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Operation.h"
10 #include <optional>
11 
12 using namespace mlir;
13 
14 //===----------------------------------------------------------------------===//
15 // AttrTypeWalker
16 //===----------------------------------------------------------------------===//
17 
18 WalkResult AttrTypeWalker::walkImpl(Attribute attr, WalkOrder order) {
19  return walkImpl(attr, attrWalkFns, order);
20 }
21 WalkResult AttrTypeWalker::walkImpl(Type type, WalkOrder order) {
22  return walkImpl(type, typeWalkFns, order);
23 }
24 
25 template <typename T, typename WalkFns>
26 WalkResult AttrTypeWalker::walkImpl(T element, WalkFns &walkFns,
27  WalkOrder order) {
28  // Check if we've already walk this element before.
29  auto key = std::make_pair(element.getAsOpaquePointer(), (int)order);
30  auto [it, inserted] =
31  visitedAttrTypes.try_emplace(key, WalkResult::advance());
32  if (!inserted)
33  return it->second;
34 
35  // If we are walking in post order, walk the sub elements first.
36  if (order == WalkOrder::PostOrder) {
37  if (walkSubElements(element, order).wasInterrupted())
38  return visitedAttrTypes[key] = WalkResult::interrupt();
39  }
40 
41  // Walk this element, bailing if skipped or interrupted.
42  for (auto &walkFn : llvm::reverse(walkFns)) {
43  WalkResult walkResult = walkFn(element);
44  if (walkResult.wasInterrupted())
45  return visitedAttrTypes[key] = WalkResult::interrupt();
46  if (walkResult.wasSkipped())
47  return WalkResult::advance();
48  }
49 
50  // If we are walking in pre-order, walk the sub elements last.
51  if (order == WalkOrder::PreOrder) {
52  if (walkSubElements(element, order).wasInterrupted())
53  return WalkResult::interrupt();
54  }
55  return WalkResult::advance();
56 }
57 
58 template <typename T>
59 WalkResult AttrTypeWalker::walkSubElements(T interface, WalkOrder order) {
61  auto walkFn = [&](auto element) {
62  if (element && !result.wasInterrupted())
63  result = walkImpl(element, order);
64  };
65  interface.walkImmediateSubElements(walkFn, walkFn);
66  return result.wasInterrupted() ? result : WalkResult::advance();
67 }
68 
69 //===----------------------------------------------------------------------===//
70 /// AttrTypeReplacerBase
71 //===----------------------------------------------------------------------===//
72 
73 template <typename Concrete>
76  attrReplacementFns.emplace_back(std::move(fn));
77 }
78 
79 template <typename Concrete>
81  ReplaceFn<Type> fn) {
82  typeReplacementFns.push_back(std::move(fn));
83 }
84 
85 template <typename Concrete>
87  Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) {
88  // Functor that replaces the given element if the new value is different,
89  // otherwise returns nullptr.
90  auto replaceIfDifferent = [&](auto element) {
91  auto replacement = static_cast<Concrete *>(this)->replace(element);
92  return (replacement && replacement != element) ? replacement : nullptr;
93  };
94 
95  // Update the attribute dictionary.
96  if (replaceAttrs) {
97  if (auto newAttrs = replaceIfDifferent(op->getAttrDictionary()))
98  op->setAttrs(cast<DictionaryAttr>(newAttrs));
99  }
100 
101  // If we aren't updating locations or types, we're done.
102  if (!replaceTypes && !replaceLocs)
103  return;
104 
105  // Update the location.
106  if (replaceLocs) {
107  if (Attribute newLoc = replaceIfDifferent(op->getLoc()))
108  op->setLoc(cast<LocationAttr>(newLoc));
109  }
110 
111  // Update the result types.
112  if (replaceTypes) {
113  for (OpResult result : op->getResults())
114  if (Type newType = replaceIfDifferent(result.getType()))
115  result.setType(newType);
116  }
117 
118  // Update any nested block arguments.
119  for (Region &region : op->getRegions()) {
120  for (Block &block : region) {
121  for (BlockArgument &arg : block.getArguments()) {
122  if (replaceLocs) {
123  if (Attribute newLoc = replaceIfDifferent(arg.getLoc()))
124  arg.setLoc(cast<LocationAttr>(newLoc));
125  }
126 
127  if (replaceTypes) {
128  if (Type newType = replaceIfDifferent(arg.getType()))
129  arg.setType(newType);
130  }
131  }
132  }
133  }
134 }
135 
136 template <typename Concrete>
138  Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) {
139  op->walk([&](Operation *nestedOp) {
140  replaceElementsIn(nestedOp, replaceAttrs, replaceLocs, replaceTypes);
141  });
142 }
143 
144 template <typename T, typename Replacer>
145 static void updateSubElementImpl(T element, Replacer &replacer,
146  SmallVectorImpl<T> &newElements,
147  FailureOr<bool> &changed) {
148  // Bail early if we failed at any point.
149  if (failed(changed))
150  return;
151 
152  // Guard against potentially null inputs. We always map null to null.
153  if (!element) {
154  newElements.push_back(nullptr);
155  return;
156  }
157 
158  // Replace the element.
159  if (T result = replacer.replace(element)) {
160  newElements.push_back(result);
161  if (result != element)
162  changed = true;
163  } else {
164  changed = failure();
165  }
166 }
167 
168 template <typename T, typename Replacer>
169 static T replaceSubElements(T interface, Replacer &replacer) {
170  // Walk the current sub-elements, replacing them as necessary.
172  SmallVector<Type, 16> newTypes;
173  FailureOr<bool> changed = false;
174  interface.walkImmediateSubElements(
175  [&](Attribute element) {
176  updateSubElementImpl(element, replacer, newAttrs, changed);
177  },
178  [&](Type element) {
179  updateSubElementImpl(element, replacer, newTypes, changed);
180  });
181  if (failed(changed))
182  return nullptr;
183 
184  // If any sub-elements changed, use the new elements during the replacement.
185  T result = interface;
186  if (*changed)
187  result = interface.replaceImmediateSubElements(newAttrs, newTypes);
188  return result;
189 }
190 
191 /// Shared implementation of replacing a given attribute or type element.
192 template <typename T, typename ReplaceFns, typename Replacer>
193 static T replaceElementImpl(T element, ReplaceFns &replaceFns,
194  Replacer &replacer) {
195  T result = element;
196  WalkResult walkResult = WalkResult::advance();
197  for (auto &replaceFn : llvm::reverse(replaceFns)) {
198  if (std::optional<std::pair<T, WalkResult>> newRes = replaceFn(element)) {
199  std::tie(result, walkResult) = *newRes;
200  break;
201  }
202  }
203 
204  // If an error occurred, return nullptr to indicate failure.
205  if (walkResult.wasInterrupted() || !result) {
206  return nullptr;
207  }
208 
209  // Handle replacing sub-elements if this element is also a container.
210  if (!walkResult.wasSkipped()) {
211  // Replace the sub elements of this element, bailing if we fail.
212  if (!(result = replaceSubElements(result, replacer))) {
213  return nullptr;
214  }
215  }
216 
217  return result;
218 }
219 
220 template <typename Concrete>
222  return replaceElementImpl(attr, attrReplacementFns,
223  *static_cast<Concrete *>(this));
224 }
225 
226 template <typename Concrete>
228  return replaceElementImpl(type, typeReplacementFns,
229  *static_cast<Concrete *>(this));
230 }
231 
232 //===----------------------------------------------------------------------===//
233 /// AttrTypeReplacer
234 //===----------------------------------------------------------------------===//
235 
237 
238 template <typename T>
239 T AttrTypeReplacer::cachedReplaceImpl(T element) {
240  const void *opaqueElement = element.getAsOpaquePointer();
241  auto [it, inserted] = cache.try_emplace(opaqueElement, opaqueElement);
242  if (!inserted)
243  return T::getFromOpaquePointer(it->second);
244 
245  T result = replaceBase(element);
246 
247  cache[opaqueElement] = result.getAsOpaquePointer();
248  return result;
249 }
250 
252  return cachedReplaceImpl(attr);
253 }
254 
255 Type AttrTypeReplacer::replace(Type type) { return cachedReplaceImpl(type); }
256 
257 //===----------------------------------------------------------------------===//
258 /// CyclicAttrTypeReplacer
259 //===----------------------------------------------------------------------===//
260 
262 
264  : cache([&](void *attr) { return breakCycleImpl(attr); }) {}
265 
267  attrCycleBreakerFns.emplace_back(std::move(fn));
268 }
269 
271  typeCycleBreakerFns.emplace_back(std::move(fn));
272 }
273 
274 template <typename T>
275 T CyclicAttrTypeReplacer::cachedReplaceImpl(T element) {
276  void *opaqueTaggedElement = AttrOrType(element).getOpaqueValue();
278  cache.lookupOrInit(opaqueTaggedElement);
279  if (auto resultOpt = cacheEntry.get())
280  return T::getFromOpaquePointer(*resultOpt);
281 
282  T result = replaceBase(element);
283 
284  cacheEntry.resolve(result.getAsOpaquePointer());
285  return result;
286 }
287 
289  return cachedReplaceImpl(attr);
290 }
291 
293  return cachedReplaceImpl(type);
294 }
295 
296 std::optional<const void *>
297 CyclicAttrTypeReplacer::breakCycleImpl(void *element) {
298  AttrOrType attrType = AttrOrType::getFromOpaqueValue(element);
299  if (auto attr = dyn_cast<Attribute>(attrType)) {
300  for (auto &cyclicReplaceFn : llvm::reverse(attrCycleBreakerFns)) {
301  if (std::optional<Attribute> newRes = cyclicReplaceFn(attr)) {
302  return newRes->getAsOpaquePointer();
303  }
304  }
305  } else {
306  auto type = dyn_cast<Type>(attrType);
307  for (auto &cyclicReplaceFn : llvm::reverse(typeCycleBreakerFns)) {
308  if (std::optional<Type> newRes = cyclicReplaceFn(type)) {
309  return newRes->getAsOpaquePointer();
310  }
311  }
312  }
313  return std::nullopt;
314 }
315 
316 //===----------------------------------------------------------------------===//
317 // AttrTypeImmediateSubElementWalker
318 //===----------------------------------------------------------------------===//
319 
321  if (element)
322  walkAttrsFn(element);
323 }
324 
326  if (element)
327  walkTypesFn(element);
328 }
static void updateSubElementImpl(T element, Replacer &replacer, SmallVectorImpl< T > &newElements, FailureOr< bool > &changed)
static T replaceElementImpl(T element, ReplaceFns &replaceFns, Replacer &replacer)
Shared implementation of replacing a given attribute or type element.
static T replaceSubElements(T interface, Replacer &replacer)
void walk(Attribute element)
Walk an attribute.
Attribute replace(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
void addCycleBreaker(CycleBreakerFn< Attribute > fn)
Register a cycle-breaking function.
Attribute replace(Attribute attr)
std::function< std::optional< T >(T)> CycleBreakerFn
A cycle-breaking function.
A cache for replacer-like functions that map values between two domains.
CacheEntry lookupOrInit(InT element)
Lookup the cache for a pre-calculated replacement for element.
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition: Operation.h:226
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Definition: Operation.cpp:296
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
result_range getResults()
Definition: Operation.h:410
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
bool wasSkipped() const
Returns true if the walk was skipped.
Definition: Visitors.h:58
static WalkResult advance()
Definition: Visitors.h:51
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
static WalkResult interrupt()
Definition: Visitors.h:50
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
Attribute replaceBase(Attribute attr)
Invokes the registered replacement functions from most recently registered to least recently register...
std::function< ReplaceFnResult< T >(T)> ReplaceFn
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
void replaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation.
Include the generated interface declarations.
WalkOrder
Traversal order for region, block and operation walk utilities.
Definition: Visitors.h:62