MLIR  19.0.0git
AttrTypeSubElements.h
Go to the documentation of this file.
1 //===- AttrTypeSubElements.h - Attr and Type SubElements -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains utilities for querying the sub elements of an attribute or
10 // type.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_IR_ATTRTYPESUBELEMENTS_H
15 #define MLIR_IR_ATTRTYPESUBELEMENTS_H
16 
17 #include "mlir/IR/MLIRContext.h"
18 #include "mlir/IR/Visitors.h"
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include <optional>
22 
23 namespace mlir {
24 class Attribute;
25 class Type;
26 
27 //===----------------------------------------------------------------------===//
28 /// AttrTypeWalker
29 //===----------------------------------------------------------------------===//
30 
31 /// This class provides a utility for walking attributes/types, and their sub
32 /// elements. Multiple walk functions may be registered.
34 public:
35  //===--------------------------------------------------------------------===//
36  // Application
37  //===--------------------------------------------------------------------===//
38 
39  /// Walk the given attribute/type, and recursively walk any sub elements.
40  template <WalkOrder Order, typename T>
41  WalkResult walk(T element) {
42  return walkImpl(element, Order);
43  }
44  template <typename T>
45  WalkResult walk(T element) {
46  return walk<WalkOrder::PostOrder, T>(element);
47  }
48 
49  //===--------------------------------------------------------------------===//
50  // Registration
51  //===--------------------------------------------------------------------===//
52 
53  template <typename T>
54  using WalkFn = std::function<WalkResult(T)>;
55 
56  /// Register a walk function for a given attribute or type. A walk function
57  /// must be convertible to any of the following forms(where `T` is a class
58  /// derived from `Type` or `Attribute`:
59  ///
60  /// * WalkResult(T)
61  /// - Returns a walk result, which can be used to control the walk
62  ///
63  /// * void(T)
64  /// - Returns void, i.e. the walk always continues.
65  ///
66  /// Note: When walking, the mostly recently added walk functions will be
67  /// invoked first.
69  attrWalkFns.emplace_back(std::move(fn));
70  }
71  void addWalk(WalkFn<Type> &&fn) { typeWalkFns.push_back(std::move(fn)); }
72 
73  /// Register a replacement function that doesn't match the default signature,
74  /// either because it uses a derived parameter type, or it uses a simplified
75  /// result type.
76  template <typename FnT,
77  typename T = typename llvm::function_traits<
78  std::decay_t<FnT>>::template arg_t<0>,
79  typename BaseT = std::conditional_t<std::is_base_of_v<Attribute, T>,
80  Attribute, Type>,
81  typename ResultT = std::invoke_result_t<FnT, T>>
82  std::enable_if_t<!std::is_same_v<T, BaseT> || std::is_same_v<ResultT, void>>
83  addWalk(FnT &&callback) {
84  addWalk([callback = std::forward<FnT>(callback)](BaseT base) -> WalkResult {
85  if (auto derived = dyn_cast<T>(base)) {
86  if constexpr (std::is_convertible_v<ResultT, WalkResult>)
87  return callback(derived);
88  else
89  callback(derived);
90  }
91  return WalkResult::advance();
92  });
93  }
94 
95 private:
96  WalkResult walkImpl(Attribute attr, WalkOrder order);
97  WalkResult walkImpl(Type type, WalkOrder order);
98 
99  /// Internal implementation of the `walk` methods above.
100  template <typename T, typename WalkFns>
101  WalkResult walkImpl(T element, WalkFns &walkFns, WalkOrder order);
102 
103  /// Walk the sub elements of the given interface.
104  template <typename T>
105  WalkResult walkSubElements(T interface, WalkOrder order);
106 
107  /// The set of walk functions that map sub elements.
108  std::vector<WalkFn<Attribute>> attrWalkFns;
109  std::vector<WalkFn<Type>> typeWalkFns;
110 
111  /// The set of visited attributes/types.
113 };
114 
115 //===----------------------------------------------------------------------===//
116 /// AttrTypeReplacer
117 //===----------------------------------------------------------------------===//
118 
119 /// This class provides a utility for replacing attributes/types, and their sub
120 /// elements. Multiple replacement functions may be registered.
122 public:
123  //===--------------------------------------------------------------------===//
124  // Application
125  //===--------------------------------------------------------------------===//
126 
127  /// Replace the elements within the given operation. If `replaceAttrs` is
128  /// true, this updates the attribute dictionary of the operation. If
129  /// `replaceLocs` is true, this also updates its location, and the locations
130  /// of any nested block arguments. If `replaceTypes` is true, this also
131  /// updates the result types of the operation, and the types of any nested
132  /// block arguments.
133  void replaceElementsIn(Operation *op, bool replaceAttrs = true,
134  bool replaceLocs = false, bool replaceTypes = false);
135 
136  /// Replace the elements within the given operation, and all nested
137  /// operations.
138  void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs = true,
139  bool replaceLocs = false,
140  bool replaceTypes = false);
141 
142  /// Replace the given attribute/type, and recursively replace any sub
143  /// elements. Returns either the new attribute/type, or nullptr in the case of
144  /// failure.
146  Type replace(Type type);
147 
148  //===--------------------------------------------------------------------===//
149  // Registration
150  //===--------------------------------------------------------------------===//
151 
152  /// A replacement mapping function, which returns either std::nullopt (to
153  /// signal the element wasn't handled), or a pair of the replacement element
154  /// and a WalkResult.
155  template <typename T>
156  using ReplaceFnResult = std::optional<std::pair<T, WalkResult>>;
157  template <typename T>
158  using ReplaceFn = std::function<ReplaceFnResult<T>(T)>;
159 
160  /// Register a replacement function for mapping a given attribute or type. A
161  /// replacement function must be convertible to any of the following
162  /// forms(where `T` is a class derived from `Type` or `Attribute`, and `BaseT`
163  /// is either `Type` or `Attribute` respectively):
164  ///
165  /// * std::optional<BaseT>(T)
166  /// - This either returns a valid Attribute/Type in the case of success,
167  /// nullptr in the case of failure, or `std::nullopt` to signify that
168  /// additional replacement functions may be applied (i.e. this function
169  /// doesn't handle that instance).
170  ///
171  /// * std::optional<std::pair<BaseT, WalkResult>>(T)
172  /// - Similar to the above, but also allows specifying a WalkResult to
173  /// control the replacement of sub elements of a given attribute or
174  /// type. Returning a `skip` result, for example, will not recursively
175  /// process the resultant attribute or type value.
176  ///
177  /// Note: When replacing, the mostly recently added replacement functions will
178  /// be invoked first.
181 
182  /// Register a replacement function that doesn't match the default signature,
183  /// either because it uses a derived parameter type, or it uses a simplified
184  /// result type.
185  template <typename FnT,
186  typename T = typename llvm::function_traits<
187  std::decay_t<FnT>>::template arg_t<0>,
188  typename BaseT = std::conditional_t<std::is_base_of_v<Attribute, T>,
189  Attribute, Type>,
190  typename ResultT = std::invoke_result_t<FnT, T>>
191  std::enable_if_t<!std::is_same_v<T, BaseT> ||
192  !std::is_convertible_v<ResultT, ReplaceFnResult<BaseT>>>
193  addReplacement(FnT &&callback) {
194  addReplacement([callback = std::forward<FnT>(callback)](
195  BaseT base) -> ReplaceFnResult<BaseT> {
196  if (auto derived = dyn_cast<T>(base)) {
197  if constexpr (std::is_convertible_v<ResultT, std::optional<BaseT>>) {
198  std::optional<BaseT> result = callback(derived);
199  return result ? std::make_pair(*result, WalkResult::advance())
201  } else {
202  return callback(derived);
203  }
204  }
205  return ReplaceFnResult<BaseT>();
206  });
207  }
208 
209 private:
210  /// Internal implementation of the `replace` methods above.
211  template <typename T, typename ReplaceFns>
212  T replaceImpl(T element, ReplaceFns &replaceFns);
213 
214  /// Replace the sub elements of the given interface.
215  template <typename T>
216  T replaceSubElements(T interface);
217 
218  /// The set of replacement functions that map sub elements.
219  std::vector<ReplaceFn<Attribute>> attrReplacementFns;
220  std::vector<ReplaceFn<Type>> typeReplacementFns;
221 
222  /// The set of cached mappings for attributes/types.
224 };
225 
226 //===----------------------------------------------------------------------===//
227 /// AttrTypeSubElementHandler
228 //===----------------------------------------------------------------------===//
229 
230 /// This class is used by AttrTypeSubElementHandler instances to walking sub
231 /// attributes and types.
233 public:
235  function_ref<void(Type)> walkTypesFn)
236  : walkAttrsFn(walkAttrsFn), walkTypesFn(walkTypesFn) {}
237 
238  /// Walk an attribute.
239  void walk(Attribute element);
240  /// Walk a type.
241  void walk(Type element);
242  /// Walk a range of attributes or types.
243  template <typename RangeT>
244  void walkRange(RangeT &&elements) {
245  for (auto element : elements)
246  walk(element);
247  }
248 
249 private:
250  function_ref<void(Attribute)> walkAttrsFn;
251  function_ref<void(Type)> walkTypesFn;
252 };
253 
254 /// This class is used by AttrTypeSubElementHandler instances to process sub
255 /// element replacements.
256 template <typename T>
258 public:
260 
261  /// Take the first N replacements as an ArrayRef, dropping them from
262  /// this replacement list.
263  ArrayRef<T> take_front(unsigned n) {
264  ArrayRef<T> elements = repls.take_front(n);
265  repls = repls.drop_front(n);
266  return elements;
267  }
268 
269 private:
270  /// The current set of replacements.
271  ArrayRef<T> repls;
272 };
275 
276 /// This class provides support for interacting with the
277 /// SubElementInterfaces for different types of parameters. An
278 /// implementation of this class should be provided for any parameter class
279 /// that may contain an attribute or type. There are two main methods of
280 /// this class that need to be implemented:
281 ///
282 /// - walk
283 ///
284 /// This method should traverse into any sub elements of the parameter
285 /// using the provided walker, or by invoking handlers for sub-types.
286 ///
287 /// - replace
288 ///
289 /// This method should extract any necessary sub elements using the
290 /// provided replacer, or by invoking handlers for sub-types. The new
291 /// post-replacement parameter value should be returned.
292 ///
293 template <typename T, typename Enable = void>
295  /// Default walk implementation that does nothing.
296  static inline void walk(const T &param,
298 
299  /// Default replace implementation just forwards the parameter.
300  template <typename ParamT>
301  static inline decltype(auto) replace(ParamT &&param,
302  AttrSubElementReplacements &attrRepls,
303  TypeSubElementReplacements &typeRepls) {
304  return std::forward<ParamT>(param);
305  }
306 
307  /// Tag indicating that this handler does not support sub-elements.
308  using DefaultHandlerTag = void;
309 };
310 
311 /// Detect if any of the given parameter types has a sub-element handler.
312 namespace detail {
313 template <typename T>
314 using has_default_sub_element_handler_t = decltype(T::DefaultHandlerTag);
315 } // namespace detail
316 template <typename... Ts>
317 inline constexpr bool has_sub_attr_or_type_v =
318  (!llvm::is_detected<detail::has_default_sub_element_handler_t, Ts>::value ||
319  ...);
320 
321 /// Implementation for derived Attributes and Types.
322 template <typename T>
324  T, std::enable_if_t<std::is_base_of_v<Attribute, T> ||
325  std::is_base_of_v<Type, T>>> {
326  static void walk(T param, AttrTypeImmediateSubElementWalker &walker) {
327  walker.walk(param);
328  }
329  static T replace(T param, AttrSubElementReplacements &attrRepls,
330  TypeSubElementReplacements &typeRepls) {
331  if (!param)
332  return T();
333  if constexpr (std::is_base_of_v<Attribute, T>) {
334  return cast<T>(attrRepls.take_front(1)[0]);
335  } else {
336  return cast<T>(typeRepls.take_front(1)[0]);
337  }
338  }
339 };
340 /// Implementation for derived ArrayRef.
341 template <typename T>
343  std::enable_if_t<has_sub_attr_or_type_v<T>>> {
345 
346  static void walk(ArrayRef<T> param,
348  for (const T &subElement : param)
349  EltHandler::walk(subElement, walker);
350  }
351  static auto replace(ArrayRef<T> param, AttrSubElementReplacements &attrRepls,
352  TypeSubElementReplacements &typeRepls) {
353  // Normal attributes/types can extract using the replacer directly.
354  if constexpr (std::is_base_of_v<Attribute, T> &&
355  sizeof(T) == sizeof(void *)) {
356  ArrayRef<Attribute> attrs = attrRepls.take_front(param.size());
357  return ArrayRef<T>((const T *)attrs.data(), attrs.size());
358  } else if constexpr (std::is_base_of_v<Type, T> &&
359  sizeof(T) == sizeof(void *)) {
360  ArrayRef<Type> types = typeRepls.take_front(param.size());
361  return ArrayRef<T>((const T *)types.data(), types.size());
362  } else {
363  // Otherwise, we need to allocate storage for the new elements.
364  SmallVector<T> newElements;
365  for (const T &element : param)
366  newElements.emplace_back(
367  EltHandler::replace(element, attrRepls, typeRepls));
368  return newElements;
369  }
370  }
371 };
372 /// Implementation for Tuple.
373 template <typename... Ts>
375  std::tuple<Ts...>, std::enable_if_t<has_sub_attr_or_type_v<Ts...>>> {
376  static void walk(const std::tuple<Ts...> &param,
378  std::apply(
379  [&](const Ts &...params) {
380  (AttrTypeSubElementHandler<Ts>::walk(params, walker), ...);
381  },
382  param);
383  }
384  static auto replace(const std::tuple<Ts...> &param,
385  AttrSubElementReplacements &attrRepls,
386  TypeSubElementReplacements &typeRepls) {
387  return std::apply(
388  [&](const Ts &...params)
389  -> std::tuple<decltype(AttrTypeSubElementHandler<Ts>::replace(
390  params, attrRepls, typeRepls))...> {
391  return {AttrTypeSubElementHandler<Ts>::replace(params, attrRepls,
392  typeRepls)...};
393  },
394  param);
395  }
396 };
397 
398 namespace detail {
399 template <typename T>
400 struct is_tuple : public std::false_type {};
401 template <typename... Ts>
402 struct is_tuple<std::tuple<Ts...>> : public std::true_type {};
403 
404 template <typename T>
405 struct is_pair : public std::false_type {};
406 template <typename... Ts>
407 struct is_pair<std::pair<Ts...>> : public std::true_type {};
408 
409 template <typename T, typename... Ts>
410 using has_get_method = decltype(T::get(std::declval<Ts>()...));
411 template <typename T, typename... Ts>
412 using has_get_as_key = decltype(std::declval<T>().getAsKey());
413 
414 /// This function provides the underlying implementation for the
415 /// SubElementInterface walk method, using the key type of the derived
416 /// attribute/type to interact with the individual parameters.
417 template <typename T>
419  function_ref<void(Attribute)> walkAttrsFn,
420  function_ref<void(Type)> walkTypesFn) {
421  using ImplT = typename T::ImplType;
422  (void)derived;
423  (void)walkAttrsFn;
424  (void)walkTypesFn;
425  if constexpr (llvm::is_detected<has_get_as_key, ImplT>::value) {
426  auto key = static_cast<ImplT *>(derived.getImpl())->getAsKey();
427 
428  // If we don't have any sub-elements, there is nothing to do.
429  if constexpr (!has_sub_attr_or_type_v<decltype(key)>)
430  return;
431  AttrTypeImmediateSubElementWalker walker(walkAttrsFn, walkTypesFn);
432  AttrTypeSubElementHandler<decltype(key)>::walk(key, walker);
433  }
434 }
435 
436 /// This function invokes the proper `get` method for a type `T` with the given
437 /// values.
438 template <typename T, typename... Ts>
439 auto constructSubElementReplacement(MLIRContext *ctx, Ts &&...params) {
440  // Prefer a direct `get` method if one exists.
441  if constexpr (llvm::is_detected<has_get_method, T, Ts...>::value) {
442  (void)ctx;
443  return T::get(std::forward<Ts>(params)...);
444  } else if constexpr (llvm::is_detected<has_get_method, T, MLIRContext *,
445  Ts...>::value) {
446  return T::get(ctx, std::forward<Ts>(params)...);
447  } else {
448  // Otherwise, pass to the base get.
449  return T::Base::get(ctx, std::forward<Ts>(params)...);
450  }
451 }
452 
453 /// This function provides the underlying implementation for the
454 /// SubElementInterface replace method, using the key type of the derived
455 /// attribute/type to interact with the individual parameters.
456 template <typename T>
458  ArrayRef<Type> &replTypes) {
459  using ImplT = typename T::ImplType;
460  if constexpr (llvm::is_detected<has_get_as_key, ImplT>::value) {
461  auto key = static_cast<ImplT *>(derived.getImpl())->getAsKey();
462 
463  // If we don't have any sub-elements, we can just return the original.
464  if constexpr (!has_sub_attr_or_type_v<decltype(key)>) {
465  return derived;
466 
467  // Otherwise, we need to replace any necessary sub-elements.
468  } else {
469  // Functor used to build the replacement on success.
470  auto buildReplacement = [&](auto newKey, MLIRContext *ctx) {
471  if constexpr (is_tuple<decltype(key)>::value ||
472  is_pair<decltype(key)>::value) {
473  return std::apply(
474  [&](auto &&...params) {
475  return constructSubElementReplacement<T>(
476  ctx, std::forward<decltype(params)>(params)...);
477  },
478  newKey);
479  } else {
480  return constructSubElementReplacement<T>(ctx, newKey);
481  }
482  };
483 
484  AttrSubElementReplacements attrRepls(replAttrs);
485  TypeSubElementReplacements typeRepls(replTypes);
486  auto newKey = AttrTypeSubElementHandler<decltype(key)>::replace(
487  key, attrRepls, typeRepls);
488  MLIRContext *ctx = derived.getContext();
489  if constexpr (std::is_convertible_v<decltype(newKey), LogicalResult>)
490  return succeeded(newKey) ? buildReplacement(*newKey, ctx) : nullptr;
491  else
492  return buildReplacement(newKey, ctx);
493  }
494  } else {
495  return derived;
496  }
497 }
498 } // namespace detail
499 } // namespace mlir
500 
501 #endif // MLIR_IR_ATTRTYPESUBELEMENTS_H
void walk(Attribute element)
Walk an attribute.
AttrTypeImmediateSubElementWalker(function_ref< void(Attribute)> walkAttrsFn, function_ref< void(Type)> walkTypesFn)
void walkRange(RangeT &&elements)
Walk a range of attributes or types.
void replaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation.
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
std::enable_if_t<!std::is_same_v< T, BaseT >||!std::is_convertible_v< ResultT, ReplaceFnResult< BaseT > > > addReplacement(FnT &&callback)
Register a replacement function that doesn't match the default signature, either because it uses a de...
std::function< ReplaceFnResult< T >(T)> ReplaceFn
std::optional< std::pair< T, WalkResult > > ReplaceFnResult
A replacement mapping function, which returns either std::nullopt (to signal the element wasn't handl...
Attribute replace(Attribute attr)
Replace the given attribute/type, and recursively replace any sub elements.
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
This class is used by AttrTypeSubElementHandler instances to process sub element replacements.
AttrTypeSubElementReplacements(ArrayRef< T > repls)
ArrayRef< T > take_front(unsigned n)
Take the first N replacements as an ArrayRef, dropping them from this replacement list.
WalkResult walk(T element)
std::function< WalkResult(T)> WalkFn
void addWalk(WalkFn< Attribute > &&fn)
Register a walk function for a given attribute or type.
void addWalk(WalkFn< Type > &&fn)
std::enable_if_t<!std::is_same_v< T, BaseT >||std::is_same_v< ResultT, void > > addWalk(FnT &&callback)
Register a replacement function that doesn't match the default signature, either because it uses a de...
WalkResult walk(T element)
Walk the given attribute/type, and recursively walk any sub elements.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:34
static WalkResult advance()
Definition: Visitors.h:52
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:137
decltype(std::declval< T >().getAsKey()) has_get_as_key
auto constructSubElementReplacement(MLIRContext *ctx, Ts &&...params)
This function invokes the proper get method for a type T with the given values.
decltype(T::DefaultHandlerTag) has_default_sub_element_handler_t
auto replaceImmediateSubElementsImpl(T derived, ArrayRef< Attribute > &replAttrs, ArrayRef< Type > &replTypes)
This function provides the underlying implementation for the SubElementInterface replace method,...
void walkImmediateSubElementsImpl(T derived, function_ref< void(Attribute)> walkAttrsFn, function_ref< void(Type)> walkTypesFn)
This function provides the underlying implementation for the SubElementInterface walk method,...
decltype(T::get(std::declval< Ts >()...)) has_get_method
@ Type
An inlay hint that for a type annotation.
Include the generated interface declarations.
constexpr bool has_sub_attr_or_type_v
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
WalkOrder
Traversal order for region, block and operation walk utilities.
Definition: Visitors.h:63
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
static void walk(ArrayRef< T > param, AttrTypeImmediateSubElementWalker &walker)
static auto replace(ArrayRef< T > param, AttrSubElementReplacements &attrRepls, TypeSubElementReplacements &typeRepls)
static T replace(T param, AttrSubElementReplacements &attrRepls, TypeSubElementReplacements &typeRepls)
static void walk(const std::tuple< Ts... > &param, AttrTypeImmediateSubElementWalker &walker)
static auto replace(const std::tuple< Ts... > &param, AttrSubElementReplacements &attrRepls, TypeSubElementReplacements &typeRepls)
This class provides support for interacting with the SubElementInterfaces for different types of para...
static void walk(const T &param, AttrTypeImmediateSubElementWalker &walker)
Default walk implementation that does nothing.
void DefaultHandlerTag
Tag indicating that this handler does not support sub-elements.
static decltype(auto) replace(ParamT &&param, AttrSubElementReplacements &attrRepls, TypeSubElementReplacements &typeRepls)
Default replace implementation just forwards the parameter.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26