MLIR  22.0.0git
Pass.cpp
Go to the documentation of this file.
1 //===- Pass.cpp - C Interface for General Pass Management APIs ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Pass.h"
10 
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Pass.h"
13 #include "mlir/CAPI/Support.h"
14 #include "mlir/CAPI/Utils.h"
15 #include "mlir/Pass/PassManager.h"
16 #include <optional>
17 
18 using namespace mlir;
19 
20 //===----------------------------------------------------------------------===//
21 // PassManager/OpPassManager APIs.
22 //===----------------------------------------------------------------------===//
23 
24 MlirPassManager mlirPassManagerCreate(MlirContext ctx) {
25  return wrap(new PassManager(unwrap(ctx)));
26 }
27 
28 MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx,
29  MlirStringRef anchorOp) {
30  return wrap(new PassManager(unwrap(ctx), unwrap(anchorOp)));
31 }
32 
33 void mlirPassManagerDestroy(MlirPassManager passManager) {
34  delete unwrap(passManager);
35 }
36 
37 MlirOpPassManager
38 mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) {
39  return wrap(static_cast<OpPassManager *>(unwrap(passManager)));
40 }
41 
42 MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager,
43  MlirOperation op) {
44  return wrap(unwrap(passManager)->run(unwrap(op)));
45 }
46 
47 void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
48  bool printBeforeAll, bool printAfterAll,
49  bool printModuleScope,
50  bool printAfterOnlyOnChange,
51  bool printAfterOnlyOnFailure,
52  MlirOpPrintingFlags flags,
53  MlirStringRef treePrintingPath) {
54  auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) {
55  return printBeforeAll;
56  };
57  auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) {
58  return printAfterAll;
59  };
60  if (unwrap(treePrintingPath).empty())
61  return unwrap(passManager)
62  ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
63  printModuleScope, printAfterOnlyOnChange,
64  printAfterOnlyOnFailure, /*out=*/llvm::errs(),
65  *unwrap(flags));
66 
67  unwrap(passManager)
68  ->enableIRPrintingToFileTree(shouldPrintBeforePass, shouldPrintAfterPass,
69  printModuleScope, printAfterOnlyOnChange,
70  printAfterOnlyOnFailure,
71  unwrap(treePrintingPath), *unwrap(flags));
72 }
73 
74 void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
75  unwrap(passManager)->enableVerifier(enable);
76 }
77 
78 void mlirPassManagerEnableTiming(MlirPassManager passManager) {
79  unwrap(passManager)->enableTiming();
80 }
81 
82 MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager,
83  MlirStringRef operationName) {
84  return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
85 }
86 
87 MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager,
88  MlirStringRef operationName) {
89  return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
90 }
91 
92 void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass) {
93  unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
94 }
95 
96 void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager,
97  MlirPass pass) {
98  unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
99 }
100 
101 MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager,
102  MlirStringRef pipelineElements,
103  MlirStringCallback callback,
104  void *userData) {
105  detail::CallbackOstream stream(callback, userData);
106  return wrap(parsePassPipeline(unwrap(pipelineElements), *unwrap(passManager),
107  stream));
108 }
109 
110 void mlirPrintPassPipeline(MlirOpPassManager passManager,
111  MlirStringCallback callback, void *userData) {
112  detail::CallbackOstream stream(callback, userData);
113  unwrap(passManager)->printAsTextualPipeline(stream);
114 }
115 
116 MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager,
117  MlirStringRef pipeline,
118  MlirStringCallback callback,
119  void *userData) {
120  detail::CallbackOstream stream(callback, userData);
121  FailureOr<OpPassManager> pm = parsePassPipeline(unwrap(pipeline), stream);
122  if (succeeded(pm))
123  *unwrap(passManager) = std::move(*pm);
124  return wrap(pm);
125 }
126 
127 //===----------------------------------------------------------------------===//
128 // External Pass API.
129 //===----------------------------------------------------------------------===//
130 
131 namespace mlir {
132 class ExternalPass;
133 } // namespace mlir
135 
136 namespace mlir {
137 /// This pass class wraps external passes defined in other languages using the
138 /// MLIR C-interface
139 class ExternalPass : public Pass {
140 public:
141  ExternalPass(TypeID passID, StringRef name, StringRef argument,
142  StringRef description, std::optional<StringRef> opName,
143  ArrayRef<MlirDialectHandle> dependentDialects,
144  MlirExternalPassCallbacks callbacks, void *userData)
145  : Pass(passID, opName), id(passID), name(name), argument(argument),
146  description(description), dependentDialects(dependentDialects),
147  callbacks(callbacks), userData(userData) {
148  callbacks.construct(userData);
149  }
150 
151  ~ExternalPass() override { callbacks.destruct(userData); }
152 
153  StringRef getName() const override { return name; }
154  StringRef getArgument() const override { return argument; }
155  StringRef getDescription() const override { return description; }
156 
157  void getDependentDialects(DialectRegistry &registry) const override {
158  MlirDialectRegistry cRegistry = wrap(&registry);
159  for (MlirDialectHandle dialect : dependentDialects)
160  mlirDialectHandleInsertDialect(dialect, cRegistry);
161  }
162 
164 
165 protected:
166  LogicalResult initialize(MLIRContext *ctx) override {
167  if (callbacks.initialize)
168  return unwrap(callbacks.initialize(wrap(ctx), userData));
169  return success();
170  }
171 
172  bool canScheduleOn(RegisteredOperationName opName) const override {
173  if (std::optional<StringRef> specifiedOpName = getOpName())
174  return opName.getStringRef() == specifiedOpName;
175  return true;
176  }
177 
178  void runOnOperation() override {
179  callbacks.run(wrap(getOperation()), wrap(this), userData);
180  }
181 
182  std::unique_ptr<Pass> clonePass() const override {
183  void *clonedUserData = callbacks.clone(userData);
184  return std::make_unique<ExternalPass>(id, name, argument, description,
185  getOpName(), dependentDialects,
186  callbacks, clonedUserData);
187  }
188 
189 private:
190  TypeID id;
191  std::string name;
192  std::string argument;
193  std::string description;
194  std::vector<MlirDialectHandle> dependentDialects;
195  MlirExternalPassCallbacks callbacks;
196  void *userData;
197 };
198 } // namespace mlir
199 
200 MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
201  MlirStringRef argument,
202  MlirStringRef description, MlirStringRef opName,
203  intptr_t nDependentDialects,
204  MlirDialectHandle *dependentDialects,
205  MlirExternalPassCallbacks callbacks,
206  void *userData) {
207  return wrap(static_cast<mlir::Pass *>(new mlir::ExternalPass(
208  unwrap(passID), unwrap(name), unwrap(argument), unwrap(description),
209  opName.length > 0 ? std::optional<StringRef>(unwrap(opName))
210  : std::nullopt,
211  {dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks,
212  userData)));
213 }
214 
215 void mlirExternalPassSignalFailure(MlirExternalPass pass) {
216  unwrap(pass)->signalPassFailure();
217 }
MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager, MlirStringRef operationName)
Nest an OpPassManager under the provided OpPassManager, the nested passmanager will only run on opera...
Definition: Pass.cpp:87
MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name, MlirStringRef argument, MlirStringRef description, MlirStringRef opName, intptr_t nDependentDialects, MlirDialectHandle *dependentDialects, MlirExternalPassCallbacks callbacks, void *userData)
Creates an external MlirPass that calls the supplied callbacks using the supplied userData.
Definition: Pass.cpp:200
MlirPassManager mlirPassManagerCreate(MlirContext ctx)
Create a new top-level PassManager with the default anchor.
Definition: Pass.cpp:24
void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable)
Enable / disable verify-each.
Definition: Pass.cpp:74
MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager, MlirStringRef operationName)
Nest an OpPassManager under the top-level PassManager, the nested passmanager will only run on operat...
Definition: Pass.cpp:82
void mlirPassManagerEnableTiming(MlirPassManager passManager)
Enable pass timing.
Definition: Pass.cpp:78
void mlirPassManagerDestroy(MlirPassManager passManager)
Destroy the provided PassManager.
Definition: Pass.cpp:33
MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline, MlirStringCallback callback, void *userData)
Parse a textual MLIR pass pipeline and assign it to the provided OpPassManager.
Definition: Pass.cpp:116
MlirOpPassManager mlirPassManagerGetAsOpPassManager(MlirPassManager passManager)
Cast a top-level PassManager to a generic OpPassManager.
Definition: Pass.cpp:38
MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op)
Run the provided passManager on the given op.
Definition: Pass.cpp:42
void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager, MlirPass pass)
Add a pass and transfer ownership to the provided mlirOpPassManager.
Definition: Pass.cpp:96
void mlirExternalPassSignalFailure(MlirExternalPass pass)
This signals that the pass has failed.
Definition: Pass.cpp:215
void mlirPrintPassPipeline(MlirOpPassManager passManager, MlirStringCallback callback, void *userData)
Print a textual MLIR pass pipeline by sending chunks of the string representation and forwarding user...
Definition: Pass.cpp:110
void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass)
Add a pass and transfer ownership to the provided top-level mlirPassManager.
Definition: Pass.cpp:92
MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx, MlirStringRef anchorOp)
Create a new top-level PassManager anchored on anchorOp.
Definition: Pass.cpp:28
MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager, MlirStringRef pipelineElements, MlirStringCallback callback, void *userData)
Parse a sequence of textual MLIR pass pipeline elements and add them to the provided OpPassManager.
Definition: Pass.cpp:101
void mlirPassManagerEnableIRPrinting(MlirPassManager passManager, bool printBeforeAll, bool printAfterAll, bool printModuleScope, bool printAfterOnlyOnChange, bool printAfterOnlyOnFailure, MlirOpPrintingFlags flags, MlirStringRef treePrintingPath)
Enable IR printing.
Definition: Pass.cpp:47
#define DEFINE_C_API_PTR_METHODS(name, cpptype)
Definition: Wrap.h:25
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This pass class wraps external passes defined in other languages using the MLIR C-interface.
Definition: Pass.cpp:139
StringRef getArgument() const override
Return the command line argument used when registering this pass.
Definition: Pass.cpp:154
~ExternalPass() override
Definition: Pass.cpp:151
void signalPassFailure()
Definition: Pass.cpp:163
ExternalPass(TypeID passID, StringRef name, StringRef argument, StringRef description, std::optional< StringRef > opName, ArrayRef< MlirDialectHandle > dependentDialects, MlirExternalPassCallbacks callbacks, void *userData)
Definition: Pass.cpp:141
StringRef getDescription() const override
Return the command line description used when registering this pass.
Definition: Pass.cpp:155
bool canScheduleOn(RegisteredOperationName opName) const override
Indicate if the current pass can be scheduled on the given operation type.
Definition: Pass.cpp:172
void runOnOperation() override
The polymorphic API that runs the pass over the currently held operation.
Definition: Pass.cpp:178
StringRef getName() const override
Returns the derived pass name.
Definition: Pass.cpp:153
LogicalResult initialize(MLIRContext *ctx) override
Initialize any complex state necessary for running this pass.
Definition: Pass.cpp:166
std::unique_ptr< Pass > clonePass() const override
Create a copy of this pass, ignoring statistics and options.
Definition: Pass.cpp:182
void getDependentDialects(DialectRegistry &registry) const override
Register dependent dialects for the current pass.
Definition: Pass.cpp:157
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class represents a pass manager that runs passes on either a specific operation type,...
Definition: PassManager.h:46
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
The main pass manager and pipeline builder.
Definition: PassManager.h:232
The abstract base pass class.
Definition: Pass.h:51
std::optional< StringRef > getOpName() const
Returns the name of the operation that this pass operates on, or std::nullopt if this is a generic Op...
Definition: Pass.h:83
Operation * getOperation()
Return the current operation being transformed.
Definition: Pass.h:212
void signalPassFailure()
Signal that some invariant was broken when running.
Definition: Pass.h:218
This is a "type erased" representation of a registered operation.
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:107
A simple raw ostream subclass that forwards write_impl calls to the user-supplied callback together w...
Definition: Utils.h:30
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
Definition: Diagnostics.h:19
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
Definition: Diagnostics.h:24
MLIR_CAPI_EXPORTED void mlirDialectHandleInsertDialect(MlirDialectHandle, MlirDialectRegistry)
Inserts the dialect associated with the provided dialect handle into the provided dialect registry.
void(* MlirStringCallback)(MlirStringRef, void *)
A callback for returning string references.
Definition: Support.h:105
Include the generated interface declarations.
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
Structure of external MlirPass callbacks.
Definition: Pass.h:153
void(* run)(MlirOperation op, MlirExternalPass pass, void *userData)
This callback is called when the pass is run.
Definition: Pass.h:174
void *(* clone)(void *userData)
This callback is called when the pass is cloned.
Definition: Pass.h:170
MlirLogicalResult(* initialize)(MlirContext ctx, void *userData)
This callback is optional.
Definition: Pass.h:166
void(* destruct)(void *userData)
This callback is called when the pass is destroyed This is analogous to a C++ pass destructor.
Definition: Pass.h:160
void(* construct)(void *userData)
This callback is called from the pass is created.
Definition: Pass.h:156
A logical result value, essentially a boolean with named states.
Definition: Support.h:116
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:73
size_t length
Length of the fragment.
Definition: Support.h:75