19 return walkImpl(attr, attrWalkFns, order);
22 return walkImpl(type, typeWalkFns, order);
25 template <
typename T,
typename WalkFns>
26 WalkResult AttrTypeWalker::walkImpl(T element, WalkFns &walkFns,
29 auto key = std::make_pair(element.getAsOpaquePointer(), (
int)order);
30 auto it = visitedAttrTypes.find(key);
31 if (it != visitedAttrTypes.end())
37 if (walkSubElements(element, order).wasInterrupted())
42 for (
auto &walkFn : llvm::reverse(walkFns)) {
52 if (walkSubElements(element, order).wasInterrupted())
61 auto walkFn = [&](
auto element) {
63 result = walkImpl(element, order);
65 interface.walkImmediateSubElements(walkFn, walkFn);
74 attrReplacementFns.emplace_back(std::move(fn));
77 typeReplacementFns.push_back(std::move(fn));
81 bool replaceLocs,
bool replaceTypes) {
84 auto replaceIfDifferent = [&](
auto element) {
85 auto replacement =
replace(element);
86 return (replacement && replacement != element) ? replacement :
nullptr;
92 op->
setAttrs(cast<DictionaryAttr>(newAttrs));
96 if (!replaceTypes && !replaceLocs)
102 op->
setLoc(cast<LocationAttr>(newLoc));
108 if (
Type newType = replaceIfDifferent(result.getType()))
109 result.setType(newType);
114 for (
Block &block : region) {
117 if (
Attribute newLoc = replaceIfDifferent(arg.getLoc()))
118 arg.setLoc(cast<LocationAttr>(newLoc));
122 if (
Type newType = replaceIfDifferent(arg.getType()))
123 arg.setType(newType);
139 template <
typename T>
149 newElements.push_back(
nullptr);
154 if (T result = replacer.
replace(element)) {
155 newElements.push_back(result);
156 if (result != element)
163 template <
typename T>
164 T AttrTypeReplacer::replaceSubElements(T interface) {
169 interface.walkImmediateSubElements(
180 T result = interface;
182 result = interface.replaceImmediateSubElements(newAttrs, newTypes);
187 template <
typename T,
typename ReplaceFns>
188 T AttrTypeReplacer::replaceImpl(T element, ReplaceFns &replaceFns) {
189 const void *opaqueElement = element.getAsOpaquePointer();
190 auto [it, inserted] = attrTypeMap.try_emplace(opaqueElement, opaqueElement);
192 return T::getFromOpaquePointer(it->second);
196 for (
auto &replaceFn : llvm::reverse(replaceFns)) {
197 if (std::optional<std::pair<T, WalkResult>> newRes = replaceFn(element)) {
198 std::tie(result, walkResult) = *newRes;
205 attrTypeMap[opaqueElement] =
nullptr;
212 if (!(result = replaceSubElements(result))) {
213 attrTypeMap[opaqueElement] =
nullptr;
218 attrTypeMap[opaqueElement] = result.getAsOpaquePointer();
223 return replaceImpl(attr, attrReplacementFns);
227 return replaceImpl(type, typeReplacementFns);
236 walkAttrsFn(element);
241 walkTypesFn(element);
static void updateSubElementImpl(T element, AttrTypeReplacer &replacer, SmallVectorImpl< T > &newElements, FailureOr< bool > &changed)
void replaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation.
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
std::function< ReplaceFnResult< T >(T)> ReplaceFn
Attribute replace(Attribute attr)
Replace the given attribute/type, and recursively replace any sub elements.
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
This class provides support for representing a failure result, or a valid value of type T.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_range getResults()
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
A utility result that is used to signal how to proceed with an ongoing walk:
bool wasSkipped() const
Returns true if the walk was skipped.
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
WalkOrder
Traversal order for region, block and operation walk utilities.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.