MLIR  16.0.0git
Pass.cpp
Go to the documentation of this file.
1 //===- Pass.cpp - C Interface for General Pass Management APIs ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Pass.h"
10 
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Pass.h"
13 #include "mlir/CAPI/Support.h"
14 #include "mlir/CAPI/Utils.h"
15 #include "mlir/Pass/PassManager.h"
16 
17 using namespace mlir;
18 
19 //===----------------------------------------------------------------------===//
20 // PassManager/OpPassManager APIs.
21 //===----------------------------------------------------------------------===//
22 
23 MlirPassManager mlirPassManagerCreate(MlirContext ctx) {
24  return wrap(new PassManager(unwrap(ctx)));
25 }
26 
27 void mlirPassManagerDestroy(MlirPassManager passManager) {
28  delete unwrap(passManager);
29 }
30 
31 MlirOpPassManager
32 mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) {
33  return wrap(static_cast<OpPassManager *>(unwrap(passManager)));
34 }
35 
36 MlirLogicalResult mlirPassManagerRun(MlirPassManager passManager,
37  MlirModule module) {
38  return wrap(unwrap(passManager)->run(unwrap(module)));
39 }
40 
41 void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) {
42  return unwrap(passManager)->enableIRPrinting();
43 }
44 
45 void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
46  unwrap(passManager)->enableVerifier(enable);
47 }
48 
49 MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager,
50  MlirStringRef operationName) {
51  return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
52 }
53 
54 MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager,
55  MlirStringRef operationName) {
56  return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
57 }
58 
59 void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass) {
60  unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
61 }
62 
63 void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager,
64  MlirPass pass) {
65  unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
66 }
67 
68 void mlirPrintPassPipeline(MlirOpPassManager passManager,
69  MlirStringCallback callback, void *userData) {
70  detail::CallbackOstream stream(callback, userData);
71  unwrap(passManager)->printAsTextualPipeline(stream);
72 }
73 
74 MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager,
75  MlirStringRef pipeline) {
76  // TODO: errors are sent to std::errs() at the moment, we should pass in a
77  // stream and redirect to a diagnostic.
78  return wrap(mlir::parsePassPipeline(unwrap(pipeline), *unwrap(passManager)));
79 }
80 
81 //===----------------------------------------------------------------------===//
82 // External Pass API.
83 //===----------------------------------------------------------------------===//
84 
85 namespace mlir {
86 class ExternalPass;
87 } // namespace mlir
89 
90 namespace mlir {
91 /// This pass class wraps external passes defined in other languages using the
92 /// MLIR C-interface
93 class ExternalPass : public Pass {
94 public:
95  ExternalPass(TypeID passID, StringRef name, StringRef argument,
96  StringRef description, Optional<StringRef> opName,
97  ArrayRef<MlirDialectHandle> dependentDialects,
98  MlirExternalPassCallbacks callbacks, void *userData)
99  : Pass(passID, opName), id(passID), name(name), argument(argument),
100  description(description), dependentDialects(dependentDialects),
101  callbacks(callbacks), userData(userData) {
102  callbacks.construct(userData);
103  }
104 
105  ~ExternalPass() override { callbacks.destruct(userData); }
106 
107  StringRef getName() const override { return name; }
108  StringRef getArgument() const override { return argument; }
109  StringRef getDescription() const override { return description; }
110 
111  void getDependentDialects(DialectRegistry &registry) const override {
112  MlirDialectRegistry cRegistry = wrap(&registry);
113  for (MlirDialectHandle dialect : dependentDialects)
114  mlirDialectHandleInsertDialect(dialect, cRegistry);
115  }
116 
118 
119 protected:
121  if (callbacks.initialize)
122  return unwrap(callbacks.initialize(wrap(ctx), userData));
123  return success();
124  }
125 
126  bool canScheduleOn(RegisteredOperationName opName) const override {
127  if (Optional<StringRef> specifiedOpName = getOpName())
128  return opName.getStringRef() == specifiedOpName;
129  return true;
130  }
131 
132  void runOnOperation() override {
133  callbacks.run(wrap(getOperation()), wrap(this), userData);
134  }
135 
136  std::unique_ptr<Pass> clonePass() const override {
137  void *clonedUserData = callbacks.clone(userData);
138  return std::make_unique<ExternalPass>(id, name, argument, description,
139  getOpName(), dependentDialects,
140  callbacks, clonedUserData);
141  }
142 
143 private:
144  TypeID id;
145  std::string name;
146  std::string argument;
147  std::string description;
148  std::vector<MlirDialectHandle> dependentDialects;
149  MlirExternalPassCallbacks callbacks;
150  void *userData;
151 };
152 } // namespace mlir
153 
154 MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
155  MlirStringRef argument,
156  MlirStringRef description, MlirStringRef opName,
157  intptr_t nDependentDialects,
158  MlirDialectHandle *dependentDialects,
159  MlirExternalPassCallbacks callbacks,
160  void *userData) {
161  return wrap(static_cast<mlir::Pass *>(new mlir::ExternalPass(
162  unwrap(passID), unwrap(name), unwrap(argument), unwrap(description),
163  opName.length > 0 ? Optional<StringRef>(unwrap(opName)) : None,
164  {dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks,
165  userData)));
166 }
167 
168 void mlirExternalPassSignalFailure(MlirExternalPass pass) {
169  unwrap(pass)->signalPassFailure();
170 }
Include the generated interface declarations.
This pass class wraps external passes defined in other languages using the MLIR C-interface.
Definition: Pass.cpp:93
MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline)
Parse a textual MLIR pass pipeline and add it to the provided OpPassManager.
Definition: Pass.cpp:74
#define DEFINE_C_API_PTR_METHODS(name, cpptype)
Definition: Wrap.h:25
MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager, MlirStringRef operationName)
Nest an OpPassManager under the provided OpPassManager, the nested passmanager will only run on opera...
Definition: Pass.cpp:54
ExternalPass(TypeID passID, StringRef name, StringRef argument, StringRef description, Optional< StringRef > opName, ArrayRef< MlirDialectHandle > dependentDialects, MlirExternalPassCallbacks callbacks, void *userData)
Definition: Pass.cpp:95
The main pass manager and pipeline builder.
Definition: PassManager.h:210
void *(* clone)(void *userData)
This callback is called when the pass is cloned.
Definition: Pass.h:147
MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager, MlirStringRef operationName)
Nest an OpPassManager under the top-level PassManager, the nested passmanager will only run on operat...
Definition: Pass.cpp:49
MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name, MlirStringRef argument, MlirStringRef description, MlirStringRef opName, intptr_t nDependentDialects, MlirDialectHandle *dependentDialects, MlirExternalPassCallbacks callbacks, void *userData)
Creates an external MlirPass that calls the supplied callbacks using the supplied userData...
Definition: Pass.cpp:154
void mlirPassManagerDestroy(MlirPassManager passManager)
Destroy the provided PassManager.
Definition: Pass.cpp:27
MlirLogicalResult(* initialize)(MlirContext ctx, void *userData)
This callback is optional.
Definition: Pass.h:143
void mlirExternalPassSignalFailure(MlirExternalPass pass)
This signals that the pass has failed.
Definition: Pass.cpp:168
void runOnOperation() override
The polymorphic API that runs the pass over the currently held operation.
Definition: Pass.cpp:132
A simple raw ostream subclass that forwards write_impl calls to the user-supplied callback together w...
Definition: Utils.h:30
void mlirPassManagerEnableIRPrinting(MlirPassManager passManager)
Enable mlir-print-ir-after-all.
Definition: Pass.cpp:41
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
void signalPassFailure()
Signal that some invariant was broken when running.
Definition: Pass.h:212
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
void signalPassFailure()
Definition: Pass.cpp:117
StringRef getName() const override
Returns the derived pass name.
Definition: Pass.cpp:107
std::unique_ptr< Pass > clonePass() const override
Create a copy of this pass, ignoring statistics and options.
Definition: Pass.cpp:136
void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable)
Enable / disable verify-each.
Definition: Pass.cpp:45
A logical result value, essentially a boolean with named states.
Definition: Support.h:114
Optional< StringRef > getOpName() const
Returns the name of the operation that this pass operates on, or None if this is a generic OperationP...
Definition: Pass.h:85
void(* run)(MlirOperation op, MlirExternalPass pass, void *userData)
This callback is called when the pass is run.
Definition: Pass.h:151
StringRef getArgument() const override
Return the command line argument used when registering this pass.
Definition: Pass.cpp:108
size_t length
Length of the fragment.
Definition: Support.h:73
Structure of external MlirPass callbacks.
Definition: Pass.h:130
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:71
void getDependentDialects(DialectRegistry &registry) const override
Register dependent dialects for the current pass.
Definition: Pass.cpp:111
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
Definition: Diagnostics.h:24
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
Operation * getOperation()
Return the current operation being transformed.
Definition: Pass.h:206
void(* MlirStringCallback)(MlirStringRef, void *)
A callback for returning string references.
Definition: Support.h:103
void(* destruct)(void *userData)
This callback is called when the pass is destroyed This is analogous to a C++ pass destructor...
Definition: Pass.h:137
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to &#39;pm&#39; on success...
bool canScheduleOn(RegisteredOperationName opName) const override
Indicate if the current pass can be scheduled on the given operation type.
Definition: Pass.cpp:126
StringRef getDescription() const override
Return the command line description used when registering this pass.
Definition: Pass.cpp:109
The abstract base pass class.
Definition: Pass.h:50
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
void mlirPrintPassPipeline(MlirOpPassManager passManager, MlirStringCallback callback, void *userData)
Print a textual MLIR pass pipeline by sending chunks of the string representation and forwarding user...
Definition: Pass.cpp:68
void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager, MlirPass pass)
Add a pass and transfer ownership to the provided mlirOpPassManager.
Definition: Pass.cpp:63
MlirOpPassManager mlirPassManagerGetAsOpPassManager(MlirPassManager passManager)
Cast a top-level PassManager to a generic OpPassManager.
Definition: Pass.cpp:32
LogicalResult initialize(MLIRContext *ctx) override
Initialize any complex state necessary for running this pass.
Definition: Pass.cpp:120
void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass)
Add a pass and transfer ownership to the provided top-level mlirPassManager.
Definition: Pass.cpp:59
void(* construct)(void *userData)
This callback is called from the pass is created.
Definition: Pass.h:133
This is a "type erased" representation of a registered operation.
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
Definition: Diagnostics.h:19
MlirPassManager mlirPassManagerCreate(MlirContext ctx)
Create a new top-level PassManager.
Definition: Pass.cpp:23
~ExternalPass() override
Definition: Pass.cpp:105
MlirLogicalResult mlirPassManagerRun(MlirPassManager passManager, MlirModule module)
Run the provided passManager on the given module.
Definition: Pass.cpp:36
MLIR_CAPI_EXPORTED void mlirDialectHandleInsertDialect(MlirDialectHandle, MlirDialectRegistry)
Inserts the dialect associated with the provided dialect handle into the provided dialect registry...