14 #ifndef MLIR_IR_ATTRTYPESUBELEMENTS_H
15 #define MLIR_IR_ATTRTYPESUBELEMENTS_H
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/DenseMap.h"
40 template <WalkOrder Order,
typename T>
42 return walkImpl(element, Order);
46 return walk<WalkOrder::PostOrder, T>(element);
69 attrWalkFns.emplace_back(std::move(fn));
76 template <
typename FnT,
77 typename T =
typename llvm::function_traits<
78 std::decay_t<FnT>>::template arg_t<0>,
79 typename BaseT = std::conditional_t<std::is_base_of_v<Attribute, T>,
81 typename ResultT = std::invoke_result_t<FnT, T>>
82 std::enable_if_t<!std::is_same_v<T, BaseT> || std::is_same_v<ResultT, void>>
85 if (
auto derived = dyn_cast<T>(base)) {
86 if constexpr (std::is_convertible_v<ResultT, WalkResult>)
87 return callback(derived);
100 template <
typename T,
typename WalkFns>
104 template <
typename T>
108 std::vector<WalkFn<Attribute>> attrWalkFns;
109 std::vector<WalkFn<Type>> typeWalkFns;
134 bool replaceLocs =
false,
bool replaceTypes =
false);
139 bool replaceLocs =
false,
140 bool replaceTypes =
false);
155 template <
typename T>
157 template <
typename T>
185 template <
typename FnT,
186 typename T =
typename llvm::function_traits<
187 std::decay_t<FnT>>::template arg_t<0>,
188 typename BaseT = std::conditional_t<std::is_base_of_v<Attribute, T>,
190 typename ResultT = std::invoke_result_t<FnT, T>>
191 std::enable_if_t<!std::is_same_v<T, BaseT> ||
192 !std::is_convertible_v<ResultT, ReplaceFnResult<BaseT>>>
196 if (
auto derived = dyn_cast<T>(base)) {
197 if constexpr (std::is_convertible_v<ResultT, std::optional<BaseT>>) {
198 std::optional<BaseT> result = callback(derived);
202 return callback(derived);
211 template <
typename T,
typename ReplaceFns>
212 T replaceImpl(T element, ReplaceFns &replaceFns);
215 template <
typename T>
216 T replaceSubElements(T interface);
219 std::vector<ReplaceFn<Attribute>> attrReplacementFns;
220 std::vector<ReplaceFn<Type>> typeReplacementFns;
236 : walkAttrsFn(walkAttrsFn), walkTypesFn(walkTypesFn) {}
243 template <
typename RangeT>
245 for (
auto element : elements)
256 template <
typename T>
265 repls = repls.drop_front(n);
293 template <
typename T,
typename Enable =
void>
296 static inline void walk(
const T ¶m,
300 template <
typename ParamT>
301 static inline decltype(
auto)
replace(ParamT &¶m,
304 return std::forward<ParamT>(param);
313 template <
typename T>
316 template <
typename... Ts>
318 (!llvm::is_detected<detail::has_default_sub_element_handler_t, Ts>::value ||
322 template <
typename T>
324 T, std::enable_if_t<std::is_base_of_v<Attribute, T> ||
325 std::is_base_of_v<Type, T>>> {
333 if constexpr (std::is_base_of_v<Attribute, T>) {
341 template <
typename T>
343 std::enable_if_t<has_sub_attr_or_type_v<T>>> {
348 for (
const T &subElement : param)
354 if constexpr (std::is_base_of_v<Attribute, T> &&
355 sizeof(T) ==
sizeof(
void *)) {
357 return ArrayRef<T>((
const T *)attrs.data(), attrs.size());
358 }
else if constexpr (std::is_base_of_v<Type, T> &&
359 sizeof(T) ==
sizeof(
void *)) {
361 return ArrayRef<T>((
const T *)types.data(), types.size());
365 for (
const T &element : param)
366 newElements.emplace_back(
367 EltHandler::replace(element, attrRepls, typeRepls));
373 template <
typename... Ts>
375 std::tuple<Ts...>, std::enable_if_t<has_sub_attr_or_type_v<Ts...>>> {
376 static void walk(
const std::tuple<Ts...> ¶m,
379 [&](
const Ts &...params) {
384 static auto replace(
const std::tuple<Ts...> ¶m,
388 [&](
const Ts &...params)
390 params, attrRepls, typeRepls))...> {
399 template <
typename T>
401 template <
typename... Ts>
402 struct is_tuple<std::tuple<Ts...>> :
public std::true_type {};
404 template <
typename T>
406 template <
typename... Ts>
407 struct is_pair<std::pair<Ts...>> :
public std::true_type {};
409 template <
typename T,
typename... Ts>
411 template <
typename T,
typename... Ts>
417 template <
typename T>
421 using ImplT =
typename T::ImplType;
425 if constexpr (llvm::is_detected<has_get_as_key, ImplT>::value) {
426 auto key =
static_cast<ImplT *
>(derived.getImpl())->getAsKey();
438 template <
typename T,
typename... Ts>
441 if constexpr (llvm::is_detected<has_get_method, T, Ts...>::value) {
443 return T::get(std::forward<Ts>(params)...);
446 return T::get(ctx, std::forward<Ts>(params)...);
456 template <
typename T>
459 using ImplT =
typename T::ImplType;
460 if constexpr (llvm::is_detected<has_get_as_key, ImplT>::value) {
461 auto key =
static_cast<ImplT *
>(derived.getImpl())->getAsKey();
470 auto buildReplacement = [&](
auto newKey,
MLIRContext *ctx) {
471 if constexpr (
is_tuple<decltype(key)>::value ||
472 is_pair<decltype(key)>::value) {
474 [&](
auto &&...params) {
475 return constructSubElementReplacement<T>(
476 ctx, std::forward<decltype(params)>(params)...);
480 return constructSubElementReplacement<T>(ctx, newKey);
487 key, attrRepls, typeRepls);
489 if constexpr (std::is_convertible_v<decltype(newKey),
LogicalResult>)
490 return succeeded(newKey) ? buildReplacement(*newKey, ctx) :
nullptr;
492 return buildReplacement(newKey, ctx);
void replaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation.
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
std::enable_if_t<!std::is_same_v< T, BaseT >||!std::is_convertible_v< ResultT, ReplaceFnResult< BaseT > > > addReplacement(FnT &&callback)
Register a replacement function that doesn't match the default signature, either because it uses a de...
std::function< ReplaceFnResult< T >(T)> ReplaceFn
std::optional< std::pair< T, WalkResult > > ReplaceFnResult
A replacement mapping function, which returns either std::nullopt (to signal the element wasn't handl...
Attribute replace(Attribute attr)
Replace the given attribute/type, and recursively replace any sub elements.
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
This class is used by AttrTypeSubElementHandler instances to process sub element replacements.
AttrTypeSubElementReplacements(ArrayRef< T > repls)
ArrayRef< T > take_front(unsigned n)
Take the first N replacements as an ArrayRef, dropping them from this replacement list.
WalkResult walk(T element)
std::function< WalkResult(T)> WalkFn
void addWalk(WalkFn< Attribute > &&fn)
Register a walk function for a given attribute or type.
void addWalk(WalkFn< Type > &&fn)
std::enable_if_t<!std::is_same_v< T, BaseT >||std::is_same_v< ResultT, void > > addWalk(FnT &&callback)
Register a replacement function that doesn't match the default signature, either because it uses a de...
WalkResult walk(T element)
Walk the given attribute/type, and recursively walk any sub elements.
Attributes are known-constant values of operations.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
decltype(std::declval< T >().getAsKey()) has_get_as_key
auto constructSubElementReplacement(MLIRContext *ctx, Ts &&...params)
This function invokes the proper get method for a type T with the given values.
decltype(T::DefaultHandlerTag) has_default_sub_element_handler_t
auto replaceImmediateSubElementsImpl(T derived, ArrayRef< Attribute > &replAttrs, ArrayRef< Type > &replTypes)
This function provides the underlying implementation for the SubElementInterface replace method,...
void walkImmediateSubElementsImpl(T derived, function_ref< void(Attribute)> walkAttrsFn, function_ref< void(Type)> walkTypesFn)
This function provides the underlying implementation for the SubElementInterface walk method,...
decltype(T::get(std::declval< Ts >()...)) has_get_method
@ Type
An inlay hint that for a type annotation.
This header declares functions that assist transformations in the MemRef dialect.
constexpr bool has_sub_attr_or_type_v
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
WalkOrder
Traversal order for region, block and operation walk utilities.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
static void walk(ArrayRef< T > param, AttrTypeImmediateSubElementWalker &walker)
static auto replace(ArrayRef< T > param, AttrSubElementReplacements &attrRepls, TypeSubElementReplacements &typeRepls)
static T replace(T param, AttrSubElementReplacements &attrRepls, TypeSubElementReplacements &typeRepls)
static void walk(T param, AttrTypeImmediateSubElementWalker &walker)
static void walk(const std::tuple< Ts... > ¶m, AttrTypeImmediateSubElementWalker &walker)
static auto replace(const std::tuple< Ts... > ¶m, AttrSubElementReplacements &attrRepls, TypeSubElementReplacements &typeRepls)
This class provides support for interacting with the SubElementInterfaces for different types of para...
static void walk(const T ¶m, AttrTypeImmediateSubElementWalker &walker)
Default walk implementation that does nothing.
void DefaultHandlerTag
Tag indicating that this handler does not support sub-elements.
static decltype(auto) replace(ParamT &¶m, AttrSubElementReplacements &attrRepls, TypeSubElementReplacements &typeRepls)
Default replace implementation just forwards the parameter.
This class represents an efficient way to signal success or failure.