MLIR  19.0.0git
AttrTypeSubElements.cpp
Go to the documentation of this file.
1 //===- AttrTypeSubElements.cpp - Attr and Type SubElement Interfaces ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Operation.h"
10 #include <optional>
11 
12 using namespace mlir;
13 
14 //===----------------------------------------------------------------------===//
15 // AttrTypeWalker
16 //===----------------------------------------------------------------------===//
17 
18 WalkResult AttrTypeWalker::walkImpl(Attribute attr, WalkOrder order) {
19  return walkImpl(attr, attrWalkFns, order);
20 }
21 WalkResult AttrTypeWalker::walkImpl(Type type, WalkOrder order) {
22  return walkImpl(type, typeWalkFns, order);
23 }
24 
25 template <typename T, typename WalkFns>
26 WalkResult AttrTypeWalker::walkImpl(T element, WalkFns &walkFns,
27  WalkOrder order) {
28  // Check if we've already walk this element before.
29  auto key = std::make_pair(element.getAsOpaquePointer(), (int)order);
30  auto it = visitedAttrTypes.find(key);
31  if (it != visitedAttrTypes.end())
32  return it->second;
33  visitedAttrTypes.try_emplace(key, WalkResult::advance());
34 
35  // If we are walking in post order, walk the sub elements first.
36  if (order == WalkOrder::PostOrder) {
37  if (walkSubElements(element, order).wasInterrupted())
38  return visitedAttrTypes[key] = WalkResult::interrupt();
39  }
40 
41  // Walk this element, bailing if skipped or interrupted.
42  for (auto &walkFn : llvm::reverse(walkFns)) {
43  WalkResult walkResult = walkFn(element);
44  if (walkResult.wasInterrupted())
45  return visitedAttrTypes[key] = WalkResult::interrupt();
46  if (walkResult.wasSkipped())
47  return WalkResult::advance();
48  }
49 
50  // If we are walking in pre-order, walk the sub elements last.
51  if (order == WalkOrder::PreOrder) {
52  if (walkSubElements(element, order).wasInterrupted())
53  return WalkResult::interrupt();
54  }
55  return WalkResult::advance();
56 }
57 
58 template <typename T>
59 WalkResult AttrTypeWalker::walkSubElements(T interface, WalkOrder order) {
61  auto walkFn = [&](auto element) {
62  if (element && !result.wasInterrupted())
63  result = walkImpl(element, order);
64  };
65  interface.walkImmediateSubElements(walkFn, walkFn);
66  return result.wasInterrupted() ? result : WalkResult::advance();
67 }
68 
69 //===----------------------------------------------------------------------===//
70 /// AttrTypeReplacer
71 //===----------------------------------------------------------------------===//
72 
74  attrReplacementFns.emplace_back(std::move(fn));
75 }
77  typeReplacementFns.push_back(std::move(fn));
78 }
79 
80 void AttrTypeReplacer::replaceElementsIn(Operation *op, bool replaceAttrs,
81  bool replaceLocs, bool replaceTypes) {
82  // Functor that replaces the given element if the new value is different,
83  // otherwise returns nullptr.
84  auto replaceIfDifferent = [&](auto element) {
85  auto replacement = replace(element);
86  return (replacement && replacement != element) ? replacement : nullptr;
87  };
88 
89  // Update the attribute dictionary.
90  if (replaceAttrs) {
91  if (auto newAttrs = replaceIfDifferent(op->getAttrDictionary()))
92  op->setAttrs(cast<DictionaryAttr>(newAttrs));
93  }
94 
95  // If we aren't updating locations or types, we're done.
96  if (!replaceTypes && !replaceLocs)
97  return;
98 
99  // Update the location.
100  if (replaceLocs) {
101  if (Attribute newLoc = replaceIfDifferent(op->getLoc()))
102  op->setLoc(cast<LocationAttr>(newLoc));
103  }
104 
105  // Update the result types.
106  if (replaceTypes) {
107  for (OpResult result : op->getResults())
108  if (Type newType = replaceIfDifferent(result.getType()))
109  result.setType(newType);
110  }
111 
112  // Update any nested block arguments.
113  for (Region &region : op->getRegions()) {
114  for (Block &block : region) {
115  for (BlockArgument &arg : block.getArguments()) {
116  if (replaceLocs) {
117  if (Attribute newLoc = replaceIfDifferent(arg.getLoc()))
118  arg.setLoc(cast<LocationAttr>(newLoc));
119  }
120 
121  if (replaceTypes) {
122  if (Type newType = replaceIfDifferent(arg.getType()))
123  arg.setType(newType);
124  }
125  }
126  }
127  }
128 }
129 
131  bool replaceAttrs,
132  bool replaceLocs,
133  bool replaceTypes) {
134  op->walk([&](Operation *nestedOp) {
135  replaceElementsIn(nestedOp, replaceAttrs, replaceLocs, replaceTypes);
136  });
137 }
138 
139 template <typename T>
140 static void updateSubElementImpl(T element, AttrTypeReplacer &replacer,
141  SmallVectorImpl<T> &newElements,
142  FailureOr<bool> &changed) {
143  // Bail early if we failed at any point.
144  if (failed(changed))
145  return;
146 
147  // Guard against potentially null inputs. We always map null to null.
148  if (!element) {
149  newElements.push_back(nullptr);
150  return;
151  }
152 
153  // Replace the element.
154  if (T result = replacer.replace(element)) {
155  newElements.push_back(result);
156  if (result != element)
157  changed = true;
158  } else {
159  changed = failure();
160  }
161 }
162 
163 template <typename T>
164 T AttrTypeReplacer::replaceSubElements(T interface) {
165  // Walk the current sub-elements, replacing them as necessary.
167  SmallVector<Type, 16> newTypes;
168  FailureOr<bool> changed = false;
169  interface.walkImmediateSubElements(
170  [&](Attribute element) {
171  updateSubElementImpl(element, *this, newAttrs, changed);
172  },
173  [&](Type element) {
174  updateSubElementImpl(element, *this, newTypes, changed);
175  });
176  if (failed(changed))
177  return nullptr;
178 
179  // If any sub-elements changed, use the new elements during the replacement.
180  T result = interface;
181  if (*changed)
182  result = interface.replaceImmediateSubElements(newAttrs, newTypes);
183  return result;
184 }
185 
186 /// Shared implementation of replacing a given attribute or type element.
187 template <typename T, typename ReplaceFns>
188 T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns) {
189  const void *opaqueElement = element.getAsOpaquePointer();
190  auto [it, inserted] = attrTypeMap.try_emplace(opaqueElement, opaqueElement);
191  if (!inserted)
192  return T::getFromOpaquePointer(it->second);
193 
194  T result = element;
195  WalkResult walkResult = WalkResult::advance();
196  for (auto &replaceFn : llvm::reverse(replaceFns)) {
197  if (std::optional<std::pair<T, WalkResult>> newRes = replaceFn(element)) {
198  std::tie(result, walkResult) = *newRes;
199  break;
200  }
201  }
202 
203  // If an error occurred, return nullptr to indicate failure.
204  if (walkResult.wasInterrupted() || !result) {
205  attrTypeMap[opaqueElement] = nullptr;
206  return nullptr;
207  }
208 
209  // Handle replacing sub-elements if this element is also a container.
210  if (!walkResult.wasSkipped()) {
211  // Replace the sub elements of this element, bailing if we fail.
212  if (!(result = replaceSubElements(result))) {
213  attrTypeMap[opaqueElement] = nullptr;
214  return nullptr;
215  }
216  }
217 
218  attrTypeMap[opaqueElement] = result.getAsOpaquePointer();
219  return result;
220 }
221 
223  return replaceImpl(attr, attrReplacementFns);
224 }
225 
227  return replaceImpl(type, typeReplacementFns);
228 }
229 
230 //===----------------------------------------------------------------------===//
231 // AttrTypeImmediateSubElementWalker
232 //===----------------------------------------------------------------------===//
233 
235  if (element)
236  walkAttrsFn(element);
237 }
238 
240  if (element)
241  walkTypesFn(element);
242 }
static void updateSubElementImpl(T element, AttrTypeReplacer &replacer, SmallVectorImpl< T > &newElements, FailureOr< bool > &changed)
void walk(Attribute element)
Walk an attribute.
void replaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation.
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
std::function< ReplaceFnResult< T >(T)> ReplaceFn
Attribute replace(Attribute attr)
Replace the given attribute/type, and recursively replace any sub elements.
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:30
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition: Operation.h:226
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Definition: Operation.cpp:296
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:672
result_range getResults()
Definition: Operation.h:410
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
bool wasSkipped() const
Returns true if the walk was skipped.
Definition: Visitors.h:59
static WalkResult advance()
Definition: Visitors.h:52
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:56
static WalkResult interrupt()
Definition: Visitors.h:51
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
WalkOrder
Traversal order for region, block and operation walk utilities.
Definition: Visitors.h:63
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72