MLIR  20.0.0git
Pass.cpp
Go to the documentation of this file.
1 //===- Pass.cpp - C Interface for General Pass Management APIs ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Pass.h"
10 
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Pass.h"
13 #include "mlir/CAPI/Support.h"
14 #include "mlir/CAPI/Utils.h"
15 #include "mlir/Pass/PassManager.h"
16 #include <optional>
17 
18 using namespace mlir;
19 
20 //===----------------------------------------------------------------------===//
21 // PassManager/OpPassManager APIs.
22 //===----------------------------------------------------------------------===//
23 
24 MlirPassManager mlirPassManagerCreate(MlirContext ctx) {
25  return wrap(new PassManager(unwrap(ctx)));
26 }
27 
28 MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx,
29  MlirStringRef anchorOp) {
30  return wrap(new PassManager(unwrap(ctx), unwrap(anchorOp)));
31 }
32 
33 void mlirPassManagerDestroy(MlirPassManager passManager) {
34  delete unwrap(passManager);
35 }
36 
37 MlirOpPassManager
38 mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) {
39  return wrap(static_cast<OpPassManager *>(unwrap(passManager)));
40 }
41 
42 MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager,
43  MlirOperation op) {
44  return wrap(unwrap(passManager)->run(unwrap(op)));
45 }
46 
47 void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
48  bool printBeforeAll, bool printAfterAll,
49  bool printModuleScope,
50  bool printAfterOnlyOnChange,
51  bool printAfterOnlyOnFailure) {
52  auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) {
53  return printBeforeAll;
54  };
55  auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) {
56  return printAfterAll;
57  };
58  return unwrap(passManager)
59  ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
60  printModuleScope, printAfterOnlyOnChange,
61  printAfterOnlyOnFailure);
62 }
63 
64 void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
65  unwrap(passManager)->enableVerifier(enable);
66 }
67 
68 MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager,
69  MlirStringRef operationName) {
70  return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
71 }
72 
73 MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager,
74  MlirStringRef operationName) {
75  return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
76 }
77 
78 void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass) {
79  unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
80 }
81 
82 void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager,
83  MlirPass pass) {
84  unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
85 }
86 
87 MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager,
88  MlirStringRef pipelineElements,
89  MlirStringCallback callback,
90  void *userData) {
91  detail::CallbackOstream stream(callback, userData);
92  return wrap(parsePassPipeline(unwrap(pipelineElements), *unwrap(passManager),
93  stream));
94 }
95 
96 void mlirPrintPassPipeline(MlirOpPassManager passManager,
97  MlirStringCallback callback, void *userData) {
98  detail::CallbackOstream stream(callback, userData);
99  unwrap(passManager)->printAsTextualPipeline(stream);
100 }
101 
102 MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager,
103  MlirStringRef pipeline,
104  MlirStringCallback callback,
105  void *userData) {
106  detail::CallbackOstream stream(callback, userData);
107  FailureOr<OpPassManager> pm = parsePassPipeline(unwrap(pipeline), stream);
108  if (succeeded(pm))
109  *unwrap(passManager) = std::move(*pm);
110  return wrap(pm);
111 }
112 
113 //===----------------------------------------------------------------------===//
114 // External Pass API.
115 //===----------------------------------------------------------------------===//
116 
117 namespace mlir {
118 class ExternalPass;
119 } // namespace mlir
121 
122 namespace mlir {
123 /// This pass class wraps external passes defined in other languages using the
124 /// MLIR C-interface
125 class ExternalPass : public Pass {
126 public:
127  ExternalPass(TypeID passID, StringRef name, StringRef argument,
128  StringRef description, std::optional<StringRef> opName,
129  ArrayRef<MlirDialectHandle> dependentDialects,
130  MlirExternalPassCallbacks callbacks, void *userData)
131  : Pass(passID, opName), id(passID), name(name), argument(argument),
132  description(description), dependentDialects(dependentDialects),
133  callbacks(callbacks), userData(userData) {
134  callbacks.construct(userData);
135  }
136 
137  ~ExternalPass() override { callbacks.destruct(userData); }
138 
139  StringRef getName() const override { return name; }
140  StringRef getArgument() const override { return argument; }
141  StringRef getDescription() const override { return description; }
142 
143  void getDependentDialects(DialectRegistry &registry) const override {
144  MlirDialectRegistry cRegistry = wrap(&registry);
145  for (MlirDialectHandle dialect : dependentDialects)
146  mlirDialectHandleInsertDialect(dialect, cRegistry);
147  }
148 
150 
151 protected:
152  LogicalResult initialize(MLIRContext *ctx) override {
153  if (callbacks.initialize)
154  return unwrap(callbacks.initialize(wrap(ctx), userData));
155  return success();
156  }
157 
158  bool canScheduleOn(RegisteredOperationName opName) const override {
159  if (std::optional<StringRef> specifiedOpName = getOpName())
160  return opName.getStringRef() == specifiedOpName;
161  return true;
162  }
163 
164  void runOnOperation() override {
165  callbacks.run(wrap(getOperation()), wrap(this), userData);
166  }
167 
168  std::unique_ptr<Pass> clonePass() const override {
169  void *clonedUserData = callbacks.clone(userData);
170  return std::make_unique<ExternalPass>(id, name, argument, description,
171  getOpName(), dependentDialects,
172  callbacks, clonedUserData);
173  }
174 
175 private:
176  TypeID id;
177  std::string name;
178  std::string argument;
179  std::string description;
180  std::vector<MlirDialectHandle> dependentDialects;
181  MlirExternalPassCallbacks callbacks;
182  void *userData;
183 };
184 } // namespace mlir
185 
186 MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
187  MlirStringRef argument,
188  MlirStringRef description, MlirStringRef opName,
189  intptr_t nDependentDialects,
190  MlirDialectHandle *dependentDialects,
191  MlirExternalPassCallbacks callbacks,
192  void *userData) {
193  return wrap(static_cast<mlir::Pass *>(new mlir::ExternalPass(
194  unwrap(passID), unwrap(name), unwrap(argument), unwrap(description),
195  opName.length > 0 ? std::optional<StringRef>(unwrap(opName))
196  : std::nullopt,
197  {dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks,
198  userData)));
199 }
200 
201 void mlirExternalPassSignalFailure(MlirExternalPass pass) {
202  unwrap(pass)->signalPassFailure();
203 }
MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager, MlirStringRef operationName)
Nest an OpPassManager under the provided OpPassManager, the nested passmanager will only run on opera...
Definition: Pass.cpp:73
MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name, MlirStringRef argument, MlirStringRef description, MlirStringRef opName, intptr_t nDependentDialects, MlirDialectHandle *dependentDialects, MlirExternalPassCallbacks callbacks, void *userData)
Creates an external MlirPass that calls the supplied callbacks using the supplied userData.
Definition: Pass.cpp:186
MlirPassManager mlirPassManagerCreate(MlirContext ctx)
Create a new top-level PassManager with the default anchor.
Definition: Pass.cpp:24
void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable)
Enable / disable verify-each.
Definition: Pass.cpp:64
MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager, MlirStringRef operationName)
Nest an OpPassManager under the top-level PassManager, the nested passmanager will only run on operat...
Definition: Pass.cpp:68
void mlirPassManagerDestroy(MlirPassManager passManager)
Destroy the provided PassManager.
Definition: Pass.cpp:33
MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline, MlirStringCallback callback, void *userData)
Parse a textual MLIR pass pipeline and assign it to the provided OpPassManager.
Definition: Pass.cpp:102
MlirOpPassManager mlirPassManagerGetAsOpPassManager(MlirPassManager passManager)
Cast a top-level PassManager to a generic OpPassManager.
Definition: Pass.cpp:38
MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op)
Run the provided passManager on the given op.
Definition: Pass.cpp:42
void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager, MlirPass pass)
Add a pass and transfer ownership to the provided mlirOpPassManager.
Definition: Pass.cpp:82
void mlirExternalPassSignalFailure(MlirExternalPass pass)
This signals that the pass has failed.
Definition: Pass.cpp:201
void mlirPrintPassPipeline(MlirOpPassManager passManager, MlirStringCallback callback, void *userData)
Print a textual MLIR pass pipeline by sending chunks of the string representation and forwarding user...
Definition: Pass.cpp:96
void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass)
Add a pass and transfer ownership to the provided top-level mlirPassManager.
Definition: Pass.cpp:78
MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx, MlirStringRef anchorOp)
Create a new top-level PassManager anchored on anchorOp.
Definition: Pass.cpp:28
MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager, MlirStringRef pipelineElements, MlirStringCallback callback, void *userData)
Parse a sequence of textual MLIR pass pipeline elements and add them to the provided OpPassManager.
Definition: Pass.cpp:87
void mlirPassManagerEnableIRPrinting(MlirPassManager passManager, bool printBeforeAll, bool printAfterAll, bool printModuleScope, bool printAfterOnlyOnChange, bool printAfterOnlyOnFailure)
Enable IR printing.
Definition: Pass.cpp:47
#define DEFINE_C_API_PTR_METHODS(name, cpptype)
Definition: Wrap.h:25
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This pass class wraps external passes defined in other languages using the MLIR C-interface.
Definition: Pass.cpp:125
StringRef getArgument() const override
Return the command line argument used when registering this pass.
Definition: Pass.cpp:140
~ExternalPass() override
Definition: Pass.cpp:137
void signalPassFailure()
Definition: Pass.cpp:149
ExternalPass(TypeID passID, StringRef name, StringRef argument, StringRef description, std::optional< StringRef > opName, ArrayRef< MlirDialectHandle > dependentDialects, MlirExternalPassCallbacks callbacks, void *userData)
Definition: Pass.cpp:127
StringRef getDescription() const override
Return the command line description used when registering this pass.
Definition: Pass.cpp:141
bool canScheduleOn(RegisteredOperationName opName) const override
Indicate if the current pass can be scheduled on the given operation type.
Definition: Pass.cpp:158
void runOnOperation() override
The polymorphic API that runs the pass over the currently held operation.
Definition: Pass.cpp:164
StringRef getName() const override
Returns the derived pass name.
Definition: Pass.cpp:139
LogicalResult initialize(MLIRContext *ctx) override
Initialize any complex state necessary for running this pass.
Definition: Pass.cpp:152
std::unique_ptr< Pass > clonePass() const override
Create a copy of this pass, ignoring statistics and options.
Definition: Pass.cpp:168
void getDependentDialects(DialectRegistry &registry) const override
Register dependent dialects for the current pass.
Definition: Pass.cpp:143
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a pass manager that runs passes on either a specific operation type,...
Definition: PassManager.h:47
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
The main pass manager and pipeline builder.
Definition: PassManager.h:231
The abstract base pass class.
Definition: Pass.h:51
std::optional< StringRef > getOpName() const
Returns the name of the operation that this pass operates on, or std::nullopt if this is a generic Op...
Definition: Pass.h:83
Operation * getOperation()
Return the current operation being transformed.
Definition: Pass.h:211
void signalPassFailure()
Signal that some invariant was broken when running.
Definition: Pass.h:217
This is a "type erased" representation of a registered operation.
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
A simple raw ostream subclass that forwards write_impl calls to the user-supplied callback together w...
Definition: Utils.h:30
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
Definition: Diagnostics.h:19
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
Definition: Diagnostics.h:24
MLIR_CAPI_EXPORTED void mlirDialectHandleInsertDialect(MlirDialectHandle, MlirDialectRegistry)
Inserts the dialect associated with the provided dialect handle into the provided dialect registry.
void(* MlirStringCallback)(MlirStringRef, void *)
A callback for returning string references.
Definition: Support.h:105
Include the generated interface declarations.
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
Structure of external MlirPass callbacks.
Definition: Pass.h:145
void(* run)(MlirOperation op, MlirExternalPass pass, void *userData)
This callback is called when the pass is run.
Definition: Pass.h:166
void *(* clone)(void *userData)
This callback is called when the pass is cloned.
Definition: Pass.h:162
MlirLogicalResult(* initialize)(MlirContext ctx, void *userData)
This callback is optional.
Definition: Pass.h:158
void(* destruct)(void *userData)
This callback is called when the pass is destroyed This is analogous to a C++ pass destructor.
Definition: Pass.h:152
void(* construct)(void *userData)
This callback is called from the pass is created.
Definition: Pass.h:148
A logical result value, essentially a boolean with named states.
Definition: Support.h:116
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:73
size_t length
Length of the fragment.
Definition: Support.h:75