MLIR 23.0.0git
MLIRContext.h
Go to the documentation of this file.
1//===- MLIRContext.h - MLIR Global Context Class ----------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_IR_MLIRCONTEXT_H
10#define MLIR_IR_MLIRCONTEXT_H
11
12#include "mlir/Support/LLVM.h"
13#include "mlir/Support/TypeID.h"
14#include "llvm/ADT/ArrayRef.h"
15#include <functional>
16#include <memory>
17#include <vector>
18
19namespace llvm {
20class ThreadPoolInterface;
21} // namespace llvm
22
23namespace mlir {
24namespace tracing {
25class Action;
26}
28class Dialect;
29class DialectRegistry;
30class DynamicDialect;
32class Location;
33class MLIRContextImpl;
35class StorageUniquer;
36class IRUnit;
37namespace remark::detail {
38class RemarkEngine;
39} // namespace remark::detail
40
41/// MLIRContext is the top-level object for a collection of MLIR operations. It
42/// holds immortal uniqued objects like types, and the tables used to unique
43/// them.
44///
45/// MLIRContext gets a redundant "MLIR" prefix because otherwise it ends up with
46/// a very generic name ("Context") and because it is uncommon for clients to
47/// interact with it.
48///
49/// The context wrap some multi-threading facilities, and in particular by
50/// default it will implicitly create a thread pool.
51/// This can be undesirable if multiple context exists at the same time or if a
52/// process will be long-lived and create and destroy contexts.
53/// To control better thread spawning, an externally owned ThreadPool can be
54/// injected in the context. For example:
55///
56/// llvm::DefaultThreadPool myThreadPool;
57/// while (auto *request = nextCompilationRequests()) {
58/// MLIRContext ctx(registry, MLIRContext::Threading::DISABLED);
59/// ctx.setThreadPool(myThreadPool);
60/// processRequest(request, cxt);
61/// }
62///
64public:
65 enum class Threading { DISABLED, ENABLED };
66 /// Create a new Context.
67 explicit MLIRContext(Threading multithreading = Threading::ENABLED);
68 explicit MLIRContext(const DialectRegistry &registry,
69 Threading multithreading = Threading::ENABLED);
71
72 /// Return information about all IR dialects loaded in the context.
73 std::vector<Dialect *> getLoadedDialects();
74
75 /// Return the dialect registry associated with this context.
77
78 /// Append the contents of the given dialect registry to the registry
79 /// associated with this context.
80 void appendDialectRegistry(const DialectRegistry &registry);
81
82 /// Return information about all available dialects in the registry in this
83 /// context.
84 std::vector<StringRef> getAvailableDialects();
85
86 /// Get a registered IR dialect with the given namespace. If an exact match is
87 /// not found, then return nullptr.
88 Dialect *getLoadedDialect(StringRef name);
89
90 /// Get a registered IR dialect for the given derived dialect type. The
91 /// derived type must provide a static 'getDialectNamespace' method.
92 template <typename T>
94 return static_cast<T *>(getLoadedDialect(T::getDialectNamespace()));
95 }
96
97 /// Get (or create) a dialect for the given derived dialect type. The derived
98 /// type must provide a static 'getDialectNamespace' method.
99 template <typename T>
101 return static_cast<T *>(
102 getOrLoadDialect(T::getDialectNamespace(), TypeID::get<T>(), [this]() {
103 std::unique_ptr<T> dialect(new T(this));
104 return dialect;
105 }));
106 }
107
108 /// Load a dialect in the context.
109 template <typename Dialect>
110 void loadDialect() {
111 // Do not load the dialect if it is currently loading. This can happen if a
112 // dialect initializer triggers loading the same dialect recursively.
113 if (!isDialectLoading(Dialect::getDialectNamespace()))
115 }
116
117 /// Load a list dialects in the context.
118 template <typename Dialect, typename OtherDialect, typename... MoreDialects>
119 void loadDialect() {
121 loadDialect<OtherDialect, MoreDialects...>();
122 }
123
124 /// Get (or create) a dynamic dialect for the given name.
126 getOrLoadDynamicDialect(StringRef dialectNamespace,
127 function_ref<void(DynamicDialect *)> ctor);
128
129 /// Load all dialects available in the registry in this context.
131
132 /// Get (or create) a dialect for the given derived dialect name.
133 /// The dialect will be loaded from the registry if no dialect is found.
134 /// If no dialect is loaded for this name and none is available in the
135 /// registry, returns nullptr.
136 Dialect *getOrLoadDialect(StringRef name);
137
138 /// Return true if we allow to create operation for unregistered dialects.
139 [[nodiscard]] bool allowsUnregisteredDialects();
140
141 /// Enables creating operations in unregistered dialects.
142 /// This option is **heavily discouraged**: it is convenient during testing
143 /// but it is not a good practice to use it in production code. Some system
144 /// invariants can be broken (like loading a dialect after creating
145 /// operations) without being caught by assertions or other means.
146 void allowUnregisteredDialects(bool allow = true);
147
148 /// Return true if multi-threading is enabled by the context.
150
151 /// Set the flag specifying if multi-threading is disabled by the context.
152 /// The command line debugging flag `--mlir-disable-threading` is overriding
153 /// this call and making it a no-op!
154 void disableMultithreading(bool disable = true);
155 void enableMultithreading(bool enable = true) {
156 disableMultithreading(!enable);
157 }
158
159 /// Set a new thread pool to be used in this context. This method requires
160 /// that multithreading is disabled for this context prior to the call. This
161 /// allows to share a thread pool across multiple contexts, as well as
162 /// decoupling the lifetime of the threads from the contexts. The thread pool
163 /// must outlive the context. Multi-threading will be enabled as part of this
164 /// method.
165 /// The command line debugging flag `--mlir-disable-threading` will still
166 /// prevent threading from being enabled and threading won't be enabled after
167 /// this call in this case.
168 void setThreadPool(llvm::ThreadPoolInterface &pool);
169
170 /// Return the number of threads used by the thread pool in this context. The
171 /// number of computed hardware threads can change over the lifetime of a
172 /// process based on affinity changes, so users should use the number of
173 /// threads actually in the thread pool for dispatching work. Returns 1 if
174 /// multithreading is disabled.
175 unsigned getNumThreads();
176
177 /// Return the thread pool used by this context. This method requires that
178 /// multithreading be enabled within the context, and should generally not be
179 /// used directly. Users should instead prefer the threading utilities within
180 /// Threading.h.
181 llvm::ThreadPoolInterface &getThreadPool();
182
183 /// Return true if we should attach the operation to diagnostics emitted via
184 /// Operation::emit.
186
187 /// Set the flag specifying if we should attach the operation to diagnostics
188 /// emitted via Operation::emit.
189 void printOpOnDiagnostic(bool enable);
190
191 /// Return true if we should attach the current stacktrace to diagnostics when
192 /// emitted.
194
195 /// Set the flag specifying if we should attach the current stacktrace when
196 /// emitting diagnostics.
197 void printStackTraceOnDiagnostic(bool enable);
198
199 /// Return a sorted array containing the information about all registered
200 /// operations.
202
203 /// Return a sorted array containing the information for registered operations
204 /// filtered by dialect name.
206 getRegisteredOperationsByDialect(StringRef dialectName);
207
208 /// Return true if this operation name is registered in this context.
209 bool isOperationRegistered(StringRef name);
210
211 // This is effectively private given that only MLIRContext.cpp can see the
212 // MLIRContextImpl type.
214 const MLIRContextImpl &getImpl() const { return *impl; }
215
216 /// Returns the diagnostic engine for this context.
218
219 /// Returns the remark engine for this context, or nullptr if none has been
220 /// set.
222
223 /// Set the remark engine for this context.
224 void setRemarkEngine(std::unique_ptr<remark::detail::RemarkEngine> engine);
225
226 /// Returns the storage uniquer used for creating affine constructs.
228
229 /// Returns the storage uniquer used for constructing type storage instances.
230 /// This should not be used directly.
232
233 /// Returns the storage uniquer used for constructing attribute storage
234 /// instances. This should not be used directly.
236
237 /// These APIs are tracking whether the context will be used in a
238 /// multithreading environment: this has no effect other than enabling
239 /// assertions on misuses of some APIs.
242
243 /// Get a dialect for the provided namespace and TypeID: abort the program if
244 /// a dialect exist for this namespace with different TypeID. If a dialect has
245 /// not been loaded for this namespace/TypeID yet, use the provided ctor to
246 /// create one on the fly and load it. Returns a pointer to the dialect owned
247 /// by the context.
248 /// The use of this method is in general discouraged in favor of
249 /// 'getOrLoadDialect<DialectClass>()'.
250 Dialect *getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
251 function_ref<std::unique_ptr<Dialect>()> ctor);
252
253 /// Returns a hash of the registry of the context that may be used to give
254 /// a rough indicator of if the state of the context registry has changed. The
255 /// context registry correlates to loaded dialects and their entities
256 /// (attributes, operations, types, etc.).
257 llvm::hash_code getRegistryHash();
258
259 //===--------------------------------------------------------------------===//
260 // Action API
261 //===--------------------------------------------------------------------===//
262
263 /// Signatures for the action handler that can be registered with the context.
264 using HandlerTy =
265 std::function<void(function_ref<void()>, const tracing::Action &)>;
266
267 /// Register a handler for handling actions that are dispatched through this
268 /// context. A nullptr handler can be set to disable a previously set handler.
269 void registerActionHandler(HandlerTy handler);
270
271 /// Return a reference to the currently registered action handler. Its target
272 /// can be used to gain access to the handler's state, if any.
273 const HandlerTy &getActionHandler() const;
275
276 /// Return true if a valid ActionHandler is set.
277 bool hasActionHandler();
278
279 /// Dispatch the provided action to the handler if any, or just execute it.
280 void executeAction(function_ref<void()> actionFn,
281 const tracing::Action &action) {
282 if (LLVM_UNLIKELY(hasActionHandler()))
283 executeActionInternal(actionFn, action);
284 else
285 actionFn();
286 }
287
288 /// Dispatch the provided action to the handler if any, or just execute it.
289 template <typename ActionTy, typename... Args>
290 void executeAction(function_ref<void()> actionFn, ArrayRef<IRUnit> irUnits,
291 Args &&...args) {
292 if (LLVM_UNLIKELY(hasActionHandler()))
293 executeActionInternal<ActionTy, Args...>(actionFn, irUnits,
294 std::forward<Args>(args)...);
295 else
296 actionFn();
297 }
298
299private:
300 /// Return true if the given dialect is currently loading.
301 bool isDialectLoading(StringRef dialectNamespace);
302
303 /// Internal helper for the dispatch method.
304 void executeActionInternal(function_ref<void()> actionFn,
305 const tracing::Action &action);
306
307 /// Internal helper for the dispatch method. We get here after checking that
308 /// there is a handler, for the purpose of keeping this code out-of-line. and
309 /// avoid calling the ctor for the Action unnecessarily.
310 template <typename ActionTy, typename... Args>
311 LLVM_ATTRIBUTE_NOINLINE void
312 executeActionInternal(function_ref<void()> actionFn, ArrayRef<IRUnit> irUnits,
313 Args &&...args) {
314 executeActionInternal(actionFn,
315 ActionTy(irUnits, std::forward<Args>(args)...));
316 }
317
318 const std::unique_ptr<MLIRContextImpl> impl;
319
320 MLIRContext(const MLIRContext &) = delete;
321 void operator=(const MLIRContext &) = delete;
322};
323
324//===----------------------------------------------------------------------===//
325// MLIRContext CommandLine Options
326//===----------------------------------------------------------------------===//
327
328/// Register a set of useful command-line options that can be used to configure
329/// various flags within the MLIRContext. These flags are used when constructing
330/// an MLIR context for initialization.
332
333} // namespace mlir
334
335#endif // MLIR_IR_MLIRCONTEXT_H
This class is the main interface for diagnostics.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
A dialect that can be defined at runtime.
IRUnit is a union of the different types of IR objects that constitute the IR structure (other than T...
Definition Unit.h:28
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This is the implementation of the MLIRContext class, using the pImpl idiom.
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
bool shouldPrintStackTraceOnDiagnostic()
Return true if we should attach the current stacktrace to diagnostics when emitted.
unsigned getNumThreads()
Return the number of threads used by the thread pool in this context.
bool isOperationRegistered(StringRef name)
Return true if this operation name is registered in this context.
MLIRContext(Threading multithreading=Threading::ENABLED)
Create a new Context.
void disableMultithreading(bool disable=true)
Set the flag specifying if multi-threading is disabled by the context.
T * getOrLoadDialect()
Get (or create) a dialect for the given derived dialect type.
void printStackTraceOnDiagnostic(bool enable)
Set the flag specifying if we should attach the current stacktrace when emitting diagnostics.
bool hasActionHandler()
Return true if a valid ActionHandler is set.
void setRemarkEngine(std::unique_ptr< remark::detail::RemarkEngine > engine)
Set the remark engine for this context.
void setThreadPool(llvm::ThreadPoolInterface &pool)
Set a new thread pool to be used in this context.
void executeAction(function_ref< void()> actionFn, const tracing::Action &action)
Dispatch the provided action to the handler if any, or just execute it.
const HandlerTy & getActionHandler() const
Return a reference to the currently registered action handler.
void enableMultithreading(bool enable=true)
remark::detail::RemarkEngine * getRemarkEngine()
Returns the remark engine for this context, or nullptr if none has been set.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperationsByDialect(StringRef dialectName)
Return a sorted array containing the information for registered operations filtered by dialect name.
void printOpOnDiagnostic(bool enable)
Set the flag specifying if we should attach the operation to diagnostics emitted via Operation::emit.
void executeAction(function_ref< void()> actionFn, ArrayRef< IRUnit > irUnits, Args &&...args)
Dispatch the provided action to the handler if any, or just execute it.
void registerActionHandler(HandlerTy handler)
Register a handler for handling actions that are dispatched through this context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
llvm::hash_code getRegistryHash()
Returns a hash of the registry of the context that may be used to give a rough indicator of if the st...
void enterMultiThreadedExecution()
These APIs are tracking whether the context will be used in a multithreading environment: this has no...
const DialectRegistry & getDialectRegistry()
Return the dialect registry associated with this context.
void loadDialect()
Load a dialect in the context.
DynamicDialect * getOrLoadDynamicDialect(StringRef dialectNamespace, function_ref< void(DynamicDialect *)> ctor)
Get (or create) a dynamic dialect for the given name.
const MLIRContextImpl & getImpl() const
StorageUniquer & getAttributeUniquer()
Returns the storage uniquer used for constructing attribute storage instances.
StorageUniquer & getAffineUniquer()
Returns the storage uniquer used for creating affine constructs.
std::function< void(function_ref< void()>, const tracing::Action &)> HandlerTy
Signatures for the action handler that can be registered with the context.
StorageUniquer & getTypeUniquer()
Returns the storage uniquer used for constructing type storage instances.
llvm::ThreadPoolInterface & getThreadPool()
Return the thread pool used by this context.
std::vector< StringRef > getAvailableDialects()
Return information about all available dialects in the registry in this context.
bool isMultithreadingEnabled()
Return true if multi-threading is enabled by the context.
void allowUnregisteredDialects(bool allow=true)
Enables creating operations in unregistered dialects.
bool allowsUnregisteredDialects()
Return true if we allow to create operation for unregistered dialects.
DiagnosticEngine & getDiagEngine()
Returns the diagnostic engine for this context.
void loadDialect()
Load a list dialects in the context.
MLIRContextImpl & getImpl()
void exitMultiThreadedExecution()
bool shouldPrintOpOnDiagnostic()
Return true if we should attach the operation to diagnostics emitted via Operation::emit.
void loadAllAvailableDialects()
Load all dialects available in the registry in this context.
T * getLoadedDialect()
Get a registered IR dialect for the given derived dialect type.
Definition MLIRContext.h:93
This is a "type erased" representation of a registered operation.
A utility class to get or create instances of "storage classes".
This class provides an efficient unique identifier for a specific C++ type.
Definition TypeID.h:107
static TypeID get()
Construct a type info object for the given type T.
Definition TypeID.h:245
An action is a specific action that is to be taken by the compiler, that can be toggled and controlle...
Definition Action.h:38
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:228
Include the generated interface declarations.
void registerMLIRContextCLOptions()
Register a set of useful command-line options that can be used to configure various flags within the ...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147