MLIR 23.0.0git
TransformDialect.h
Go to the documentation of this file.
1//===- TransformDialect.h - Transform Dialect Definition --------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
10#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
11
12#include "mlir/IR/Dialect.h"
14#include "mlir/Support/LLVM.h"
15#include "mlir/Support/TypeID.h"
16#include "llvm/ADT/DenseMap.h"
17#include "llvm/ADT/StringMap.h"
18#include <optional>
19
20namespace mlir {
21namespace transform {
22
23namespace detail {
24/// Concrete base class for CRTP TransformDialectDataBase. Must not be used
25/// directly.
27public:
28 virtual ~TransformDialectDataBase() = default;
29
30 /// Returns the dynamic type ID of the subclass.
31 TypeID getTypeID() const { return typeID; }
32
33protected:
34 /// Must be called by the subclass with the appropriate type ID.
36 : typeID(typeID), ctx(ctx) {}
37
38 /// Return the MLIR context.
39 MLIRContext *getContext() const { return ctx; }
40
41private:
42 /// The type ID of the subclass.
43 const TypeID typeID;
44
45 /// The MLIR context.
46 MLIRContext *ctx;
47};
48} // namespace detail
49
50/// Base class for additional data owned by the Transform dialect. Extensions
51/// may communicate with each other using this data. The data object is
52/// identified by the TypeID of the specific data subclass, querying the data of
53/// the same subclass returns a reference to the same object. When a Transform
54/// dialect extension is initialized, it can populate the data in the specific
55/// subclass. When a Transform op is applied, it can read (but not mutate) the
56/// data in the specific subclass, including the data provided by other
57/// extensions.
58///
59/// This follows CRTP: derived classes must list themselves as template
60/// argument.
61template <typename DerivedTy>
63protected:
64 /// Forward the TypeID of the derived class to the base.
67};
68
69#if LLVM_ENABLE_ABI_BREAKING_CHECKS
70namespace detail {
71/// Asserts that the operations provided as template arguments implement the
72/// TransformOpInterface and MemoryEffectsOpInterface. This must be a dynamic
73/// assertion since interface implementations may be registered at runtime.
74void checkImplementsTransformOpInterface(StringRef name, MLIRContext *context);
75
76/// Asserts that the type provided as template argument implements the
77/// TransformHandleTypeInterface. This must be a dynamic assertion since
78/// interface implementations may be registered at runtime.
79void checkImplementsTransformHandleTypeInterface(TypeID typeID,
80 MLIRContext *context);
81} // namespace detail
82#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
83} // namespace transform
84} // namespace mlir
85
86#include "mlir/Dialect/Transform/IR/TransformDialect.h.inc"
87
88namespace mlir {
89namespace transform {
90
91/// Base class for extensions of the Transform dialect that supports injecting
92/// operations into the Transform dialect at load time. Concrete extensions are
93/// expected to derive this class and register operations in the constructor.
94/// They can be registered with the DialectRegistry and automatically applied
95/// to the Transform dialect when it is loaded.
96///
97/// Derived classes are expected to define a `void init()` function in which
98/// they can call various protected methods of the base class to register
99/// extension operations and declare their dependencies.
100///
101/// By default, the extension is configured both for construction of the
102/// Transform IR and for its application to some payload. If only the
103/// construction is desired, the extension can be switched to "build-only" mode
104/// that avoids loading the dialects that are only necessary for transforming
105/// the payload. To perform the switch, the extension must be wrapped into the
106/// `BuildOnly` class template (see below) when it is registered, as in:
107///
108/// dialectRegistry.addExtension<BuildOnly<MyTransformDialectExt>>();
109///
110/// instead of:
111///
112/// dialectRegistry.addExtension<MyTransformDialectExt>();
113///
114/// Derived classes must reexport the constructor of this class or otherwise
115/// forward its boolean argument to support this behavior.
116template <typename DerivedTy, typename... ExtraDialects>
118 : public DialectExtension<DerivedTy, TransformDialect, ExtraDialects...> {
119 using Initializer = std::function<void(TransformDialect *)>;
120 using DialectLoader = std::function<void(MLIRContext *)>;
121
122public:
123 /// Extension application hook. Actually loads the dependent dialects and
124 /// registers the additional operations. Not expected to be called directly.
125 void apply(MLIRContext *context, TransformDialect *transformDialect,
126 ExtraDialects *...) const final {
127 for (const DialectLoader &loader : dialectLoaders)
128 loader(context);
129
130 // Only load generated dialects if the user intends to apply
131 // transformations specified by the extension.
132 if (!buildOnly)
133 for (const DialectLoader &loader : generatedDialectLoaders)
134 loader(context);
135
136 for (const Initializer &init : initializers)
137 init(transformDialect);
138 }
139
140protected:
141 using Base = TransformDialectExtension<DerivedTy, ExtraDialects...>;
142
143 /// Extension constructor. The argument indicates whether to skip generated
144 /// dialects when applying the extension.
145 explicit TransformDialectExtension(bool buildOnly = false)
146 : buildOnly(buildOnly) {
147 static_cast<DerivedTy *>(this)->init();
148 }
149
150 /// Registers a custom initialization step to be performed when the extension
151 /// is applied to the dialect while loading. This is discouraged in favor of
152 /// more specific calls `declareGeneratedDialect`, `addDialectDataInitializer`
153 /// etc. `Func` must be convertible to the `void (MLIRContext *)` form. It
154 /// will be called during the extension initialization and given the current
155 /// MLIR context. This may be used to attach additional interfaces that cannot
156 /// be attached elsewhere.
157 template <typename Func>
159 std::function<void(MLIRContext *)> initializer = func;
160 dialectLoaders.push_back(
161 [init = std::move(initializer)](MLIRContext *ctx) { init(ctx); });
162 }
163
164 /// Registers the given function as one of the initializers for the
165 /// dialect-owned data of the kind specified as template argument. The
166 /// function must be convertible to the `void (DataTy &)` form. It will be
167 /// called during the extension initialization and will be given a mutable
168 /// reference to `DataTy`. The callback is expected to append data to the
169 /// given storage, and is not allowed to remove or destructively mutate the
170 /// existing data. The order in which callbacks from different extensions are
171 /// executed is unspecified so the callbacks may not rely on data being
172 /// already present. `DataTy` must be a class deriving `TransformDialectData`.
173 template <typename DataTy, typename Func>
175 static_assert(std::is_base_of_v<detail::TransformDialectDataBase, DataTy>,
176 "only classes deriving TransformDialectData are accepted");
177
178 std::function<void(DataTy &)> initializer = func;
179 initializers.push_back(
180 [init = std::move(initializer)](TransformDialect *transformDialect) {
181 init(transformDialect->getOrCreateExtraData<DataTy>());
182 });
183 }
184
185 /// Hook for derived classes to inject constructor behavior.
186 void init() {}
187
188 /// Injects the operations into the Transform dialect. The operations must
189 /// implement the TransformOpInterface and MemoryEffectsOpInterface, and the
190 /// implementations must be already available when the operation is injected.
191 template <typename... OpTys>
193 initializers.push_back([](TransformDialect *transformDialect) {
194 transformDialect->addOperationsChecked<OpTys...>();
195 });
196 }
197
198 /// Injects the attributes into the Transform dialect. The attributes must
199 /// provide a `getMnemonic` static method returning an object convertible to
200 /// `StringRef` that is unique across all injected attributes.
201 template <typename... AttrTys>
203 initializers.push_back([](TransformDialect *transformDialect) {
204 transformDialect->addAttributesChecked<AttrTys...>();
205 });
206 }
207
208 /// Injects the types into the Transform dialect. The types must implement
209 /// the TransformHandleTypeInterface and the implementation must be already
210 /// available when the type is injected. Furthermore, the types must provide
211 /// a `getMnemonic` static method returning an object convertible to
212 /// `StringRef` that is unique across all injected types.
213 template <typename... TypeTys>
215 initializers.push_back([](TransformDialect *transformDialect) {
216 transformDialect->addTypesChecked<TypeTys...>();
217 });
218 }
219
220 /// Declares that this Transform dialect extension depends on the dialect
221 /// provided as template parameter. When the Transform dialect is loaded,
222 /// dependent dialects will be loaded as well. This is intended for dialects
223 /// that contain attributes and types used in creation and canonicalization of
224 /// the injected operations, similarly to how the dialect definition may list
225 /// dependent dialects. This is *not* intended for dialects entities from
226 /// which may be produced when applying the transformations specified by ops
227 /// registered by this extension.
228 template <typename DialectTy>
230 dialectLoaders.push_back(
231 [](MLIRContext *context) { context->loadDialect<DialectTy>(); });
232 }
233
234 /// Declares that the transformations associated with the operations
235 /// registered by this dialect extension may produce operations from the
236 /// dialect provided as template parameter while processing payload IR that
237 /// does not contain the operations from said dialect. This is similar to
238 /// dependent dialects of a pass. These dialects will be loaded along with the
239 /// transform dialect unless the extension is in the build-only mode.
240 template <typename DialectTy>
242 generatedDialectLoaders.push_back(
243 [](MLIRContext *context) { context->loadDialect<DialectTy>(); });
244 }
245
246private:
247 /// Callbacks performing extension initialization, e.g., registering ops,
248 /// types and defining the additional data.
249 SmallVector<Initializer> initializers;
250
251 /// Callbacks loading the dependent dialects, i.e. the dialect needed for the
252 /// extension ops.
253 SmallVector<DialectLoader> dialectLoaders;
254
255 /// Callbacks loading the generated dialects, i.e. the dialects produced when
256 /// applying the transformations.
257 SmallVector<DialectLoader> generatedDialectLoaders;
258
259 /// Indicates that the extension is in build-only mode.
260 bool buildOnly;
261};
262
263template <typename OpTy>
264void TransformDialect::addOperationIfNotRegistered() {
265 std::optional<RegisteredOperationName> opName =
267 if (!opName) {
268 addOperations<OpTy>();
269#if LLVM_ENABLE_ABI_BREAKING_CHECKS
270 StringRef name = OpTy::getOperationName();
271 detail::checkImplementsTransformOpInterface(name, getContext());
272#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
273 return;
274 }
275
276 if (LLVM_LIKELY(opName->getTypeID() == TypeID::get<OpTy>()))
277 return;
278
279 reportDuplicateOpRegistration(OpTy::getOperationName());
280}
281
282template <typename AttrTy>
283void TransformDialect::addAttributeIfNotRegistered() {
284 // Use the address of the parse method as a proxy for identifying whether we
285 // are registering the same type class for the same mnemonic.
286 StringRef mnemonic = AttrTy::getMnemonic();
287 auto [it, inserted] =
288 attributeParsingHooks.try_emplace(mnemonic, AttrTy::parse);
289 if (!inserted) {
290 const ExtensionAttributeParsingHook &parsingHook = it->getValue();
291 if (parsingHook != &AttrTy::parse)
292 reportDuplicateAttributeRegistration(mnemonic);
293 else
294 return;
295 }
296 attributePrintingHooks.try_emplace(
298 +[](mlir::Attribute attribute, AsmPrinter &printer) {
299 printer << AttrTy::getMnemonic();
300 cast<AttrTy>(attribute).print(printer);
301 });
302 addAttributes<AttrTy>();
303}
304
305template <typename Type>
306void TransformDialect::addTypeIfNotRegistered() {
307 // Use the address of the parse method as a proxy for identifying whether we
308 // are registering the same type class for the same mnemonic.
309 StringRef mnemonic = Type::getMnemonic();
310 auto [it, inserted] = typeParsingHooks.try_emplace(mnemonic, Type::parse);
311 if (!inserted) {
312 const ExtensionTypeParsingHook &parsingHook = it->getValue();
313 if (parsingHook != &Type::parse)
314 reportDuplicateTypeRegistration(mnemonic);
315 else
316 return;
317 }
318 typePrintingHooks.try_emplace(
319 TypeID::get<Type>(), +[](mlir::Type type, AsmPrinter &printer) {
320 printer << Type::getMnemonic();
321 cast<Type>(type).print(printer);
322 });
323 addTypes<Type>();
324
325#if LLVM_ENABLE_ABI_BREAKING_CHECKS
326 detail::checkImplementsTransformHandleTypeInterface(TypeID::get<Type>(),
327 getContext());
328#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
329}
330
331template <typename DataTy>
332DataTy &TransformDialect::getOrCreateExtraData() {
333 TypeID typeID = TypeID::get<DataTy>();
334 auto [it, inserted] = extraData.try_emplace(typeID);
335 if (inserted)
336 it->getSecond() = std::make_unique<DataTy>(getContext());
337 return static_cast<DataTy &>(*it->getSecond());
338}
339
340/// A wrapper for transform dialect extensions that forces them to be
341/// constructed in the build-only mode.
342template <typename DerivedTy>
343class BuildOnly : public DerivedTy {
344public:
345 BuildOnly() : DerivedTy(/*buildOnly=*/true) {}
346};
347
348} // namespace transform
349} // namespace mlir
350
351#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
true
Given two iterators into the same block, return "true" if a is before `b.
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void loadDialect()
Load a dialect in the context.
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
static TypeID get()
Construct a type info object for the given type T.
Definition TypeID.h:245
TransformDialectData(MLIRContext *ctx)
Forward the TypeID of the derived class to the base.
TransformDialectExtension(bool buildOnly=false)
Extension constructor.
TransformDialectExtension< DerivedTy, ExtraDialects... > Base
void registerTransformOps()
Injects the operations into the Transform dialect.
void declareDependentDialect()
Declares that this Transform dialect extension depends on the dialect provided as template parameter.
void declareGeneratedDialect()
Declares that the transformations associated with the operations registered by this dialect extension...
void addDialectDataInitializer(Func &&func)
Registers the given function as one of the initializers for the dialect-owned data of the kind specif...
void apply(MLIRContext *context, TransformDialect *transformDialect, ExtraDialects *...) const final
Extension application hook.
void registerTypes()
Injects the types into the Transform dialect.
void addCustomInitializationStep(Func &&func)
Registers a custom initialization step to be performed when the extension is applied to the dialect w...
void registerAttributes()
Injects the attributes into the Transform dialect.
Concrete base class for CRTP TransformDialectDataBase.
MLIRContext * getContext() const
Return the MLIR context.
TransformDialectDataBase(TypeID typeID, MLIRContext *ctx)
Must be called by the subclass with the appropriate type ID.
TypeID getTypeID() const
Returns the dynamic type ID of the subclass.
AttrTypeReplacer.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...