MLIR 22.0.0git
TransformDialect.h
Go to the documentation of this file.
1//===- TransformDialect.h - Transform Dialect Definition --------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
10#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
11
12#include "mlir/IR/Dialect.h"
14#include "mlir/Support/LLVM.h"
15#include "mlir/Support/TypeID.h"
16#include "llvm/ADT/DenseMap.h"
17#include "llvm/ADT/StringMap.h"
18#include <optional>
19
20namespace mlir {
21namespace transform {
22
23namespace detail {
24/// Concrete base class for CRTP TransformDialectDataBase. Must not be used
25/// directly.
27public:
28 virtual ~TransformDialectDataBase() = default;
29
30 /// Returns the dynamic type ID of the subclass.
31 TypeID getTypeID() const { return typeID; }
32
33protected:
34 /// Must be called by the subclass with the appropriate type ID.
36 : typeID(typeID), ctx(ctx) {}
37
38 /// Return the MLIR context.
39 MLIRContext *getContext() const { return ctx; }
40
41private:
42 /// The type ID of the subclass.
43 const TypeID typeID;
44
45 /// The MLIR context.
46 MLIRContext *ctx;
47};
48} // namespace detail
49
50/// Base class for additional data owned by the Transform dialect. Extensions
51/// may communicate with each other using this data. The data object is
52/// identified by the TypeID of the specific data subclass, querying the data of
53/// the same subclass returns a reference to the same object. When a Transform
54/// dialect extension is initialized, it can populate the data in the specific
55/// subclass. When a Transform op is applied, it can read (but not mutate) the
56/// data in the specific subclass, including the data provided by other
57/// extensions.
58///
59/// This follows CRTP: derived classes must list themselves as template
60/// argument.
61template <typename DerivedTy>
63protected:
64 /// Forward the TypeID of the derived class to the base.
67};
68
69#ifndef NDEBUG
70namespace detail {
71/// Asserts that the operations provided as template arguments implement the
72/// TransformOpInterface and MemoryEffectsOpInterface. This must be a dynamic
73/// assertion since interface implementations may be registered at runtime.
74void checkImplementsTransformOpInterface(StringRef name, MLIRContext *context);
75
76/// Asserts that the type provided as template argument implements the
77/// TransformHandleTypeInterface. This must be a dynamic assertion since
78/// interface implementations may be registered at runtime.
80 MLIRContext *context);
81} // namespace detail
82#endif // NDEBUG
83} // namespace transform
84} // namespace mlir
85
86#include "mlir/Dialect/Transform/IR/TransformDialect.h.inc"
87
88namespace mlir {
89namespace transform {
90
91/// Base class for extensions of the Transform dialect that supports injecting
92/// operations into the Transform dialect at load time. Concrete extensions are
93/// expected to derive this class and register operations in the constructor.
94/// They can be registered with the DialectRegistry and automatically applied
95/// to the Transform dialect when it is loaded.
96///
97/// Derived classes are expected to define a `void init()` function in which
98/// they can call various protected methods of the base class to register
99/// extension operations and declare their dependencies.
100///
101/// By default, the extension is configured both for construction of the
102/// Transform IR and for its application to some payload. If only the
103/// construction is desired, the extension can be switched to "build-only" mode
104/// that avoids loading the dialects that are only necessary for transforming
105/// the payload. To perform the switch, the extension must be wrapped into the
106/// `BuildOnly` class template (see below) when it is registered, as in:
107///
108/// dialectRegistry.addExtension<BuildOnly<MyTransformDialectExt>>();
109///
110/// instead of:
111///
112/// dialectRegistry.addExtension<MyTransformDialectExt>();
113///
114/// Derived classes must reexport the constructor of this class or otherwise
115/// forward its boolean argument to support this behavior.
116template <typename DerivedTy, typename... ExtraDialects>
118 : public DialectExtension<DerivedTy, TransformDialect, ExtraDialects...> {
119 using Initializer = std::function<void(TransformDialect *)>;
120 using DialectLoader = std::function<void(MLIRContext *)>;
121
122public:
123 /// Extension application hook. Actually loads the dependent dialects and
124 /// registers the additional operations. Not expected to be called directly.
125 void apply(MLIRContext *context, TransformDialect *transformDialect,
126 ExtraDialects *...) const final {
127 for (const DialectLoader &loader : dialectLoaders)
128 loader(context);
129
130 // Only load generated dialects if the user intends to apply
131 // transformations specified by the extension.
132 if (!buildOnly)
133 for (const DialectLoader &loader : generatedDialectLoaders)
134 loader(context);
135
136 for (const Initializer &init : initializers)
137 init(transformDialect);
138 }
139
140protected:
141 using Base = TransformDialectExtension<DerivedTy, ExtraDialects...>;
142
143 /// Extension constructor. The argument indicates whether to skip generated
144 /// dialects when applying the extension.
145 explicit TransformDialectExtension(bool buildOnly = false)
146 : buildOnly(buildOnly) {
147 static_cast<DerivedTy *>(this)->init();
148 }
149
150 /// Registers a custom initialization step to be performed when the extension
151 /// is applied to the dialect while loading. This is discouraged in favor of
152 /// more specific calls `declareGeneratedDialect`, `addDialectDataInitializer`
153 /// etc. `Func` must be convertible to the `void (MLIRContext *)` form. It
154 /// will be called during the extension initialization and given the current
155 /// MLIR context. This may be used to attach additional interfaces that cannot
156 /// be attached elsewhere.
157 template <typename Func>
159 std::function<void(MLIRContext *)> initializer = func;
160 dialectLoaders.push_back(
161 [init = std::move(initializer)](MLIRContext *ctx) { init(ctx); });
162 }
163
164 /// Registers the given function as one of the initializers for the
165 /// dialect-owned data of the kind specified as template argument. The
166 /// function must be convertible to the `void (DataTy &)` form. It will be
167 /// called during the extension initialization and will be given a mutable
168 /// reference to `DataTy`. The callback is expected to append data to the
169 /// given storage, and is not allowed to remove or destructively mutate the
170 /// existing data. The order in which callbacks from different extensions are
171 /// executed is unspecified so the callbacks may not rely on data being
172 /// already present. `DataTy` must be a class deriving `TransformDialectData`.
173 template <typename DataTy, typename Func>
175 static_assert(std::is_base_of_v<detail::TransformDialectDataBase, DataTy>,
176 "only classes deriving TransformDialectData are accepted");
177
178 std::function<void(DataTy &)> initializer = func;
179 initializers.push_back(
180 [init = std::move(initializer)](TransformDialect *transformDialect) {
181 init(transformDialect->getOrCreateExtraData<DataTy>());
182 });
183 }
184
185 /// Hook for derived classes to inject constructor behavior.
186 void init() {}
187
188 /// Injects the operations into the Transform dialect. The operations must
189 /// implement the TransformOpInterface and MemoryEffectsOpInterface, and the
190 /// implementations must be already available when the operation is injected.
191 template <typename... OpTys>
193 initializers.push_back([](TransformDialect *transformDialect) {
194 transformDialect->addOperationsChecked<OpTys...>();
195 });
196 }
197
198 /// Injects the types into the Transform dialect. The types must implement
199 /// the TransformHandleTypeInterface and the implementation must be already
200 /// available when the type is injected. Furthermore, the types must provide
201 /// a `getMnemonic` static method returning an object convertible to
202 /// `StringRef` that is unique across all injected types.
203 template <typename... TypeTys>
205 initializers.push_back([](TransformDialect *transformDialect) {
206 transformDialect->addTypesChecked<TypeTys...>();
207 });
208 }
209
210 /// Declares that this Transform dialect extension depends on the dialect
211 /// provided as template parameter. When the Transform dialect is loaded,
212 /// dependent dialects will be loaded as well. This is intended for dialects
213 /// that contain attributes and types used in creation and canonicalization of
214 /// the injected operations, similarly to how the dialect definition may list
215 /// dependent dialects. This is *not* intended for dialects entities from
216 /// which may be produced when applying the transformations specified by ops
217 /// registered by this extension.
218 template <typename DialectTy>
220 dialectLoaders.push_back(
221 [](MLIRContext *context) { context->loadDialect<DialectTy>(); });
222 }
223
224 /// Declares that the transformations associated with the operations
225 /// registered by this dialect extension may produce operations from the
226 /// dialect provided as template parameter while processing payload IR that
227 /// does not contain the operations from said dialect. This is similar to
228 /// dependent dialects of a pass. These dialects will be loaded along with the
229 /// transform dialect unless the extension is in the build-only mode.
230 template <typename DialectTy>
232 generatedDialectLoaders.push_back(
233 [](MLIRContext *context) { context->loadDialect<DialectTy>(); });
234 }
235
236private:
237 /// Callbacks performing extension initialization, e.g., registering ops,
238 /// types and defining the additional data.
239 SmallVector<Initializer> initializers;
240
241 /// Callbacks loading the dependent dialects, i.e. the dialect needed for the
242 /// extension ops.
243 SmallVector<DialectLoader> dialectLoaders;
244
245 /// Callbacks loading the generated dialects, i.e. the dialects produced when
246 /// applying the transformations.
247 SmallVector<DialectLoader> generatedDialectLoaders;
248
249 /// Indicates that the extension is in build-only mode.
250 bool buildOnly;
251};
252
253template <typename OpTy>
254void TransformDialect::addOperationIfNotRegistered() {
255 std::optional<RegisteredOperationName> opName =
257 if (!opName) {
258 addOperations<OpTy>();
259#ifndef NDEBUG
260 StringRef name = OpTy::getOperationName();
262#endif // NDEBUG
263 return;
264 }
265
266 if (LLVM_LIKELY(opName->getTypeID() == TypeID::get<OpTy>()))
267 return;
268
269 reportDuplicateOpRegistration(OpTy::getOperationName());
270}
271
272template <typename Type>
273void TransformDialect::addTypeIfNotRegistered() {
274 // Use the address of the parse method as a proxy for identifying whether we
275 // are registering the same type class for the same mnemonic.
276 StringRef mnemonic = Type::getMnemonic();
277 auto [it, inserted] = typeParsingHooks.try_emplace(mnemonic, Type::parse);
278 if (!inserted) {
279 const ExtensionTypeParsingHook &parsingHook = it->getValue();
280 if (parsingHook != &Type::parse)
281 reportDuplicateTypeRegistration(mnemonic);
282 else
283 return;
284 }
285 typePrintingHooks.try_emplace(
286 TypeID::get<Type>(), +[](mlir::Type type, AsmPrinter &printer) {
287 printer << Type::getMnemonic();
288 cast<Type>(type).print(printer);
289 });
290 addTypes<Type>();
291
292#ifndef NDEBUG
294 getContext());
295#endif // NDEBUG
296}
297
298template <typename DataTy>
299DataTy &TransformDialect::getOrCreateExtraData() {
300 TypeID typeID = TypeID::get<DataTy>();
301 auto [it, inserted] = extraData.try_emplace(typeID);
302 if (inserted)
303 it->getSecond() = std::make_unique<DataTy>(getContext());
304 return static_cast<DataTy &>(*it->getSecond());
305}
306
307/// A wrapper for transform dialect extensions that forces them to be
308/// constructed in the build-only mode.
309template <typename DerivedTy>
310class BuildOnly : public DerivedTy {
311public:
312 BuildOnly() : DerivedTy(/*buildOnly=*/true) {}
313};
314
315} // namespace transform
316} // namespace mlir
317
318#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
true
Given two iterators into the same block, return "true" if a is before `b.
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void loadDialect()
Load a dialect in the context.
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
static TypeID get()
Construct a type info object for the given type T.
Definition TypeID.h:245
TransformDialectData(MLIRContext *ctx)
Forward the TypeID of the derived class to the base.
TransformDialectExtension(bool buildOnly=false)
Extension constructor.
TransformDialectExtension< DerivedTy, ExtraDialects... > Base
void registerTransformOps()
Injects the operations into the Transform dialect.
void declareDependentDialect()
Declares that this Transform dialect extension depends on the dialect provided as template parameter.
void declareGeneratedDialect()
Declares that the transformations associated with the operations registered by this dialect extension...
void addDialectDataInitializer(Func &&func)
Registers the given function as one of the initializers for the dialect-owned data of the kind specif...
void apply(MLIRContext *context, TransformDialect *transformDialect, ExtraDialects *...) const final
Extension application hook.
void registerTypes()
Injects the types into the Transform dialect.
void addCustomInitializationStep(Func &&func)
Registers a custom initialization step to be performed when the extension is applied to the dialect w...
Concrete base class for CRTP TransformDialectDataBase.
MLIRContext * getContext() const
Return the MLIR context.
TransformDialectDataBase(TypeID typeID, MLIRContext *ctx)
Must be called by the subclass with the appropriate type ID.
TypeID getTypeID() const
Returns the dynamic type ID of the subclass.
AttrTypeReplacer.
void checkImplementsTransformHandleTypeInterface(TypeID typeID, MLIRContext *context)
Asserts that the type provided as template argument implements the TransformHandleTypeInterface.
void checkImplementsTransformOpInterface(StringRef name, MLIRContext *context)
Asserts that the operations provided as template arguments implement the TransformOpInterface and Mem...
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...