9 #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
10 #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/StringMap.h"
36 : typeID(typeID), ctx(ctx) {}
61 template <
typename DerivedTy>
86 #include "mlir/Dialect/Transform/IR/TransformDialect.h.inc"
116 template <
typename DerivedTy,
typename... ExtraDialects>
119 using Initializer = std::function<void(TransformDialect *)>;
120 using DialectLoader = std::function<void(
MLIRContext *)>;
126 ExtraDialects *...)
const final {
127 for (
const DialectLoader &loader : dialectLoaders)
133 for (
const DialectLoader &loader : generatedDialectLoaders)
136 for (
const Initializer &
init : initializers)
137 init(transformDialect);
146 : buildOnly(buildOnly) {
157 template <
typename Func>
159 std::function<void(
MLIRContext *)> initializer = func;
160 dialectLoaders.push_back(
173 template <
typename DataTy,
typename Func>
175 static_assert(std::is_base_of_v<detail::TransformDialectDataBase, DataTy>,
176 "only classes deriving TransformDialectData are accepted");
178 std::function<void(DataTy &)> initializer = func;
179 initializers.push_back(
180 [
init = std::move(initializer)](TransformDialect *transformDialect) {
181 init(transformDialect->getOrCreateExtraData<DataTy>());
191 template <
typename... OpTys>
193 initializers.push_back([](TransformDialect *transformDialect) {
194 transformDialect->addOperationsChecked<OpTys...>();
203 template <
typename... TypeTys>
205 initializers.push_back([](TransformDialect *transformDialect) {
206 transformDialect->addTypesChecked<TypeTys...>();
218 template <
typename DialectTy>
220 dialectLoaders.push_back(
230 template <
typename DialectTy>
232 generatedDialectLoaders.push_back(
253 template <
typename OpTy>
254 void TransformDialect::addOperationIfNotRegistered() {
255 std::optional<RegisteredOperationName> opName =
258 addOperations<OpTy>();
260 StringRef name = OpTy::getOperationName();
266 if (LLVM_LIKELY(opName->getTypeID() == TypeID::get<OpTy>()))
269 reportDuplicateOpRegistration(OpTy::getOperationName());
272 template <
typename Type>
273 void TransformDialect::addTypeIfNotRegistered() {
276 StringRef mnemonic = Type::getMnemonic();
277 auto [it, inserted] = typeParsingHooks.try_emplace(mnemonic,
Type::parse);
279 const ExtensionTypeParsingHook &parsingHook = it->getValue();
281 reportDuplicateTypeRegistration(mnemonic);
285 typePrintingHooks.try_emplace(
286 TypeID::get<Type>(), +[](
mlir::Type type, AsmPrinter &printer) {
287 printer << Type::getMnemonic();
288 cast<Type>(type).print(printer);
298 template <
typename DataTy>
299 DataTy &TransformDialect::getOrCreateExtraData() {
300 TypeID typeID = TypeID::get<DataTy>();
301 auto [it, inserted] = extraData.try_emplace(typeID);
303 it->getSecond() = std::make_unique<DataTy>(
getContext());
304 return static_cast<DataTy &
>(*it->getSecond());
309 template <
typename DerivedTy>
static MLIRContext * getContext(OpFoldResult val)
This class represents a dialect extension anchored on the given set of dialects.
MLIRContext is the top-level object for a collection of MLIR operations.
void loadDialect()
Load a dialect in the context.
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
This class provides an efficient unique identifier for a specific C++ type.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
void init()
Hook for derived classes to inject constructor behavior.
TransformDialectExtension(bool buildOnly=false)
Extension constructor.
void registerTransformOps()
Injects the operations into the Transform dialect.
void declareDependentDialect()
Declares that this Transform dialect extension depends on the dialect provided as template parameter.
void declareGeneratedDialect()
Declares that the transformations associated with the operations registered by this dialect extension...
void addDialectDataInitializer(Func &&func)
Registers the given function as one of the initializers for the dialect-owned data of the kind specif...
void apply(MLIRContext *context, TransformDialect *transformDialect, ExtraDialects *...) const final
Extension application hook.
void registerTypes()
Injects the types into the Transform dialect.
void addCustomInitializationStep(Func &&func)
Registers a custom initialization step to be performed when the extension is applied to the dialect w...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...