MLIR  19.0.0git
TransformDialect.h
Go to the documentation of this file.
1 //===- TransformDialect.h - Transform Dialect Definition --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
10 #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
11 
12 #include "mlir/IR/Dialect.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Support/LLVM.h"
15 #include "mlir/Support/TypeID.h"
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/StringMap.h"
18 #include <optional>
19 
20 namespace mlir {
21 namespace transform {
22 
23 namespace detail {
24 /// Concrete base class for CRTP TransformDialectDataBase. Must not be used
25 /// directly.
27 public:
28  virtual ~TransformDialectDataBase() = default;
29 
30  /// Returns the dynamic type ID of the subclass.
31  TypeID getTypeID() const { return typeID; }
32 
33 protected:
34  /// Must be called by the subclass with the appropriate type ID.
36  : typeID(typeID), ctx(ctx) {}
37 
38  /// Return the MLIR context.
39  MLIRContext *getContext() const { return ctx; }
40 
41 private:
42  /// The type ID of the subclass.
43  const TypeID typeID;
44 
45  /// The MLIR context.
46  MLIRContext *ctx;
47 };
48 } // namespace detail
49 
50 /// Base class for additional data owned by the Transform dialect. Extensions
51 /// may communicate with each other using this data. The data object is
52 /// identified by the TypeID of the specific data subclass, querying the data of
53 /// the same subclass returns a reference to the same object. When a Transform
54 /// dialect extension is initialized, it can populate the data in the specific
55 /// subclass. When a Transform op is applied, it can read (but not mutate) the
56 /// data in the specific subclass, including the data provided by other
57 /// extensions.
58 ///
59 /// This follows CRTP: derived classes must list themselves as template
60 /// argument.
61 template <typename DerivedTy>
63 protected:
64  /// Forward the TypeID of the derived class to the base.
67 };
68 
69 #ifndef NDEBUG
70 namespace detail {
71 /// Asserts that the operations provided as template arguments implement the
72 /// TransformOpInterface and MemoryEffectsOpInterface. This must be a dynamic
73 /// assertion since interface implementations may be registered at runtime.
74 void checkImplementsTransformOpInterface(StringRef name, MLIRContext *context);
75 
76 /// Asserts that the type provided as template argument implements the
77 /// TransformHandleTypeInterface. This must be a dynamic assertion since
78 /// interface implementations may be registered at runtime.
80  MLIRContext *context);
81 } // namespace detail
82 #endif // NDEBUG
83 } // namespace transform
84 } // namespace mlir
85 
86 #include "mlir/Dialect/Transform/IR/TransformDialect.h.inc"
87 
88 namespace mlir {
89 namespace transform {
90 
91 /// Base class for extensions of the Transform dialect that supports injecting
92 /// operations into the Transform dialect at load time. Concrete extensions are
93 /// expected to derive this class and register operations in the constructor.
94 /// They can be registered with the DialectRegistry and automatically applied
95 /// to the Transform dialect when it is loaded.
96 ///
97 /// Derived classes are expected to define a `void init()` function in which
98 /// they can call various protected methods of the base class to register
99 /// extension operations and declare their dependencies.
100 ///
101 /// By default, the extension is configured both for construction of the
102 /// Transform IR and for its application to some payload. If only the
103 /// construction is desired, the extension can be switched to "build-only" mode
104 /// that avoids loading the dialects that are only necessary for transforming
105 /// the payload. To perform the switch, the extension must be wrapped into the
106 /// `BuildOnly` class template (see below) when it is registered, as in:
107 ///
108 /// dialectRegistry.addExtension<BuildOnly<MyTransformDialectExt>>();
109 ///
110 /// instead of:
111 ///
112 /// dialectRegistry.addExtension<MyTransformDialectExt>();
113 ///
114 /// Derived classes must reexport the constructor of this class or otherwise
115 /// forward its boolean argument to support this behavior.
116 template <typename DerivedTy, typename... ExtraDialects>
118  : public DialectExtension<DerivedTy, TransformDialect, ExtraDialects...> {
119  using Initializer = std::function<void(TransformDialect *)>;
120  using DialectLoader = std::function<void(MLIRContext *)>;
121 
122 public:
123  /// Extension application hook. Actually loads the dependent dialects and
124  /// registers the additional operations. Not expected to be called directly.
125  void apply(MLIRContext *context, TransformDialect *transformDialect,
126  ExtraDialects *...) const final {
127  for (const DialectLoader &loader : dialectLoaders)
128  loader(context);
129 
130  // Only load generated dialects if the user intends to apply
131  // transformations specified by the extension.
132  if (!buildOnly)
133  for (const DialectLoader &loader : generatedDialectLoaders)
134  loader(context);
135 
136  for (const Initializer &init : initializers)
137  init(transformDialect);
138  }
139 
140 protected:
141  using Base = TransformDialectExtension<DerivedTy, ExtraDialects...>;
142 
143  /// Extension constructor. The argument indicates whether to skip generated
144  /// dialects when applying the extension.
145  explicit TransformDialectExtension(bool buildOnly = false)
146  : buildOnly(buildOnly) {
147  static_cast<DerivedTy *>(this)->init();
148  }
149 
150  /// Registers a custom initialization step to be performed when the extension
151  /// is applied to the dialect while loading. This is discouraged in favor of
152  /// more specific calls `declareGeneratedDialect`, `addDialectDataInitializer`
153  /// etc. `Func` must be convertible to the `void (MLIRContext *)` form. It
154  /// will be called during the extension initialization and given the current
155  /// MLIR context. This may be used to attach additional interfaces that cannot
156  /// be attached elsewhere.
157  template <typename Func>
158  void addCustomInitializationStep(Func &&func) {
159  std::function<void(MLIRContext *)> initializer = func;
160  dialectLoaders.push_back(
161  [init = std::move(initializer)](MLIRContext *ctx) { init(ctx); });
162  }
163 
164  /// Registers the given function as one of the initializers for the
165  /// dialect-owned data of the kind specified as template argument. The
166  /// function must be convertible to the `void (DataTy &)` form. It will be
167  /// called during the extension initialization and will be given a mutable
168  /// reference to `DataTy`. The callback is expected to append data to the
169  /// given storage, and is not allowed to remove or destructively mutate the
170  /// existing data. The order in which callbacks from different extensions are
171  /// executed is unspecified so the callbacks may not rely on data being
172  /// already present. `DataTy` must be a class deriving `TransformDialectData`.
173  template <typename DataTy, typename Func>
174  void addDialectDataInitializer(Func &&func) {
175  static_assert(std::is_base_of_v<detail::TransformDialectDataBase, DataTy>,
176  "only classes deriving TransformDialectData are accepted");
177 
178  std::function<void(DataTy &)> initializer = func;
179  initializers.push_back(
180  [init = std::move(initializer)](TransformDialect *transformDialect) {
181  init(transformDialect->getOrCreateExtraData<DataTy>());
182  });
183  }
184 
185  /// Hook for derived classes to inject constructor behavior.
186  void init() {}
187 
188  /// Injects the operations into the Transform dialect. The operations must
189  /// implement the TransformOpInterface and MemoryEffectsOpInterface, and the
190  /// implementations must be already available when the operation is injected.
191  template <typename... OpTys>
193  initializers.push_back([](TransformDialect *transformDialect) {
194  transformDialect->addOperationsChecked<OpTys...>();
195  });
196  }
197 
198  /// Injects the types into the Transform dialect. The types must implement
199  /// the TransformHandleTypeInterface and the implementation must be already
200  /// available when the type is injected. Furthermore, the types must provide
201  /// a `getMnemonic` static method returning an object convertible to
202  /// `StringRef` that is unique across all injected types.
203  template <typename... TypeTys>
204  void registerTypes() {
205  initializers.push_back([](TransformDialect *transformDialect) {
206  transformDialect->addTypesChecked<TypeTys...>();
207  });
208  }
209 
210  /// Declares that this Transform dialect extension depends on the dialect
211  /// provided as template parameter. When the Transform dialect is loaded,
212  /// dependent dialects will be loaded as well. This is intended for dialects
213  /// that contain attributes and types used in creation and canonicalization of
214  /// the injected operations, similarly to how the dialect definition may list
215  /// dependent dialects. This is *not* intended for dialects entities from
216  /// which may be produced when applying the transformations specified by ops
217  /// registered by this extension.
218  template <typename DialectTy>
220  dialectLoaders.push_back(
221  [](MLIRContext *context) { context->loadDialect<DialectTy>(); });
222  }
223 
224  /// Declares that the transformations associated with the operations
225  /// registered by this dialect extension may produce operations from the
226  /// dialect provided as template parameter while processing payload IR that
227  /// does not contain the operations from said dialect. This is similar to
228  /// dependent dialects of a pass. These dialects will be loaded along with the
229  /// transform dialect unless the extension is in the build-only mode.
230  template <typename DialectTy>
232  generatedDialectLoaders.push_back(
233  [](MLIRContext *context) { context->loadDialect<DialectTy>(); });
234  }
235 
236 private:
237  /// Callbacks performing extension initialization, e.g., registering ops,
238  /// types and defining the additional data.
239  SmallVector<Initializer> initializers;
240 
241  /// Callbacks loading the dependent dialects, i.e. the dialect needed for the
242  /// extension ops.
243  SmallVector<DialectLoader> dialectLoaders;
244 
245  /// Callbacks loading the generated dialects, i.e. the dialects produced when
246  /// applying the transformations.
247  SmallVector<DialectLoader> generatedDialectLoaders;
248 
249  /// Indicates that the extension is in build-only mode.
250  bool buildOnly;
251 };
252 
253 template <typename OpTy>
254 void TransformDialect::addOperationIfNotRegistered() {
255  std::optional<RegisteredOperationName> opName =
256  RegisteredOperationName::lookup(TypeID::get<OpTy>(), getContext());
257  if (!opName) {
258  addOperations<OpTy>();
259 #ifndef NDEBUG
260  StringRef name = OpTy::getOperationName();
262 #endif // NDEBUG
263  return;
264  }
265 
266  if (LLVM_LIKELY(opName->getTypeID() == TypeID::get<OpTy>()))
267  return;
268 
269  reportDuplicateOpRegistration(OpTy::getOperationName());
270 }
271 
272 template <typename Type>
273 void TransformDialect::addTypeIfNotRegistered() {
274  // Use the address of the parse method as a proxy for identifying whether we
275  // are registering the same type class for the same mnemonic.
276  StringRef mnemonic = Type::getMnemonic();
277  auto [it, inserted] = typeParsingHooks.try_emplace(mnemonic, Type::parse);
278  if (!inserted) {
279  const ExtensionTypeParsingHook &parsingHook = it->getValue();
280  if (parsingHook != &Type::parse)
281  reportDuplicateTypeRegistration(mnemonic);
282  else
283  return;
284  }
285  typePrintingHooks.try_emplace(
286  TypeID::get<Type>(), +[](mlir::Type type, AsmPrinter &printer) {
287  printer << Type::getMnemonic();
288  cast<Type>(type).print(printer);
289  });
290  addTypes<Type>();
291 
292 #ifndef NDEBUG
294  getContext());
295 #endif // NDEBUG
296 }
297 
298 template <typename DataTy>
299 DataTy &TransformDialect::getOrCreateExtraData() {
300  TypeID typeID = TypeID::get<DataTy>();
301  auto it = extraData.find(typeID);
302  if (it != extraData.end())
303  return static_cast<DataTy &>(*it->getSecond());
304 
305  auto emplaced =
306  extraData.try_emplace(typeID, std::make_unique<DataTy>(getContext()));
307  return static_cast<DataTy &>(*emplaced.first->getSecond());
308 }
309 
310 /// A wrapper for transform dialect extensions that forces them to be
311 /// constructed in the build-only mode.
312 template <typename DerivedTy>
313 class BuildOnly : public DerivedTy {
314 public:
315  BuildOnly() : DerivedTy(/*buildOnly=*/true) {}
316 };
317 
318 } // namespace transform
319 } // namespace mlir
320 
321 #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
static MLIRContext * getContext(OpFoldResult val)
This class represents a dialect extension anchored on the given set of dialects.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void loadDialect()
Load a dialect in the context.
Definition: MLIRContext.h:107
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
A wrapper for transform dialect extensions that forces them to be constructed in the build-only mode.
Base class for additional data owned by the Transform dialect.
TransformDialectData(MLIRContext *ctx)
Forward the TypeID of the derived class to the base.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
void init()
Hook for derived classes to inject constructor behavior.
TransformDialectExtension(bool buildOnly=false)
Extension constructor.
void registerTransformOps()
Injects the operations into the Transform dialect.
void declareDependentDialect()
Declares that this Transform dialect extension depends on the dialect provided as template parameter.
void declareGeneratedDialect()
Declares that the transformations associated with the operations registered by this dialect extension...
void addDialectDataInitializer(Func &&func)
Registers the given function as one of the initializers for the dialect-owned data of the kind specif...
void apply(MLIRContext *context, TransformDialect *transformDialect, ExtraDialects *...) const final
Extension application hook.
void registerTypes()
Injects the types into the Transform dialect.
void addCustomInitializationStep(Func &&func)
Registers a custom initialization step to be performed when the extension is applied to the dialect w...
Concrete base class for CRTP TransformDialectDataBase.
TransformDialectDataBase(TypeID typeID, MLIRContext *ctx)
Must be called by the subclass with the appropriate type ID.
MLIRContext * getContext() const
Return the MLIR context.
TypeID getTypeID() const
Returns the dynamic type ID of the subclass.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
void checkImplementsTransformHandleTypeInterface(TypeID typeID, MLIRContext *context)
Asserts that the type provided as template argument implements the TransformHandleTypeInterface.
void checkImplementsTransformOpInterface(StringRef name, MLIRContext *context)
Asserts that the operations provided as template arguments implement the TransformOpInterface and Mem...
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...