9#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
10#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
16#include "llvm/ADT/DenseMap.h"
17#include "llvm/ADT/StringMap.h"
36 : typeID(typeID), ctx(ctx) {}
61template <
typename DerivedTy>
69#if LLVM_ENABLE_ABI_BREAKING_CHECKS
74void checkImplementsTransformOpInterface(StringRef name,
MLIRContext *context);
79void checkImplementsTransformHandleTypeInterface(
TypeID typeID,
86#include "mlir/Dialect/Transform/IR/TransformDialect.h.inc"
116template <
typename DerivedTy,
typename... ExtraDialects>
119 using Initializer = std::function<
void(TransformDialect *)>;
126 ExtraDialects *...)
const final {
127 for (
const DialectLoader &loader : dialectLoaders)
133 for (
const DialectLoader &loader : generatedDialectLoaders)
136 for (
const Initializer &
init : initializers)
137 init(transformDialect);
146 : buildOnly(buildOnly) {
157 template <
typename Func>
160 dialectLoaders.push_back(
173 template <
typename DataTy,
typename Func>
175 static_assert(std::is_base_of_v<detail::TransformDialectDataBase, DataTy>,
176 "only classes deriving TransformDialectData are accepted");
178 std::function<
void(DataTy &)> initializer =
func;
179 initializers.push_back(
180 [
init = std::move(initializer)](TransformDialect *transformDialect) {
181 init(transformDialect->getOrCreateExtraData<DataTy>());
191 template <
typename... OpTys>
193 initializers.push_back([](TransformDialect *transformDialect) {
194 transformDialect->addOperationsChecked<OpTys...>();
201 template <
typename... AttrTys>
203 initializers.push_back([](TransformDialect *transformDialect) {
204 transformDialect->addAttributesChecked<AttrTys...>();
213 template <
typename... TypeTys>
215 initializers.push_back([](TransformDialect *transformDialect) {
216 transformDialect->addTypesChecked<TypeTys...>();
228 template <
typename DialectTy>
230 dialectLoaders.push_back(
240 template <
typename DialectTy>
242 generatedDialectLoaders.push_back(
263template <
typename OpTy>
264void TransformDialect::addOperationIfNotRegistered() {
265 std::optional<RegisteredOperationName> opName =
268 addOperations<OpTy>();
269#if LLVM_ENABLE_ABI_BREAKING_CHECKS
270 StringRef name = OpTy::getOperationName();
271 detail::checkImplementsTransformOpInterface(name,
getContext());
279 reportDuplicateOpRegistration(OpTy::getOperationName());
282template <
typename AttrTy>
283void TransformDialect::addAttributeIfNotRegistered() {
286 StringRef mnemonic = AttrTy::getMnemonic();
288 attributeParsingHooks.try_emplace(mnemonic, AttrTy::parse);
290 const ExtensionAttributeParsingHook &parsingHook = it->getValue();
291 if (parsingHook != &AttrTy::parse)
292 reportDuplicateAttributeRegistration(mnemonic);
296 attributePrintingHooks.try_emplace(
298 +[](mlir::Attribute attribute, AsmPrinter &printer) {
299 printer << AttrTy::getMnemonic();
300 cast<AttrTy>(attribute).print(printer);
302 addAttributes<AttrTy>();
305template <
typename Type>
306void TransformDialect::addTypeIfNotRegistered() {
309 StringRef mnemonic = Type::getMnemonic();
310 auto [it,
inserted] = typeParsingHooks.try_emplace(mnemonic, Type::parse);
312 const ExtensionTypeParsingHook &parsingHook = it->getValue();
313 if (parsingHook != &Type::parse)
314 reportDuplicateTypeRegistration(mnemonic);
318 typePrintingHooks.try_emplace(
320 printer << Type::getMnemonic();
321 cast<Type>(type).print(printer);
325#if LLVM_ENABLE_ABI_BREAKING_CHECKS
331template <
typename DataTy>
332DataTy &TransformDialect::getOrCreateExtraData() {
334 auto [it,
inserted] = extraData.try_emplace(typeID);
336 it->getSecond() = std::make_unique<DataTy>(
getContext());
337 return static_cast<DataTy &
>(*it->getSecond());
342template <
typename DerivedTy>
true
Given two iterators into the same block, return "true" if a is before `b.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
MLIRContext is the top-level object for a collection of MLIR operations.
void loadDialect()
Load a dialect in the context.
static std::optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
This class provides an efficient unique identifier for a specific C++ type.
static TypeID get()
Construct a type info object for the given type T.
TransformDialectExtension(bool buildOnly=false)
Extension constructor.
TransformDialectExtension< DerivedTy, ExtraDialects... > Base
void registerTransformOps()
Injects the operations into the Transform dialect.
void declareDependentDialect()
Declares that this Transform dialect extension depends on the dialect provided as template parameter.
void declareGeneratedDialect()
Declares that the transformations associated with the operations registered by this dialect extension...
void addDialectDataInitializer(Func &&func)
Registers the given function as one of the initializers for the dialect-owned data of the kind specif...
void apply(MLIRContext *context, TransformDialect *transformDialect, ExtraDialects *...) const final
Extension application hook.
void registerTypes()
Injects the types into the Transform dialect.
void addCustomInitializationStep(Func &&func)
Registers a custom initialization step to be performed when the extension is applied to the dialect w...
void registerAttributes()
Injects the attributes into the Transform dialect.
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...