MLIR  20.0.0git
Pass.cpp
Go to the documentation of this file.
1 //===- Pass.cpp - C Interface for General Pass Management APIs ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Pass.h"
10 
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Pass.h"
13 #include "mlir/CAPI/Support.h"
14 #include "mlir/CAPI/Utils.h"
15 #include "mlir/Pass/PassManager.h"
16 #include <optional>
17 
18 using namespace mlir;
19 
20 //===----------------------------------------------------------------------===//
21 // PassManager/OpPassManager APIs.
22 //===----------------------------------------------------------------------===//
23 
24 MlirPassManager mlirPassManagerCreate(MlirContext ctx) {
25  return wrap(new PassManager(unwrap(ctx)));
26 }
27 
28 MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx,
29  MlirStringRef anchorOp) {
30  return wrap(new PassManager(unwrap(ctx), unwrap(anchorOp)));
31 }
32 
33 void mlirPassManagerDestroy(MlirPassManager passManager) {
34  delete unwrap(passManager);
35 }
36 
37 MlirOpPassManager
38 mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) {
39  return wrap(static_cast<OpPassManager *>(unwrap(passManager)));
40 }
41 
42 MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager,
43  MlirOperation op) {
44  return wrap(unwrap(passManager)->run(unwrap(op)));
45 }
46 
47 void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
48  bool printBeforeAll, bool printAfterAll,
49  bool printModuleScope,
50  bool printAfterOnlyOnChange,
51  bool printAfterOnlyOnFailure,
52  MlirOpPrintingFlags flags,
53  MlirStringRef treePrintingPath) {
54  auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) {
55  return printBeforeAll;
56  };
57  auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) {
58  return printAfterAll;
59  };
60  if (unwrap(treePrintingPath).empty())
61  return unwrap(passManager)
62  ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
63  printModuleScope, printAfterOnlyOnChange,
64  printAfterOnlyOnFailure, /*out=*/llvm::errs(),
65  *unwrap(flags));
66 
67  unwrap(passManager)
68  ->enableIRPrintingToFileTree(shouldPrintBeforePass, shouldPrintAfterPass,
69  printModuleScope, printAfterOnlyOnChange,
70  printAfterOnlyOnFailure,
71  unwrap(treePrintingPath), *unwrap(flags));
72 }
73 
74 void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
75  unwrap(passManager)->enableVerifier(enable);
76 }
77 
78 MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager,
79  MlirStringRef operationName) {
80  return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
81 }
82 
83 MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager,
84  MlirStringRef operationName) {
85  return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
86 }
87 
88 void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass) {
89  unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
90 }
91 
92 void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager,
93  MlirPass pass) {
94  unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
95 }
96 
97 MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager,
98  MlirStringRef pipelineElements,
99  MlirStringCallback callback,
100  void *userData) {
101  detail::CallbackOstream stream(callback, userData);
102  return wrap(parsePassPipeline(unwrap(pipelineElements), *unwrap(passManager),
103  stream));
104 }
105 
106 void mlirPrintPassPipeline(MlirOpPassManager passManager,
107  MlirStringCallback callback, void *userData) {
108  detail::CallbackOstream stream(callback, userData);
109  unwrap(passManager)->printAsTextualPipeline(stream);
110 }
111 
112 MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager,
113  MlirStringRef pipeline,
114  MlirStringCallback callback,
115  void *userData) {
116  detail::CallbackOstream stream(callback, userData);
117  FailureOr<OpPassManager> pm = parsePassPipeline(unwrap(pipeline), stream);
118  if (succeeded(pm))
119  *unwrap(passManager) = std::move(*pm);
120  return wrap(pm);
121 }
122 
123 //===----------------------------------------------------------------------===//
124 // External Pass API.
125 //===----------------------------------------------------------------------===//
126 
127 namespace mlir {
128 class ExternalPass;
129 } // namespace mlir
131 
132 namespace mlir {
133 /// This pass class wraps external passes defined in other languages using the
134 /// MLIR C-interface
135 class ExternalPass : public Pass {
136 public:
137  ExternalPass(TypeID passID, StringRef name, StringRef argument,
138  StringRef description, std::optional<StringRef> opName,
139  ArrayRef<MlirDialectHandle> dependentDialects,
140  MlirExternalPassCallbacks callbacks, void *userData)
141  : Pass(passID, opName), id(passID), name(name), argument(argument),
142  description(description), dependentDialects(dependentDialects),
143  callbacks(callbacks), userData(userData) {
144  callbacks.construct(userData);
145  }
146 
147  ~ExternalPass() override { callbacks.destruct(userData); }
148 
149  StringRef getName() const override { return name; }
150  StringRef getArgument() const override { return argument; }
151  StringRef getDescription() const override { return description; }
152 
153  void getDependentDialects(DialectRegistry &registry) const override {
154  MlirDialectRegistry cRegistry = wrap(&registry);
155  for (MlirDialectHandle dialect : dependentDialects)
156  mlirDialectHandleInsertDialect(dialect, cRegistry);
157  }
158 
160 
161 protected:
162  LogicalResult initialize(MLIRContext *ctx) override {
163  if (callbacks.initialize)
164  return unwrap(callbacks.initialize(wrap(ctx), userData));
165  return success();
166  }
167 
168  bool canScheduleOn(RegisteredOperationName opName) const override {
169  if (std::optional<StringRef> specifiedOpName = getOpName())
170  return opName.getStringRef() == specifiedOpName;
171  return true;
172  }
173 
174  void runOnOperation() override {
175  callbacks.run(wrap(getOperation()), wrap(this), userData);
176  }
177 
178  std::unique_ptr<Pass> clonePass() const override {
179  void *clonedUserData = callbacks.clone(userData);
180  return std::make_unique<ExternalPass>(id, name, argument, description,
181  getOpName(), dependentDialects,
182  callbacks, clonedUserData);
183  }
184 
185 private:
186  TypeID id;
187  std::string name;
188  std::string argument;
189  std::string description;
190  std::vector<MlirDialectHandle> dependentDialects;
191  MlirExternalPassCallbacks callbacks;
192  void *userData;
193 };
194 } // namespace mlir
195 
196 MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
197  MlirStringRef argument,
198  MlirStringRef description, MlirStringRef opName,
199  intptr_t nDependentDialects,
200  MlirDialectHandle *dependentDialects,
201  MlirExternalPassCallbacks callbacks,
202  void *userData) {
203  return wrap(static_cast<mlir::Pass *>(new mlir::ExternalPass(
204  unwrap(passID), unwrap(name), unwrap(argument), unwrap(description),
205  opName.length > 0 ? std::optional<StringRef>(unwrap(opName))
206  : std::nullopt,
207  {dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks,
208  userData)));
209 }
210 
211 void mlirExternalPassSignalFailure(MlirExternalPass pass) {
212  unwrap(pass)->signalPassFailure();
213 }
MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager, MlirStringRef operationName)
Nest an OpPassManager under the provided OpPassManager, the nested passmanager will only run on opera...
Definition: Pass.cpp:83
MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name, MlirStringRef argument, MlirStringRef description, MlirStringRef opName, intptr_t nDependentDialects, MlirDialectHandle *dependentDialects, MlirExternalPassCallbacks callbacks, void *userData)
Creates an external MlirPass that calls the supplied callbacks using the supplied userData.
Definition: Pass.cpp:196
MlirPassManager mlirPassManagerCreate(MlirContext ctx)
Create a new top-level PassManager with the default anchor.
Definition: Pass.cpp:24
void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable)
Enable / disable verify-each.
Definition: Pass.cpp:74
MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager, MlirStringRef operationName)
Nest an OpPassManager under the top-level PassManager, the nested passmanager will only run on operat...
Definition: Pass.cpp:78
void mlirPassManagerDestroy(MlirPassManager passManager)
Destroy the provided PassManager.
Definition: Pass.cpp:33
MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline, MlirStringCallback callback, void *userData)
Parse a textual MLIR pass pipeline and assign it to the provided OpPassManager.
Definition: Pass.cpp:112
MlirOpPassManager mlirPassManagerGetAsOpPassManager(MlirPassManager passManager)
Cast a top-level PassManager to a generic OpPassManager.
Definition: Pass.cpp:38
MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op)
Run the provided passManager on the given op.
Definition: Pass.cpp:42
void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager, MlirPass pass)
Add a pass and transfer ownership to the provided mlirOpPassManager.
Definition: Pass.cpp:92
void mlirExternalPassSignalFailure(MlirExternalPass pass)
This signals that the pass has failed.
Definition: Pass.cpp:211
void mlirPrintPassPipeline(MlirOpPassManager passManager, MlirStringCallback callback, void *userData)
Print a textual MLIR pass pipeline by sending chunks of the string representation and forwarding user...
Definition: Pass.cpp:106
void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass)
Add a pass and transfer ownership to the provided top-level mlirPassManager.
Definition: Pass.cpp:88
MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx, MlirStringRef anchorOp)
Create a new top-level PassManager anchored on anchorOp.
Definition: Pass.cpp:28
MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager, MlirStringRef pipelineElements, MlirStringCallback callback, void *userData)
Parse a sequence of textual MLIR pass pipeline elements and add them to the provided OpPassManager.
Definition: Pass.cpp:97
void mlirPassManagerEnableIRPrinting(MlirPassManager passManager, bool printBeforeAll, bool printAfterAll, bool printModuleScope, bool printAfterOnlyOnChange, bool printAfterOnlyOnFailure, MlirOpPrintingFlags flags, MlirStringRef treePrintingPath)
Enable IR printing.
Definition: Pass.cpp:47
#define DEFINE_C_API_PTR_METHODS(name, cpptype)
Definition: Wrap.h:25
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This pass class wraps external passes defined in other languages using the MLIR C-interface.
Definition: Pass.cpp:135
StringRef getArgument() const override
Return the command line argument used when registering this pass.
Definition: Pass.cpp:150
~ExternalPass() override
Definition: Pass.cpp:147
void signalPassFailure()
Definition: Pass.cpp:159
ExternalPass(TypeID passID, StringRef name, StringRef argument, StringRef description, std::optional< StringRef > opName, ArrayRef< MlirDialectHandle > dependentDialects, MlirExternalPassCallbacks callbacks, void *userData)
Definition: Pass.cpp:137
StringRef getDescription() const override
Return the command line description used when registering this pass.
Definition: Pass.cpp:151
bool canScheduleOn(RegisteredOperationName opName) const override
Indicate if the current pass can be scheduled on the given operation type.
Definition: Pass.cpp:168
void runOnOperation() override
The polymorphic API that runs the pass over the currently held operation.
Definition: Pass.cpp:174
StringRef getName() const override
Returns the derived pass name.
Definition: Pass.cpp:149
LogicalResult initialize(MLIRContext *ctx) override
Initialize any complex state necessary for running this pass.
Definition: Pass.cpp:162
std::unique_ptr< Pass > clonePass() const override
Create a copy of this pass, ignoring statistics and options.
Definition: Pass.cpp:178
void getDependentDialects(DialectRegistry &registry) const override
Register dependent dialects for the current pass.
Definition: Pass.cpp:153
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a pass manager that runs passes on either a specific operation type,...
Definition: PassManager.h:47
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
The main pass manager and pipeline builder.
Definition: PassManager.h:231
The abstract base pass class.
Definition: Pass.h:51
std::optional< StringRef > getOpName() const
Returns the name of the operation that this pass operates on, or std::nullopt if this is a generic Op...
Definition: Pass.h:83
Operation * getOperation()
Return the current operation being transformed.
Definition: Pass.h:211
void signalPassFailure()
Signal that some invariant was broken when running.
Definition: Pass.h:217
This is a "type erased" representation of a registered operation.
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:104
A simple raw ostream subclass that forwards write_impl calls to the user-supplied callback together w...
Definition: Utils.h:30
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
Definition: Diagnostics.h:19
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
Definition: Diagnostics.h:24
MLIR_CAPI_EXPORTED void mlirDialectHandleInsertDialect(MlirDialectHandle, MlirDialectRegistry)
Inserts the dialect associated with the provided dialect handle into the provided dialect registry.
void(* MlirStringCallback)(MlirStringRef, void *)
A callback for returning string references.
Definition: Support.h:105
Include the generated interface declarations.
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
Structure of external MlirPass callbacks.
Definition: Pass.h:149
void(* run)(MlirOperation op, MlirExternalPass pass, void *userData)
This callback is called when the pass is run.
Definition: Pass.h:170
void *(* clone)(void *userData)
This callback is called when the pass is cloned.
Definition: Pass.h:166
MlirLogicalResult(* initialize)(MlirContext ctx, void *userData)
This callback is optional.
Definition: Pass.h:162
void(* destruct)(void *userData)
This callback is called when the pass is destroyed This is analogous to a C++ pass destructor.
Definition: Pass.h:156
void(* construct)(void *userData)
This callback is called from the pass is created.
Definition: Pass.h:152
A logical result value, essentially a boolean with named states.
Definition: Support.h:116
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:73
size_t length
Length of the fragment.
Definition: Support.h:75