MLIR 22.0.0git
AttrTypeSubElements.cpp
Go to the documentation of this file.
1//===- AttrTypeSubElements.cpp - Attr and Type SubElement Interfaces ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/Operation.h"
10#include <optional>
11
12using namespace mlir;
13
14//===----------------------------------------------------------------------===//
15// AttrTypeWalker
16//===----------------------------------------------------------------------===//
17
18WalkResult AttrTypeWalker::walkImpl(Attribute attr, WalkOrder order) {
19 return walkImpl(attr, attrWalkFns, order);
20}
21WalkResult AttrTypeWalker::walkImpl(Type type, WalkOrder order) {
22 return walkImpl(type, typeWalkFns, order);
23}
24
25template <typename T, typename WalkFns>
26WalkResult AttrTypeWalker::walkImpl(T element, WalkFns &walkFns,
27 WalkOrder order) {
28 // Check if we've already walk this element before.
29 auto key = std::make_pair(element.getAsOpaquePointer(), (int)order);
30 auto [it, inserted] =
31 visitedAttrTypes.try_emplace(key, WalkResult::advance());
32 if (!inserted)
33 return it->second;
34
35 // If we are walking in post order, walk the sub elements first.
36 if (order == WalkOrder::PostOrder) {
37 if (walkSubElements(element, order).wasInterrupted())
38 return visitedAttrTypes[key] = WalkResult::interrupt();
39 }
40
41 // Walk this element, bailing if skipped or interrupted.
42 for (auto &walkFn : llvm::reverse(walkFns)) {
43 WalkResult walkResult = walkFn(element);
44 if (walkResult.wasInterrupted())
45 return visitedAttrTypes[key] = WalkResult::interrupt();
46 if (walkResult.wasSkipped())
47 return WalkResult::advance();
48 }
49
50 // If we are walking in pre-order, walk the sub elements last.
51 if (order == WalkOrder::PreOrder) {
52 if (walkSubElements(element, order).wasInterrupted())
53 return WalkResult::interrupt();
54 }
55 return WalkResult::advance();
56}
57
58template <typename T>
59WalkResult AttrTypeWalker::walkSubElements(T interface, WalkOrder order) {
60 WalkResult result = WalkResult::advance();
61 auto walkFn = [&](auto element) {
62 if (element && !result.wasInterrupted())
63 result = walkImpl(element, order);
64 };
65 interface.walkImmediateSubElements(walkFn, walkFn);
66 return result.wasInterrupted() ? result : WalkResult::advance();
67}
68
69//===----------------------------------------------------------------------===//
70/// AttrTypeReplacerBase
71//===----------------------------------------------------------------------===//
72
73template <typename Concrete>
76 attrReplacementFns.emplace_back(std::move(fn));
77}
78
79template <typename Concrete>
81 ReplaceFn<Type> fn) {
82 typeReplacementFns.push_back(std::move(fn));
83}
84
85template <typename Concrete>
87 Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) {
88 // Functor that replaces the given element if the new value is different,
89 // otherwise returns nullptr.
90 auto replaceIfDifferent = [&](auto element) {
91 auto replacement = static_cast<Concrete *>(this)->replace(element);
92 return (replacement && replacement != element) ? replacement : nullptr;
93 };
94
95 // Update the attribute dictionary.
96 if (replaceAttrs) {
97 if (auto newAttrs = replaceIfDifferent(op->getAttrDictionary()))
98 op->setAttrs(cast<DictionaryAttr>(newAttrs));
99 }
100
101 // If we aren't updating locations or types, we're done.
102 if (!replaceTypes && !replaceLocs)
103 return;
104
105 // Update the location.
106 if (replaceLocs) {
107 if (Attribute newLoc = replaceIfDifferent(op->getLoc()))
108 op->setLoc(cast<LocationAttr>(newLoc));
109 }
110
111 // Update the result types.
112 if (replaceTypes) {
113 for (OpResult result : op->getResults())
114 if (Type newType = replaceIfDifferent(result.getType()))
115 result.setType(newType);
116 }
117
118 // Update any nested block arguments.
119 for (Region &region : op->getRegions()) {
120 for (Block &block : region) {
121 for (BlockArgument &arg : block.getArguments()) {
122 if (replaceLocs) {
123 if (Attribute newLoc = replaceIfDifferent(arg.getLoc()))
124 arg.setLoc(cast<LocationAttr>(newLoc));
125 }
126
127 if (replaceTypes) {
128 if (Type newType = replaceIfDifferent(arg.getType()))
129 arg.setType(newType);
130 }
131 }
132 }
133 }
134}
135
136template <typename Concrete>
138 Operation *op, bool replaceAttrs, bool replaceLocs, bool replaceTypes) {
139 op->walk([&](Operation *nestedOp) {
140 replaceElementsIn(nestedOp, replaceAttrs, replaceLocs, replaceTypes);
141 });
142}
143
144template <typename T, typename Replacer>
145static void updateSubElementImpl(T element, Replacer &replacer,
146 SmallVectorImpl<T> &newElements,
147 FailureOr<bool> &changed) {
148 // Bail early if we failed at any point.
149 if (failed(changed))
150 return;
152 // Guard against potentially null inputs. We always map null to null.
153 if (!element) {
154 newElements.push_back(nullptr);
155 return;
156 }
157
158 // Replace the element.
159 if (T result = replacer.replace(element)) {
160 newElements.push_back(result);
161 if (result != element)
162 changed = true;
163 } else {
164 changed = failure();
165 }
166}
167
168template <typename T, typename Replacer>
169static T replaceSubElements(T interface, Replacer &replacer) {
170 // Walk the current sub-elements, replacing them as necessary.
172 SmallVector<Type, 16> newTypes;
173 FailureOr<bool> changed = false;
174 interface.walkImmediateSubElements(
175 [&](Attribute element) {
176 updateSubElementImpl(element, replacer, newAttrs, changed);
177 },
178 [&](Type element) {
179 updateSubElementImpl(element, replacer, newTypes, changed);
180 });
181 if (failed(changed))
182 return nullptr;
183
184 // If any sub-elements changed, use the new elements during the replacement.
185 T result = interface;
186 if (*changed)
187 result = interface.replaceImmediateSubElements(newAttrs, newTypes);
188 return result;
189}
190
191/// Shared implementation of replacing a given attribute or type element.
192template <typename T, typename ReplaceFns, typename Replacer>
193static T replaceElementImpl(T element, ReplaceFns &replaceFns,
194 Replacer &replacer) {
195 T result = element;
196 WalkResult walkResult = WalkResult::advance();
197 for (auto &replaceFn : llvm::reverse(replaceFns)) {
198 if (std::optional<std::pair<T, WalkResult>> newRes = replaceFn(element)) {
199 std::tie(result, walkResult) = *newRes;
200 break;
201 }
202 }
203
204 // If an error occurred, return nullptr to indicate failure.
205 if (walkResult.wasInterrupted() || !result) {
206 return nullptr;
207 }
208
209 // Handle replacing sub-elements if this element is also a container.
210 if (!walkResult.wasSkipped()) {
211 // Replace the sub elements of this element, bailing if we fail.
212 if (!(result = replaceSubElements(result, replacer))) {
213 return nullptr;
214 }
215 }
216
217 return result;
218}
219
220template <typename Concrete>
222 return replaceElementImpl(attr, attrReplacementFns,
223 *static_cast<Concrete *>(this));
224}
225
226template <typename Concrete>
228 return replaceElementImpl(type, typeReplacementFns,
229 *static_cast<Concrete *>(this));
230}
231
232//===----------------------------------------------------------------------===//
233/// AttrTypeReplacer
234//===----------------------------------------------------------------------===//
235
237
238template <typename T>
239T AttrTypeReplacer::cachedReplaceImpl(T element) {
240 const void *opaqueElement = element.getAsOpaquePointer();
241 auto [it, inserted] = cache.try_emplace(opaqueElement, opaqueElement);
242 if (!inserted)
243 return T::getFromOpaquePointer(it->second);
244
245 T result = replaceBase(element);
246
247 cache[opaqueElement] = result.getAsOpaquePointer();
248 return result;
249}
250
252 return cachedReplaceImpl(attr);
253}
254
255Type AttrTypeReplacer::replace(Type type) { return cachedReplaceImpl(type); }
256
257//===----------------------------------------------------------------------===//
258/// CyclicAttrTypeReplacer
259//===----------------------------------------------------------------------===//
260
262
264 : cache([&](void *attr) { return breakCycleImpl(attr); }) {}
265
267 attrCycleBreakerFns.emplace_back(std::move(fn));
268}
269
271 typeCycleBreakerFns.emplace_back(std::move(fn));
272}
273
274template <typename T>
275T CyclicAttrTypeReplacer::cachedReplaceImpl(T element) {
276 void *opaqueTaggedElement = AttrOrType(element).getOpaqueValue();
278 cache.lookupOrInit(opaqueTaggedElement);
279 if (auto resultOpt = cacheEntry.get())
280 return T::getFromOpaquePointer(*resultOpt);
281
282 T result = replaceBase(element);
283
284 cacheEntry.resolve(result.getAsOpaquePointer());
285 return result;
286}
287
289 return cachedReplaceImpl(attr);
290}
291
293 return cachedReplaceImpl(type);
294}
295
296std::optional<const void *>
297CyclicAttrTypeReplacer::breakCycleImpl(void *element) {
298 AttrOrType attrType = AttrOrType::getFromOpaqueValue(element);
299 if (auto attr = dyn_cast<Attribute>(attrType)) {
300 for (auto &cyclicReplaceFn : llvm::reverse(attrCycleBreakerFns)) {
301 if (std::optional<Attribute> newRes = cyclicReplaceFn(attr)) {
302 return newRes->getAsOpaquePointer();
303 }
304 }
305 } else {
306 auto type = dyn_cast<Type>(attrType);
307 for (auto &cyclicReplaceFn : llvm::reverse(typeCycleBreakerFns)) {
308 if (std::optional<Type> newRes = cyclicReplaceFn(type)) {
309 return newRes->getAsOpaquePointer();
310 }
311 }
312 }
313 return std::nullopt;
314}
315
316//===----------------------------------------------------------------------===//
317// AttrTypeImmediateSubElementWalker
318//===----------------------------------------------------------------------===//
319
321 if (element)
322 walkAttrsFn(element);
323}
324
326 if (element)
327 walkTypesFn(element);
328}
static void updateSubElementImpl(T element, Replacer &replacer, SmallVectorImpl< T > &newElements, FailureOr< bool > &changed)
static T replaceElementImpl(T element, ReplaceFns &replaceFns, Replacer &replacer)
Shared implementation of replacing a given attribute or type element.
static T replaceSubElements(T interface, Replacer &replacer)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
void walk(Attribute element)
Walk an attribute.
Attribute replace(Attribute attr)
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
void addCycleBreaker(CycleBreakerFn< Attribute > fn)
Register a cycle-breaking function.
Attribute replace(Attribute attr)
std::function< std::optional< T >(T)> CycleBreakerFn
A cycle-breaking function.
CacheEntry lookupOrInit(InT element)
Lookup the cache for a pre-calculated replacement for element.
This is a value defined by a result of an operation.
Definition Value.h:457
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
Definition Operation.h:226
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
result_range getResults()
Definition Operation.h:415
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
bool wasSkipped() const
Returns true if the walk was skipped.
Definition WalkResult.h:54
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
This class provides a base utility for replacing attributes/types, and their sub elements.
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
Attribute replaceBase(Attribute attr)
Invokes the registered replacement functions from most recently registered to least recently register...
std::function< ReplaceFnResult< T >(T)> ReplaceFn
void replaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation.
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
WalkOrder
Traversal order for region, block and operation walk utilities.
Definition Visitors.h:28
A possibly unresolved cache entry.
void resolve(OutT result)
Resolve an unresolved cache entry by providing the result to be stored in the cache.
const std::optional< OutT > & get() const
Get the resolved result if one exists.