MLIR  20.0.0git
Pass.cpp
Go to the documentation of this file.
1 //===- Pass.cpp - C Interface for General Pass Management APIs ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Pass.h"
10 
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Pass.h"
13 #include "mlir/CAPI/Support.h"
14 #include "mlir/CAPI/Utils.h"
15 #include "mlir/Pass/PassManager.h"
16 #include <optional>
17 
18 using namespace mlir;
19 
20 //===----------------------------------------------------------------------===//
21 // PassManager/OpPassManager APIs.
22 //===----------------------------------------------------------------------===//
23 
24 MlirPassManager mlirPassManagerCreate(MlirContext ctx) {
25  return wrap(new PassManager(unwrap(ctx)));
26 }
27 
28 MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx,
29  MlirStringRef anchorOp) {
30  return wrap(new PassManager(unwrap(ctx), unwrap(anchorOp)));
31 }
32 
33 void mlirPassManagerDestroy(MlirPassManager passManager) {
34  delete unwrap(passManager);
35 }
36 
37 MlirOpPassManager
38 mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) {
39  return wrap(static_cast<OpPassManager *>(unwrap(passManager)));
40 }
41 
42 MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager,
43  MlirOperation op) {
44  return wrap(unwrap(passManager)->run(unwrap(op)));
45 }
46 
47 void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) {
48  return unwrap(passManager)->enableIRPrinting();
49 }
50 
51 void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
52  unwrap(passManager)->enableVerifier(enable);
53 }
54 
55 MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager,
56  MlirStringRef operationName) {
57  return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
58 }
59 
60 MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager,
61  MlirStringRef operationName) {
62  return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
63 }
64 
65 void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass) {
66  unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
67 }
68 
69 void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager,
70  MlirPass pass) {
71  unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
72 }
73 
74 MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager,
75  MlirStringRef pipelineElements,
76  MlirStringCallback callback,
77  void *userData) {
78  detail::CallbackOstream stream(callback, userData);
79  return wrap(parsePassPipeline(unwrap(pipelineElements), *unwrap(passManager),
80  stream));
81 }
82 
83 void mlirPrintPassPipeline(MlirOpPassManager passManager,
84  MlirStringCallback callback, void *userData) {
85  detail::CallbackOstream stream(callback, userData);
86  unwrap(passManager)->printAsTextualPipeline(stream);
87 }
88 
89 MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager,
90  MlirStringRef pipeline,
91  MlirStringCallback callback,
92  void *userData) {
93  detail::CallbackOstream stream(callback, userData);
94  FailureOr<OpPassManager> pm = parsePassPipeline(unwrap(pipeline), stream);
95  if (succeeded(pm))
96  *unwrap(passManager) = std::move(*pm);
97  return wrap(pm);
98 }
99 
100 //===----------------------------------------------------------------------===//
101 // External Pass API.
102 //===----------------------------------------------------------------------===//
103 
104 namespace mlir {
105 class ExternalPass;
106 } // namespace mlir
108 
109 namespace mlir {
110 /// This pass class wraps external passes defined in other languages using the
111 /// MLIR C-interface
112 class ExternalPass : public Pass {
113 public:
114  ExternalPass(TypeID passID, StringRef name, StringRef argument,
115  StringRef description, std::optional<StringRef> opName,
116  ArrayRef<MlirDialectHandle> dependentDialects,
117  MlirExternalPassCallbacks callbacks, void *userData)
118  : Pass(passID, opName), id(passID), name(name), argument(argument),
119  description(description), dependentDialects(dependentDialects),
120  callbacks(callbacks), userData(userData) {
121  callbacks.construct(userData);
122  }
123 
124  ~ExternalPass() override { callbacks.destruct(userData); }
125 
126  StringRef getName() const override { return name; }
127  StringRef getArgument() const override { return argument; }
128  StringRef getDescription() const override { return description; }
129 
130  void getDependentDialects(DialectRegistry &registry) const override {
131  MlirDialectRegistry cRegistry = wrap(&registry);
132  for (MlirDialectHandle dialect : dependentDialects)
133  mlirDialectHandleInsertDialect(dialect, cRegistry);
134  }
135 
137 
138 protected:
139  LogicalResult initialize(MLIRContext *ctx) override {
140  if (callbacks.initialize)
141  return unwrap(callbacks.initialize(wrap(ctx), userData));
142  return success();
143  }
144 
145  bool canScheduleOn(RegisteredOperationName opName) const override {
146  if (std::optional<StringRef> specifiedOpName = getOpName())
147  return opName.getStringRef() == specifiedOpName;
148  return true;
149  }
150 
151  void runOnOperation() override {
152  callbacks.run(wrap(getOperation()), wrap(this), userData);
153  }
154 
155  std::unique_ptr<Pass> clonePass() const override {
156  void *clonedUserData = callbacks.clone(userData);
157  return std::make_unique<ExternalPass>(id, name, argument, description,
158  getOpName(), dependentDialects,
159  callbacks, clonedUserData);
160  }
161 
162 private:
163  TypeID id;
164  std::string name;
165  std::string argument;
166  std::string description;
167  std::vector<MlirDialectHandle> dependentDialects;
168  MlirExternalPassCallbacks callbacks;
169  void *userData;
170 };
171 } // namespace mlir
172 
173 MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
174  MlirStringRef argument,
175  MlirStringRef description, MlirStringRef opName,
176  intptr_t nDependentDialects,
177  MlirDialectHandle *dependentDialects,
178  MlirExternalPassCallbacks callbacks,
179  void *userData) {
180  return wrap(static_cast<mlir::Pass *>(new mlir::ExternalPass(
181  unwrap(passID), unwrap(name), unwrap(argument), unwrap(description),
182  opName.length > 0 ? std::optional<StringRef>(unwrap(opName))
183  : std::nullopt,
184  {dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks,
185  userData)));
186 }
187 
188 void mlirExternalPassSignalFailure(MlirExternalPass pass) {
189  unwrap(pass)->signalPassFailure();
190 }
MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager, MlirStringRef operationName)
Nest an OpPassManager under the provided OpPassManager, the nested passmanager will only run on opera...
Definition: Pass.cpp:60
MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name, MlirStringRef argument, MlirStringRef description, MlirStringRef opName, intptr_t nDependentDialects, MlirDialectHandle *dependentDialects, MlirExternalPassCallbacks callbacks, void *userData)
Creates an external MlirPass that calls the supplied callbacks using the supplied userData.
Definition: Pass.cpp:173
MlirPassManager mlirPassManagerCreate(MlirContext ctx)
Create a new top-level PassManager with the default anchor.
Definition: Pass.cpp:24
void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable)
Enable / disable verify-each.
Definition: Pass.cpp:51
MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager, MlirStringRef operationName)
Nest an OpPassManager under the top-level PassManager, the nested passmanager will only run on operat...
Definition: Pass.cpp:55
void mlirPassManagerDestroy(MlirPassManager passManager)
Destroy the provided PassManager.
Definition: Pass.cpp:33
MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline, MlirStringCallback callback, void *userData)
Parse a textual MLIR pass pipeline and assign it to the provided OpPassManager.
Definition: Pass.cpp:89
MlirOpPassManager mlirPassManagerGetAsOpPassManager(MlirPassManager passManager)
Cast a top-level PassManager to a generic OpPassManager.
Definition: Pass.cpp:38
MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op)
Run the provided passManager on the given op.
Definition: Pass.cpp:42
void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager, MlirPass pass)
Add a pass and transfer ownership to the provided mlirOpPassManager.
Definition: Pass.cpp:69
void mlirExternalPassSignalFailure(MlirExternalPass pass)
This signals that the pass has failed.
Definition: Pass.cpp:188
void mlirPrintPassPipeline(MlirOpPassManager passManager, MlirStringCallback callback, void *userData)
Print a textual MLIR pass pipeline by sending chunks of the string representation and forwarding user...
Definition: Pass.cpp:83
void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass)
Add a pass and transfer ownership to the provided top-level mlirPassManager.
Definition: Pass.cpp:65
MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx, MlirStringRef anchorOp)
Create a new top-level PassManager anchored on anchorOp.
Definition: Pass.cpp:28
MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager, MlirStringRef pipelineElements, MlirStringCallback callback, void *userData)
Parse a sequence of textual MLIR pass pipeline elements and add them to the provided OpPassManager.
Definition: Pass.cpp:74
void mlirPassManagerEnableIRPrinting(MlirPassManager passManager)
Enable mlir-print-ir-after-all.
Definition: Pass.cpp:47
#define DEFINE_C_API_PTR_METHODS(name, cpptype)
Definition: Wrap.h:25
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This pass class wraps external passes defined in other languages using the MLIR C-interface.
Definition: Pass.cpp:112
StringRef getArgument() const override
Return the command line argument used when registering this pass.
Definition: Pass.cpp:127
~ExternalPass() override
Definition: Pass.cpp:124
void signalPassFailure()
Definition: Pass.cpp:136
ExternalPass(TypeID passID, StringRef name, StringRef argument, StringRef description, std::optional< StringRef > opName, ArrayRef< MlirDialectHandle > dependentDialects, MlirExternalPassCallbacks callbacks, void *userData)
Definition: Pass.cpp:114
StringRef getDescription() const override
Return the command line description used when registering this pass.
Definition: Pass.cpp:128
bool canScheduleOn(RegisteredOperationName opName) const override
Indicate if the current pass can be scheduled on the given operation type.
Definition: Pass.cpp:145
void runOnOperation() override
The polymorphic API that runs the pass over the currently held operation.
Definition: Pass.cpp:151
StringRef getName() const override
Returns the derived pass name.
Definition: Pass.cpp:126
LogicalResult initialize(MLIRContext *ctx) override
Initialize any complex state necessary for running this pass.
Definition: Pass.cpp:139
std::unique_ptr< Pass > clonePass() const override
Create a copy of this pass, ignoring statistics and options.
Definition: Pass.cpp:155
void getDependentDialects(DialectRegistry &registry) const override
Register dependent dialects for the current pass.
Definition: Pass.cpp:130
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a pass manager that runs passes on either a specific operation type,...
Definition: PassManager.h:47
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
The main pass manager and pipeline builder.
Definition: PassManager.h:231
The abstract base pass class.
Definition: Pass.h:51
std::optional< StringRef > getOpName() const
Returns the name of the operation that this pass operates on, or std::nullopt if this is a generic Op...
Definition: Pass.h:83
Operation * getOperation()
Return the current operation being transformed.
Definition: Pass.h:211
void signalPassFailure()
Signal that some invariant was broken when running.
Definition: Pass.h:217
This is a "type erased" representation of a registered operation.
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
A simple raw ostream subclass that forwards write_impl calls to the user-supplied callback together w...
Definition: Utils.h:30
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
Definition: Diagnostics.h:19
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
Definition: Diagnostics.h:24
MLIR_CAPI_EXPORTED void mlirDialectHandleInsertDialect(MlirDialectHandle, MlirDialectRegistry)
Inserts the dialect associated with the provided dialect handle into the provided dialect registry.
void(* MlirStringCallback)(MlirStringRef, void *)
A callback for returning string references.
Definition: Support.h:105
Include the generated interface declarations.
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
Structure of external MlirPass callbacks.
Definition: Pass.h:143
void(* run)(MlirOperation op, MlirExternalPass pass, void *userData)
This callback is called when the pass is run.
Definition: Pass.h:164
void *(* clone)(void *userData)
This callback is called when the pass is cloned.
Definition: Pass.h:160
MlirLogicalResult(* initialize)(MlirContext ctx, void *userData)
This callback is optional.
Definition: Pass.h:156
void(* destruct)(void *userData)
This callback is called when the pass is destroyed This is analogous to a C++ pass destructor.
Definition: Pass.h:150
void(* construct)(void *userData)
This callback is called from the pass is created.
Definition: Pass.h:146
A logical result value, essentially a boolean with named states.
Definition: Support.h:116
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:73
size_t length
Length of the fragment.
Definition: Support.h:75