MLIR  16.0.0git
TransformDialect.h
Go to the documentation of this file.
1 //===- TransformDialect.h - Transform Dialect Definition --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
10 #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
11 
13 #include "mlir/IR/Dialect.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/Support/LLVM.h"
16 #include "llvm/ADT/StringMap.h"
17 
18 namespace mlir {
19 namespace transform {
20 #ifndef NDEBUG
21 namespace detail {
22 /// Asserts that the operations provided as template arguments implement the
23 /// TransformOpInterface and MemoryEffectsOpInterface. This must be a dynamic
24 /// assertion since interface implementations may be registered at runtime.
25 template <typename OpTy>
26 static inline void checkImplementsTransformInterface(MLIRContext *context) {
27  // Since the operation is being inserted into the Transform dialect and the
28  // dialect does not implement the interface fallback, only check for the op
29  // itself having the interface implementation.
31  *RegisteredOperationName::lookup(OpTy::getOperationName(), context);
32  assert((opName.hasInterface<TransformOpInterface>() ||
33  opName.hasTrait<OpTrait::IsTerminator>()) &&
34  "non-terminator ops injected into the transform dialect must "
35  "implement TransformOpInterface");
36  assert(opName.hasInterface<MemoryEffectOpInterface>() &&
37  "ops injected into the transform dialect must implement "
38  "MemoryEffectsOpInterface");
39 }
40 } // namespace detail
41 #endif // NDEBUG
42 } // namespace transform
43 } // namespace mlir
44 
45 #include "mlir/Dialect/Transform/IR/TransformDialect.h.inc"
46 
47 namespace mlir {
48 namespace transform {
49 
50 /// Base class for extensions of the Transform dialect that supports injecting
51 /// operations into the Transform dialect at load time. Concrete extensions are
52 /// expected to derive this class and register operations in the constructor.
53 /// They can be registered with the DialectRegistry and automatically applied
54 /// to the Transform dialect when it is loaded.
55 ///
56 /// Derived classes are expected to define a `void init()` function in which
57 /// they can call various protected methods of the base class to register
58 /// extension operations and declare their dependencies.
59 ///
60 /// By default, the extension is configured both for construction of the
61 /// Transform IR and for its application to some payload. If only the
62 /// construction is desired, the extension can be switched to "build-only" mode
63 /// that avoids loading the dialects that are only necessary for transforming
64 /// the payload. To perform the switch, the extension must be wrapped into the
65 /// `BuildOnly` class template (see below) when it is registered, as in:
66 ///
67 /// dialectRegistry.addExtension<BuildOnly<MyTransformDialectExt>>();
68 ///
69 /// instead of:
70 ///
71 /// dialectRegistry.addExtension<MyTransformDialectExt>();
72 ///
73 /// Derived classes must reexport the constructor of this class or otherwise
74 /// forward its boolean argument to support this behavior.
75 template <typename DerivedTy, typename... ExtraDialects>
77  : public DialectExtension<DerivedTy, TransformDialect, ExtraDialects...> {
78  using Initializer = std::function<void(TransformDialect *)>;
79  using DialectLoader = std::function<void(MLIRContext *)>;
80 
81 public:
82  /// Extension application hook. Actually loads the dependent dialects and
83  /// registers the additional operations. Not expected to be called directly.
84  void apply(MLIRContext *context, TransformDialect *transformDialect,
85  ExtraDialects *...) const final {
86  for (const DialectLoader &loader : dialectLoaders)
87  loader(context);
88 
89  // Only load generated dialects if the user intends to apply
90  // transformations specified by the extension.
91  if (!buildOnly)
92  for (const DialectLoader &loader : generatedDialectLoaders)
93  loader(context);
94 
95  for (const Initializer &init : opInitializers)
96  init(transformDialect);
97  transformDialect->mergeInPDLMatchHooks(std::move(pdlMatchConstraintFns));
98  }
99 
100 protected:
101  using Base = TransformDialectExtension<DerivedTy, ExtraDialects...>;
102 
103  /// Extension constructor. The argument indicates whether to skip generated
104  /// dialects when applying the extension.
105  explicit TransformDialectExtension(bool buildOnly = false)
106  : buildOnly(buildOnly) {
107  static_cast<DerivedTy *>(this)->init();
108  }
109 
110  /// Hook for derived classes to inject constructor behavior.
111  void init() {}
112 
113  /// Injects the operations into the Transform dialect. The operations must
114  /// implement the TransformOpInterface and MemoryEffectsOpInterface, and the
115  /// implementations must be already available when the operation is injected.
116  template <typename... OpTys>
118  opInitializers.push_back([](TransformDialect *transformDialect) {
119  transformDialect->addOperationsChecked<OpTys...>();
120  });
121  }
122 
123  /// Declares that this Transform dialect extension depends on the dialect
124  /// provided as template parameter. When the Transform dialect is loaded,
125  /// dependent dialects will be loaded as well. This is intended for dialects
126  /// that contain attributes and types used in creation and canonicalization of
127  /// the injected operations, similarly to how the dialect definition may list
128  /// dependent dialects. This is *not* intended for dialects entities from
129  /// which may be produced when applying the transformations specified by ops
130  /// registered by this extension.
131  template <typename DialectTy>
133  dialectLoaders.push_back(
134  [](MLIRContext *context) { context->loadDialect<DialectTy>(); });
135  }
136 
137  /// Declares that the transformations associated with the operations
138  /// registered by this dialect extension may produce operations from the
139  /// dialect provided as template parameter while processing payload IR that
140  /// does not contain the operations from said dialect. This is similar to
141  /// dependent dialects of a pass. These dialects will be loaded along with the
142  /// transform dialect unless the extension is in the build-only mode.
143  template <typename DialectTy>
145  generatedDialectLoaders.push_back(
146  [](MLIRContext *context) { context->loadDialect<DialectTy>(); });
147  }
148 
149  /// Injects the named constraint to make it available for use with the
150  /// PDLMatchOp in the transform dialect.
151  void registerPDLMatchConstraintFn(StringRef name,
152  PDLConstraintFunction &&fn) {
153  pdlMatchConstraintFns.try_emplace(name,
154  std::forward<PDLConstraintFunction>(fn));
155  }
156  template <typename ConstraintFnTy>
157  void registerPDLMatchConstraintFn(StringRef name, ConstraintFnTy &&fn) {
158  pdlMatchConstraintFns.try_emplace(
160  std::forward<ConstraintFnTy>(fn)));
161  }
162 
163 private:
164  SmallVector<Initializer> opInitializers;
165 
166  /// Callbacks loading the dependent dialects, i.e. the dialect needed for the
167  /// extension ops.
168  SmallVector<DialectLoader> dialectLoaders;
169 
170  /// Callbacks loading the generated dialects, i.e. the dialects produced when
171  /// applying the transformations.
172  SmallVector<DialectLoader> generatedDialectLoaders;
173 
174  /// A list of constraints that should be made available to PDL patterns
175  /// processed by PDLMatchOp in the Transform dialect.
176  ///
177  /// Declared as mutable so its contents can be moved in the `apply` const
178  /// method, which is only called once.
179  mutable llvm::StringMap<PDLConstraintFunction> pdlMatchConstraintFns;
180 
181  /// Indicates that the extension is in build-only mode.
182  bool buildOnly;
183 };
184 
185 /// A wrapper for transform dialect extensions that forces them to be
186 /// constructed in the build-only mode.
187 template <typename DerivedTy>
188 class BuildOnly : public DerivedTy {
189 public:
190  BuildOnly() : DerivedTy(/*buildOnly=*/true) {}
191 };
192 
193 } // namespace transform
194 } // namespace mlir
195 
196 #include "mlir/Dialect/Transform/IR/TransformDialectEnums.h.inc"
197 
198 #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
Include the generated interface declarations.
bool hasTrait() const
Returns true if the operation has a particular trait.
A wrapper for transform dialect extensions that forces them to be constructed in the build-only mode...
TransformDialectExtension(bool buildOnly=false)
Extension constructor.
void declareGeneratedDialect()
Declares that the transformations associated with the operations registered by this dialect extension...
bool hasInterface() const
Returns true if this operation has the given interface registered to it.
void loadDialect()
Load a dialect in the context.
Definition: MLIRContext.h:102
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:707
std::function< LogicalResult(PatternRewriter &, ArrayRef< PDLValue >)> PDLConstraintFunction
A generic PDL pattern constraint function.
Definition: PatternMatch.h:806
void init()
Hook for derived classes to inject constructor behavior.
This class represents a dialect extension anchored on the given set of dialects.
static Optional< RegisteredOperationName > lookup(StringRef name, MLIRContext *ctx)
Lookup the registered operation information for the given operation.
void registerPDLMatchConstraintFn(StringRef name, ConstraintFnTy &&fn)
void registerTransformOps()
Injects the operations into the Transform dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
This is a "type erased" representation of a registered operation.
static void checkImplementsTransformInterface(MLIRContext *context)
Asserts that the operations provided as template arguments implement the TransformOpInterface and Mem...
void declareDependentDialect()
Declares that this Transform dialect extension depends on the dialect provided as template parameter...
void apply(MLIRContext *context, TransformDialect *transformDialect, ExtraDialects *...) const final
Extension application hook.
void registerPDLMatchConstraintFn(StringRef name, PDLConstraintFunction &&fn)
Injects the named constraint to make it available for use with the PDLMatchOp in the transform dialec...
std::enable_if_t< std::is_convertible< ConstraintFnT, PDLConstraintFunction >::value, PDLConstraintFunction > buildConstraintFn(ConstraintFnT &&constraintFn)
Build a constraint function from the given function ConstraintFnT.