19 return walkImpl(attr, attrWalkFns, order);
22 return walkImpl(type, typeWalkFns, order);
25 template <
typename T,
typename WalkFns>
26 WalkResult AttrTypeWalker::walkImpl(T element, WalkFns &walkFns,
29 auto key = std::make_pair(element.getAsOpaquePointer(), (
int)order);
37 if (walkSubElements(element, order).wasInterrupted())
42 for (
auto &walkFn : llvm::reverse(walkFns)) {
52 if (walkSubElements(element, order).wasInterrupted())
61 auto walkFn = [&](
auto element) {
63 result = walkImpl(element, order);
65 interface.walkImmediateSubElements(walkFn, walkFn);
73 template <
typename Concrete>
76 attrReplacementFns.emplace_back(std::move(fn));
79 template <
typename Concrete>
82 typeReplacementFns.push_back(std::move(fn));
85 template <
typename Concrete>
87 Operation *op,
bool replaceAttrs,
bool replaceLocs,
bool replaceTypes) {
90 auto replaceIfDifferent = [&](
auto element) {
91 auto replacement =
static_cast<Concrete *
>(
this)->replace(element);
92 return (replacement && replacement != element) ? replacement :
nullptr;
98 op->
setAttrs(cast<DictionaryAttr>(newAttrs));
102 if (!replaceTypes && !replaceLocs)
108 op->
setLoc(cast<LocationAttr>(newLoc));
114 if (
Type newType = replaceIfDifferent(result.getType()))
115 result.setType(newType);
120 for (
Block &block : region) {
123 if (
Attribute newLoc = replaceIfDifferent(arg.getLoc()))
124 arg.setLoc(cast<LocationAttr>(newLoc));
128 if (
Type newType = replaceIfDifferent(arg.getType()))
129 arg.setType(newType);
136 template <
typename Concrete>
138 Operation *op,
bool replaceAttrs,
bool replaceLocs,
bool replaceTypes) {
140 replaceElementsIn(nestedOp, replaceAttrs, replaceLocs, replaceTypes);
144 template <
typename T,
typename Replacer>
147 FailureOr<bool> &changed) {
154 newElements.push_back(
nullptr);
159 if (T result = replacer.replace(element)) {
160 newElements.push_back(result);
161 if (result != element)
168 template <
typename T,
typename Replacer>
173 FailureOr<bool> changed =
false;
174 interface.walkImmediateSubElements(
185 T result = interface;
187 result = interface.replaceImmediateSubElements(newAttrs, newTypes);
192 template <
typename T,
typename ReplaceFns,
typename Replacer>
194 Replacer &replacer) {
197 for (
auto &replaceFn : llvm::reverse(replaceFns)) {
198 if (std::optional<std::pair<T, WalkResult>> newRes = replaceFn(element)) {
199 std::tie(result, walkResult) = *newRes;
220 template <
typename Concrete>
223 *
static_cast<Concrete *
>(
this));
226 template <
typename Concrete>
229 *
static_cast<Concrete *
>(
this));
238 template <
typename T>
239 T AttrTypeReplacer::cachedReplaceImpl(T element) {
240 const void *opaqueElement = element.getAsOpaquePointer();
241 auto [it, inserted] = cache.try_emplace(opaqueElement, opaqueElement);
243 return T::getFromOpaquePointer(it->second);
247 cache[opaqueElement] = result.getAsOpaquePointer();
252 return cachedReplaceImpl(attr);
264 : cache([&](void *attr) {
return breakCycleImpl(attr); }) {}
267 attrCycleBreakerFns.emplace_back(std::move(fn));
271 typeCycleBreakerFns.emplace_back(std::move(fn));
274 template <
typename T>
275 T CyclicAttrTypeReplacer::cachedReplaceImpl(T element) {
276 void *opaqueTaggedElement = AttrOrType(element).getOpaqueValue();
279 if (
auto resultOpt = cacheEntry.get())
280 return T::getFromOpaquePointer(*resultOpt);
284 cacheEntry.resolve(result.getAsOpaquePointer());
289 return cachedReplaceImpl(attr);
293 return cachedReplaceImpl(type);
296 std::optional<const void *>
297 CyclicAttrTypeReplacer::breakCycleImpl(
void *element) {
298 AttrOrType attrType = AttrOrType::getFromOpaqueValue(element);
299 if (
auto attr = dyn_cast<Attribute>(attrType)) {
300 for (
auto &cyclicReplaceFn : llvm::reverse(attrCycleBreakerFns)) {
301 if (std::optional<Attribute> newRes = cyclicReplaceFn(attr)) {
302 return newRes->getAsOpaquePointer();
306 auto type = dyn_cast<Type>(attrType);
307 for (
auto &cyclicReplaceFn : llvm::reverse(typeCycleBreakerFns)) {
308 if (std::optional<Type> newRes = cyclicReplaceFn(type)) {
309 return newRes->getAsOpaquePointer();
322 walkAttrsFn(element);
327 walkTypesFn(element);
static void updateSubElementImpl(T element, Replacer &replacer, SmallVectorImpl< T > &newElements, FailureOr< bool > &changed)
static T replaceElementImpl(T element, ReplaceFns &replaceFns, Replacer &replacer)
Shared implementation of replacing a given attribute or type element.
static T replaceSubElements(T interface, Replacer &replacer)
Attribute replace(Attribute attr)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
void addCycleBreaker(CycleBreakerFn< Attribute > fn)
Register a cycle-breaking function.
Attribute replace(Attribute attr)
std::function< std::optional< T >(T)> CycleBreakerFn
A cycle-breaking function.
A cache for replacer-like functions that map values between two domains.
CacheEntry lookupOrInit(InT element)
Lookup the cache for a pre-calculated replacement for element.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_range getResults()
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
A utility result that is used to signal how to proceed with an ongoing walk:
bool wasSkipped() const
Returns true if the walk was skipped.
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
Attribute replaceBase(Attribute attr)
Invokes the registered replacement functions from most recently registered to least recently register...
std::function< ReplaceFnResult< T >(T)> ReplaceFn
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
void replaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation.
Include the generated interface declarations.
WalkOrder
Traversal order for region, block and operation walk utilities.