MLIR  16.0.0git
SubElementInterfaces.h
Go to the documentation of this file.
1 //===- SubElementInterfaces.h - Attr and Type SubElements -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains interfaces and utilities for querying the sub elements of
10 // an attribute or type.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_IR_SUBELEMENTINTERFACES_H
15 #define MLIR_IR_SUBELEMENTINTERFACES_H
16 
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Types.h"
19 #include "mlir/IR/Visitors.h"
20 
21 namespace mlir {
22 //===----------------------------------------------------------------------===//
23 /// AttrTypeReplacer
24 //===----------------------------------------------------------------------===//
25 
26 /// This class provides a utility for replacing attributes/types, and their sub
27 /// elements. Multiple replacement functions may be registered.
29 public:
30  //===--------------------------------------------------------------------===//
31  // Application
32  //===--------------------------------------------------------------------===//
33 
34  /// Replace the elements within the given operation. If `replaceAttrs` is
35  /// true, this updates the attribute dictionary of the operation. If
36  /// `replaceLocs` is true, this also updates its location, and the locations
37  /// of any nested block arguments. If `replaceTypes` is true, this also
38  /// updates the result types of the operation, and the types of any nested
39  /// block arguments.
40  void replaceElementsIn(Operation *op, bool replaceAttrs = true,
41  bool replaceLocs = false, bool replaceTypes = false);
42 
43  /// Replace the elements within the given operation, and all nested
44  /// operations.
45  void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs = true,
46  bool replaceLocs = false,
47  bool replaceTypes = false);
48 
49  /// Replace the given attribute/type, and recursively replace any sub
50  /// elements. Returns either the new attribute/type, or nullptr in the case of
51  /// failure.
53  Type replace(Type type);
54 
55  //===--------------------------------------------------------------------===//
56  // Registration
57  //===--------------------------------------------------------------------===//
58 
59  /// A replacement mapping function, which returns either None (to signal the
60  /// element wasn't handled), or a pair of the replacement element and a
61  /// WalkResult.
62  template <typename T>
64  template <typename T>
65  using ReplaceFn = std::function<ReplaceFnResult<T>(T)>;
66 
67  /// Register a replacement function for mapping a given attribute or type. A
68  /// replacement function must be convertible to any of the following
69  /// forms(where `T` is a class derived from `Type` or `Attribute`, and `BaseT`
70  /// is either `Type` or `Attribute` respectively):
71  ///
72  /// * Optional<BaseT>(T)
73  /// - This either returns a valid Attribute/Type in the case of success,
74  /// nullptr in the case of failure, or `std::nullopt` to signify that
75  /// additional replacement functions may be applied (i.e. this function
76  /// doesn't handle that instance).
77  ///
78  /// * Optional<std::pair<BaseT, WalkResult>>(T)
79  /// - Similar to the above, but also allows specifying a WalkResult to
80  /// control the replacement of sub elements of a given attribute or
81  /// type. Returning a `skip` result, for example, will not recursively
82  /// process the resultant attribute or type value.
83  ///
84  /// Note: When replacing, the mostly recently added replacement functions will
85  /// be invoked first.
87  attrReplacementFns.emplace_back(std::move(fn));
88  }
90  typeReplacementFns.push_back(std::move(fn));
91  }
92 
93  /// Register a replacement function that doesn't match the default signature,
94  /// either because it uses a derived parameter type, or it uses a simplified
95  /// result type.
96  template <typename FnT,
97  typename T = typename llvm::function_traits<
98  std::decay_t<FnT>>::template arg_t<0>,
99  typename BaseT = std::conditional_t<std::is_base_of_v<Attribute, T>,
100  Attribute, Type>,
101  typename ResultT = std::invoke_result_t<FnT, T>>
102  std::enable_if_t<!std::is_same_v<T, BaseT> ||
103  !std::is_convertible_v<ResultT, ReplaceFnResult<BaseT>>>
104  addReplacement(FnT &&callback) {
105  addReplacement([callback = std::forward<FnT>(callback)](
106  BaseT base) -> ReplaceFnResult<BaseT> {
107  if (auto derived = dyn_cast<T>(base)) {
108  if constexpr (std::is_convertible_v<ResultT, Optional<BaseT>>) {
109  Optional<BaseT> result = callback(derived);
110  return result ? std::make_pair(*result, WalkResult::advance())
112  } else {
113  return callback(derived);
114  }
115  }
116  return ReplaceFnResult<BaseT>();
117  });
118  }
119 
120 private:
121  /// Internal implementation of the `replace` methods above.
122  template <typename InterfaceT, typename ReplaceFns, typename T>
123  T replaceImpl(T element, ReplaceFns &replaceFns, DenseMap<T, T> &map);
124 
125  /// Replace the sub elements of the given interface.
126  template <typename InterfaceT, typename T = typename InterfaceT::ValueType>
127  T replaceSubElements(InterfaceT interface, DenseMap<T, T> &interfaceMap);
128 
129  /// The set of replacement functions that map sub elements.
130  std::vector<ReplaceFn<Attribute>> attrReplacementFns;
131  std::vector<ReplaceFn<Type>> typeReplacementFns;
132 
133  /// The set of cached mappings for attributes/types.
135  DenseMap<Type, Type> typeMap;
136 };
137 
138 //===----------------------------------------------------------------------===//
139 /// AttrTypeSubElementHandler
140 //===----------------------------------------------------------------------===//
141 
142 /// This class is used by AttrTypeSubElementHandler instances to walking sub
143 /// attributes and types.
145 public:
147  function_ref<void(Type)> walkTypesFn)
148  : walkAttrsFn(walkAttrsFn), walkTypesFn(walkTypesFn) {}
149 
150  /// Walk an attribute.
151  void walk(Attribute element) {
152  if (element)
153  walkAttrsFn(element);
154  }
155  /// Walk a type.
156  void walk(Type element) {
157  if (element)
158  walkTypesFn(element);
159  }
160  /// Walk a range of attributes or types.
161  template <typename RangeT>
162  void walkRange(RangeT &&elements) {
163  for (auto element : elements)
164  walk(element);
165  }
166 
167 private:
168  function_ref<void(Attribute)> walkAttrsFn;
169  function_ref<void(Type)> walkTypesFn;
170 };
171 
172 /// This class is used by AttrTypeSubElementHandler instances to process sub
173 /// element replacements.
174 template <typename T>
176 public:
178 
179  /// Take the first N replacements as an ArrayRef, dropping them from
180  /// this replacement list.
181  ArrayRef<T> take_front(unsigned n) {
182  ArrayRef<T> elements = repls.take_front(n);
183  repls = repls.drop_front(n);
184  return elements;
185  }
186 
187 private:
188  /// The current set of replacements.
189  ArrayRef<T> repls;
190 };
193 
194 /// This class provides support for interacting with the
195 /// SubElementInterfaces for different types of parameters. An
196 /// implementation of this class should be provided for any parameter class
197 /// that may contain an attribute or type. There are two main methods of
198 /// this class that need to be implemented:
199 ///
200 /// - walk
201 ///
202 /// This method should traverse into any sub elements of the parameter
203 /// using the provided walker, or by invoking handlers for sub-types.
204 ///
205 /// - replace
206 ///
207 /// This method should extract any necessary sub elements using the
208 /// provided replacer, or by invoking handlers for sub-types. The new
209 /// post-replacement parameter value should be returned.
210 ///
211 template <typename T, typename Enable = void>
213  /// Default walk implementation that does nothing.
214  static inline void walk(const T &param, AttrTypeSubElementWalker &walker) {}
215 
216  /// Default replace implementation just forwards the parameter.
217  template <typename ParamT>
218  static inline decltype(auto) replace(ParamT &&param,
219  AttrSubElementReplacements &attrRepls,
220  TypeSubElementReplacements &typeRepls) {
221  return std::forward<ParamT>(param);
222  }
223 
224  /// Tag indicating that this handler does not support sub-elements.
225  using DefaultHandlerTag = void;
226 };
227 
228 /// Detect if any of the given parameter types has a sub-element handler.
229 namespace detail {
230 template <typename T>
231 using has_default_sub_element_handler_t = decltype(T::DefaultHandlerTag);
232 } // namespace detail
233 template <typename... Ts>
234 inline constexpr bool has_sub_attr_or_type_v =
236  ...);
237 
238 /// Implementation for derived Attributes and Types.
239 template <typename T>
241  T, std::enable_if_t<std::is_base_of_v<Attribute, T> ||
242  std::is_base_of_v<Type, T>>> {
243  static void walk(T param, AttrTypeSubElementWalker &walker) {
244  walker.walk(param);
245  }
246  static T replace(T param, AttrSubElementReplacements &attrRepls,
247  TypeSubElementReplacements &typeRepls) {
248  if (!param)
249  return T();
250  if constexpr (std::is_base_of_v<Attribute, T>) {
251  return cast<T>(attrRepls.take_front(1)[0]);
252  } else {
253  return cast<T>(typeRepls.take_front(1)[0]);
254  }
255  }
256 };
257 template <>
259  template <typename T>
260  static void walk(T param, AttrTypeSubElementWalker &walker) {
261  walker.walk(param.getName());
262  walker.walk(param.getValue());
263  }
264  template <typename T>
265  static T replace(T param, AttrSubElementReplacements &attrRepls,
266  TypeSubElementReplacements &typeRepls) {
267  ArrayRef<Attribute> paramRepls = attrRepls.take_front(2);
268  return T(cast<decltype(param.getName())>(paramRepls[0]), paramRepls[1]);
269  }
270 };
271 /// Implementation for derived ArrayRef.
272 template <typename T>
274  std::enable_if_t<has_sub_attr_or_type_v<T>>> {
276 
277  static void walk(ArrayRef<T> param, AttrTypeSubElementWalker &walker) {
278  for (const T &subElement : param)
279  EltHandler::walk(subElement, walker);
280  }
281  static auto replace(ArrayRef<T> param, AttrSubElementReplacements &attrRepls,
282  TypeSubElementReplacements &typeRepls) {
283  // Normal attributes/types can extract using the replacer directly.
284  if constexpr (std::is_base_of_v<Attribute, T> &&
285  sizeof(T) == sizeof(Attribute)) {
286  ArrayRef<Attribute> attrs = attrRepls.take_front(param.size());
287  return ArrayRef<T>((const T *)attrs.data(), attrs.size());
288  } else if constexpr (std::is_base_of_v<Type, T> &&
289  sizeof(T) == sizeof(Type)) {
290  ArrayRef<Type> types = typeRepls.take_front(param.size());
291  return ArrayRef<T>((const T *)types.data(), types.size());
292  } else {
293  // Otherwise, we need to allocate storage for the new elements.
294  SmallVector<T> newElements;
295  for (const T &element : param)
296  newElements.emplace_back(
297  EltHandler::replace(element, attrRepls, typeRepls));
298  return newElements;
299  }
300  }
301 };
302 /// Implementation for Tuple.
303 template <typename... Ts>
305  std::tuple<Ts...>, std::enable_if_t<has_sub_attr_or_type_v<Ts...>>> {
306  static void walk(const std::tuple<Ts...> &param,
307  AttrTypeSubElementWalker &walker) {
308  std::apply(
309  [&](const Ts &...params) {
310  (AttrTypeSubElementHandler<Ts>::walk(params, walker), ...);
311  },
312  param);
313  }
314  static auto replace(const std::tuple<Ts...> &param,
315  AttrSubElementReplacements &attrRepls,
316  TypeSubElementReplacements &typeRepls) {
317  return std::apply(
318  [&](const Ts &...params)
319  -> std::tuple<decltype(AttrTypeSubElementHandler<Ts>::replace(
320  params, attrRepls, typeRepls))...> {
321  return {AttrTypeSubElementHandler<Ts>::replace(params, attrRepls,
322  typeRepls)...};
323  },
324  param);
325  }
326 };
327 
328 namespace detail {
329 template <typename T>
330 struct is_tuple : public std::false_type {};
331 template <typename... Ts>
332 struct is_tuple<std::tuple<Ts...>> : public std::true_type {};
333 template <typename T, typename... Ts>
334 using has_get_method = decltype(T::get(std::declval<Ts>()...));
335 
336 /// This function provides the underlying implementation for the
337 /// SubElementInterface walk method, using the key type of the derived
338 /// attribute/type to interact with the individual parameters.
339 template <typename T>
341  function_ref<void(Attribute)> walkAttrsFn,
342  function_ref<void(Type)> walkTypesFn) {
343  auto key = static_cast<typename T::ImplType *>(derived.getImpl())->getAsKey();
344 
345  // If we don't have any sub-elements, there is nothing to do.
346  if constexpr (!has_sub_attr_or_type_v<decltype(key)>) {
347  return;
348  } else {
349  AttrTypeSubElementWalker walker(walkAttrsFn, walkTypesFn);
350  AttrTypeSubElementHandler<decltype(key)>::walk(key, walker);
351  }
352 }
353 
354 /// This function invokes the proper `get` method for a type `T` with the given
355 /// values.
356 template <typename T, typename... Ts>
358  // Prefer a direct `get` method if one exists.
360  (void)ctx;
361  return T::get(std::forward<Ts>(params)...);
362  } else if constexpr (llvm::is_detected<has_get_method, T, MLIRContext *,
363  Ts...>::value) {
364  return T::get(ctx, std::forward<Ts>(params)...);
365  } else {
366  // Otherwise, pass to the base get.
367  return T::Base::get(ctx, std::forward<Ts>(params)...);
368  }
369 }
370 
371 /// This function provides the underlying implementation for the
372 /// SubElementInterface replace method, using the key type of the derived
373 /// attribute/type to interact with the individual parameters.
374 template <typename T>
376  ArrayRef<Type> &replTypes) {
377  auto key = static_cast<typename T::ImplType *>(derived.getImpl())->getAsKey();
378 
379  // If we don't have any sub-elements, we can just return the original.
380  if constexpr (!has_sub_attr_or_type_v<decltype(key)>) {
381  return derived;
382 
383  // Otherwise, we need to replace any necessary sub-elements.
384  } else {
385  AttrSubElementReplacements attrRepls(replAttrs);
386  TypeSubElementReplacements typeRepls(replTypes);
387  auto newKey = AttrTypeSubElementHandler<decltype(key)>::replace(
388  key, attrRepls, typeRepls);
389  if constexpr (is_tuple<decltype(key)>::value) {
390  return std::apply(
391  [&](auto &&...params) {
392  return constructSubElementReplacement<T>(
393  derived.getContext(),
394  std::forward<decltype(params)>(params)...);
395  },
396  newKey);
397  } else {
398  return constructSubElementReplacement<T>(derived.getContext(), newKey);
399  }
400  }
401 }
402 } // namespace detail
403 } // namespace mlir
404 
405 /// Include the definitions of the sub element interfaces.
406 #include "mlir/IR/SubElementAttrInterfaces.h.inc"
407 #include "mlir/IR/SubElementTypeInterfaces.h.inc"
408 
409 #endif // MLIR_IR_SUBELEMENTINTERFACES_H
static constexpr const bool value
void addReplacement(ReplaceFn< Type > fn)
void replaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation.
void recursivelyReplaceElementsIn(Operation *op, bool replaceAttrs=true, bool replaceLocs=false, bool replaceTypes=false)
Replace the elements within the given operation, and all nested operations.
std::enable_if_t<!std::is_same_v< T, BaseT >||!std::is_convertible_v< ResultT, ReplaceFnResult< BaseT > > > addReplacement(FnT &&callback)
Register a replacement function that doesn't match the default signature, either because it uses a de...
std::function< ReplaceFnResult< T >(T)> ReplaceFn
Attribute replace(Attribute attr)
Replace the given attribute/type, and recursively replace any sub elements.
void addReplacement(ReplaceFn< Attribute > fn)
Register a replacement function for mapping a given attribute or type.
This class is used by AttrTypeSubElementHandler instances to process sub element replacements.
AttrTypeSubElementReplacements(ArrayRef< T > repls)
ArrayRef< T > take_front(unsigned n)
Take the first N replacements as an ArrayRef, dropping them from this replacement list.
AttrTypeSubElementHandler.
void walkRange(RangeT &&elements)
Walk a range of attributes or types.
void walk(Type element)
Walk a type.
void walk(Attribute element)
Walk an attribute.
AttrTypeSubElementWalker(function_ref< void(Attribute)> walkAttrsFn, function_ref< void(Type)> walkTypesFn)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:150
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
static WalkResult advance()
Definition: Visitors.h:51
decltype(T::DefaultHandlerTag) has_default_sub_element_handler_t
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.cpp:24
T replaceImmediateSubElementsImpl(T derived, ArrayRef< Attribute > &replAttrs, ArrayRef< Type > &replTypes)
This function provides the underlying implementation for the SubElementInterface replace method,...
void walkImmediateSubElementsImpl(T derived, function_ref< void(Attribute)> walkAttrsFn, function_ref< void(Type)> walkTypesFn)
This function provides the underlying implementation for the SubElementInterface walk method,...
T constructSubElementReplacement(MLIRContext *ctx, Ts &&...params)
This function invokes the proper get method for a type T with the given values.
decltype(T::get(std::declval< Ts >()...)) has_get_method
Include the generated interface declarations.
constexpr bool has_sub_attr_or_type_v
static auto replace(ArrayRef< T > param, AttrSubElementReplacements &attrRepls, TypeSubElementReplacements &typeRepls)
static void walk(T param, AttrTypeSubElementWalker &walker)
static T replace(T param, AttrSubElementReplacements &attrRepls, TypeSubElementReplacements &typeRepls)
static T replace(T param, AttrSubElementReplacements &attrRepls, TypeSubElementReplacements &typeRepls)
static void walk(const std::tuple< Ts... > &param, AttrTypeSubElementWalker &walker)
static auto replace(const std::tuple< Ts... > &param, AttrSubElementReplacements &attrRepls, TypeSubElementReplacements &typeRepls)
This class provides support for interacting with the SubElementInterfaces for different types of para...
static void walk(const T &param, AttrTypeSubElementWalker &walker)
Default walk implementation that does nothing.
void DefaultHandlerTag
Tag indicating that this handler does not support sub-elements.
static decltype(auto) replace(ParamT &&param, AttrSubElementReplacements &attrRepls, TypeSubElementReplacements &typeRepls)
Default replace implementation just forwards the parameter.