19 return walkImpl(attr, attrWalkFns, order);
22 return walkImpl(type, typeWalkFns, order);
25template <
typename T,
typename WalkFns>
26WalkResult AttrTypeWalker::walkImpl(T element, WalkFns &walkFns,
29 auto key = std::make_pair(element.getAsOpaquePointer(), (
int)order);
37 if (walkSubElements(element, order).wasInterrupted())
42 for (
auto &walkFn : llvm::reverse(walkFns)) {
43 WalkResult walkResult = walkFn(element);
52 if (walkSubElements(element, order).wasInterrupted())
61 auto walkFn = [&](
auto element) {
62 if (element && !
result.wasInterrupted())
63 result = walkImpl(element, order);
65 interface.walkImmediateSubElements(walkFn, walkFn);
73template <
typename Concrete>
76 attrReplacementFns.emplace_back(std::move(fn));
79template <
typename Concrete>
82 typeReplacementFns.push_back(std::move(fn));
85template <
typename Concrete>
87 Operation *op,
bool replaceAttrs,
bool replaceLocs,
bool replaceTypes) {
90 auto replaceIfDifferent = [&](
auto element) {
91 auto replacement =
static_cast<Concrete *
>(
this)->replace(element);
98 op->
setAttrs(cast<DictionaryAttr>(newAttrs));
102 if (!replaceTypes && !replaceLocs)
108 op->
setLoc(cast<LocationAttr>(newLoc));
114 if (
Type newType = replaceIfDifferent(
result.getType()))
120 for (
Block &block : region) {
123 if (
Attribute newLoc = replaceIfDifferent(arg.getLoc()))
124 arg.setLoc(cast<LocationAttr>(newLoc));
128 if (
Type newType = replaceIfDifferent(arg.getType()))
129 arg.setType(newType);
136template <
typename Concrete>
138 Operation *op,
bool replaceAttrs,
bool replaceLocs,
bool replaceTypes) {
144template <
typename T,
typename Replacer>
154 newElements.push_back(
nullptr);
159 if (T
result = replacer.replace(element)) {
160 newElements.push_back(
result);
168template <
typename T,
typename Replacer>
173 FailureOr<bool>
changed =
false;
174 interface.walkImmediateSubElements(
187 result = interface.replaceImmediateSubElements(newAttrs, newTypes);
192template <
typename T,
typename ReplaceFns,
typename Replacer>
194 Replacer &replacer) {
197 for (
auto &replaceFn : llvm::reverse(replaceFns)) {
198 if (std::optional<std::pair<T, WalkResult>> newRes = replaceFn(element)) {
199 std::tie(
result, walkResult) = *newRes;
220template <
typename Concrete>
223 *
static_cast<Concrete *
>(
this));
226template <
typename Concrete>
229 *
static_cast<Concrete *
>(
this));
239T AttrTypeReplacer::cachedReplaceImpl(T element) {
240 const void *opaqueElement = element.getAsOpaquePointer();
241 auto [it,
inserted] = cache.try_emplace(opaqueElement, opaqueElement);
243 return T::getFromOpaquePointer(it->second);
247 cache[opaqueElement] =
result.getAsOpaquePointer();
252 return cachedReplaceImpl(attr);
264 : cache([&](
void *attr) {
return breakCycleImpl(attr); }) {}
267 attrCycleBreakerFns.emplace_back(std::move(fn));
271 typeCycleBreakerFns.emplace_back(std::move(fn));
275T CyclicAttrTypeReplacer::cachedReplaceImpl(T element) {
276 void *opaqueTaggedElement = AttrOrType(element).getOpaqueValue();
279 if (
auto resultOpt = cacheEntry.
get())
280 return T::getFromOpaquePointer(*resultOpt);
289 return cachedReplaceImpl(attr);
293 return cachedReplaceImpl(type);
296std::optional<const void *>
297CyclicAttrTypeReplacer::breakCycleImpl(
void *element) {
298 AttrOrType attrType = AttrOrType::getFromOpaqueValue(element);
299 if (
auto attr = dyn_cast<Attribute>(attrType)) {
300 for (
auto &cyclicReplaceFn : llvm::reverse(attrCycleBreakerFns)) {
301 if (std::optional<Attribute> newRes = cyclicReplaceFn(attr)) {
302 return newRes->getAsOpaquePointer();
306 auto type = dyn_cast<Type>(attrType);
307 for (
auto &cyclicReplaceFn : llvm::reverse(typeCycleBreakerFns)) {
308 if (std::optional<Type> newRes = cyclicReplaceFn(type)) {
309 return newRes->getAsOpaquePointer();
322 walkAttrsFn(element);
327 walkTypesFn(element);
static void updateSubElementImpl(T element, Replacer &replacer, SmallVectorImpl< T > &newElements, FailureOr< bool > &changed)
static T replaceElementImpl(T element, ReplaceFns &replaceFns, Replacer &replacer)
Shared implementation of replacing a given attribute or type element.
static T replaceSubElements(T interface, Replacer &replacer)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
Attribute replace(Attribute attr)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
void addCycleBreaker(CycleBreakerFn< Attribute > fn)
Register a cycle-breaking function.
Attribute replace(Attribute attr)
std::function< std::optional< T >(T)> CycleBreakerFn
A cycle-breaking function.
CacheEntry lookupOrInit(InT element)
Lookup the cache for a pre-calculated replacement for element.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
void setLoc(Location loc)
Set the source location the operation was defined or derived from.
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Location getLoc()
The source location the operation was defined or derived from.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getResults()
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
A utility result that is used to signal how to proceed with an ongoing walk:
bool wasSkipped() const
Returns true if the walk was skipped.
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
This class provides a base utility for replacing attributes/types, and their sub elements.
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
Attribute replaceBase(Attribute attr)
Invokes the registered replacement functions from most recently registered to least recently register...
std::function< ReplaceFnResult< T >(T)> ReplaceFn
void replaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation.
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
WalkOrder
Traversal order for region, block and operation walk utilities.
A possibly unresolved cache entry.
void resolve(OutT result)
Resolve an unresolved cache entry by providing the result to be stored in the cache.
const std::optional< OutT > & get() const
Get the resolved result if one exists.