MLIR  22.0.0git
Pass.cpp
Go to the documentation of this file.
1 //===- Pass.cpp - C Interface for General Pass Management APIs ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Pass.h"
10 
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Pass.h"
13 #include "mlir/CAPI/Support.h"
14 #include "mlir/CAPI/Utils.h"
15 #include "mlir/Pass/PassManager.h"
16 #include "llvm/Support/ErrorHandling.h"
17 #include <optional>
18 
19 using namespace mlir;
20 
21 //===----------------------------------------------------------------------===//
22 // PassManager/OpPassManager APIs.
23 //===----------------------------------------------------------------------===//
24 
25 MlirPassManager mlirPassManagerCreate(MlirContext ctx) {
26  return wrap(new PassManager(unwrap(ctx)));
27 }
28 
29 MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx,
30  MlirStringRef anchorOp) {
31  return wrap(new PassManager(unwrap(ctx), unwrap(anchorOp)));
32 }
33 
34 void mlirPassManagerDestroy(MlirPassManager passManager) {
35  delete unwrap(passManager);
36 }
37 
38 MlirOpPassManager
39 mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) {
40  return wrap(static_cast<OpPassManager *>(unwrap(passManager)));
41 }
42 
43 MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager,
44  MlirOperation op) {
45  return wrap(unwrap(passManager)->run(unwrap(op)));
46 }
47 
48 void mlirPassManagerEnableIRPrinting(MlirPassManager passManager,
49  bool printBeforeAll, bool printAfterAll,
50  bool printModuleScope,
51  bool printAfterOnlyOnChange,
52  bool printAfterOnlyOnFailure,
53  MlirOpPrintingFlags flags,
54  MlirStringRef treePrintingPath) {
55  auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) {
56  return printBeforeAll;
57  };
58  auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) {
59  return printAfterAll;
60  };
61  if (unwrap(treePrintingPath).empty())
62  return unwrap(passManager)
63  ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass,
64  printModuleScope, printAfterOnlyOnChange,
65  printAfterOnlyOnFailure, /*out=*/llvm::errs(),
66  *unwrap(flags));
67 
68  unwrap(passManager)
69  ->enableIRPrintingToFileTree(shouldPrintBeforePass, shouldPrintAfterPass,
70  printModuleScope, printAfterOnlyOnChange,
71  printAfterOnlyOnFailure,
72  unwrap(treePrintingPath), *unwrap(flags));
73 }
74 
75 void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) {
76  unwrap(passManager)->enableVerifier(enable);
77 }
78 
79 void mlirPassManagerEnableTiming(MlirPassManager passManager) {
80  unwrap(passManager)->enableTiming();
81 }
82 
83 void mlirPassManagerEnableStatistics(MlirPassManager passManager,
84  MlirPassDisplayMode displayMode) {
85  PassDisplayMode mode;
86  switch (displayMode) {
88  mode = PassDisplayMode::List;
89  break;
92  break;
93  }
94  unwrap(passManager)->enableStatistics(mode);
95 }
96 
97 MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager,
98  MlirStringRef operationName) {
99  return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
100 }
101 
102 MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager,
103  MlirStringRef operationName) {
104  return wrap(&unwrap(passManager)->nest(unwrap(operationName)));
105 }
106 
107 void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass) {
108  unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
109 }
110 
111 void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager,
112  MlirPass pass) {
113  unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
114 }
115 
116 MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager,
117  MlirStringRef pipelineElements,
118  MlirStringCallback callback,
119  void *userData) {
120  detail::CallbackOstream stream(callback, userData);
121  return wrap(parsePassPipeline(unwrap(pipelineElements), *unwrap(passManager),
122  stream));
123 }
124 
125 void mlirPrintPassPipeline(MlirOpPassManager passManager,
126  MlirStringCallback callback, void *userData) {
127  detail::CallbackOstream stream(callback, userData);
128  unwrap(passManager)->printAsTextualPipeline(stream);
129 }
130 
131 MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager,
132  MlirStringRef pipeline,
133  MlirStringCallback callback,
134  void *userData) {
135  detail::CallbackOstream stream(callback, userData);
136  FailureOr<OpPassManager> pm = parsePassPipeline(unwrap(pipeline), stream);
137  if (succeeded(pm))
138  *unwrap(passManager) = std::move(*pm);
139  return wrap(pm);
140 }
141 
142 //===----------------------------------------------------------------------===//
143 // External Pass API.
144 //===----------------------------------------------------------------------===//
145 
146 namespace mlir {
147 class ExternalPass;
148 } // namespace mlir
150 
151 namespace mlir {
152 /// This pass class wraps external passes defined in other languages using the
153 /// MLIR C-interface
154 class ExternalPass : public Pass {
155 public:
156  ExternalPass(TypeID passID, StringRef name, StringRef argument,
157  StringRef description, std::optional<StringRef> opName,
158  ArrayRef<MlirDialectHandle> dependentDialects,
159  MlirExternalPassCallbacks callbacks, void *userData)
160  : Pass(passID, opName), id(passID), name(name), argument(argument),
161  description(description), dependentDialects(dependentDialects),
162  callbacks(callbacks), userData(userData) {
163  if (callbacks.construct)
164  callbacks.construct(userData);
165  }
166 
167  ~ExternalPass() override {
168  if (callbacks.destruct)
169  callbacks.destruct(userData);
170  }
171 
172  StringRef getName() const override { return name; }
173  StringRef getArgument() const override { return argument; }
174  StringRef getDescription() const override { return description; }
175 
176  void getDependentDialects(DialectRegistry &registry) const override {
177  MlirDialectRegistry cRegistry = wrap(&registry);
178  for (MlirDialectHandle dialect : dependentDialects)
179  mlirDialectHandleInsertDialect(dialect, cRegistry);
180  }
181 
183 
184 protected:
185  LogicalResult initialize(MLIRContext *ctx) override {
186  if (callbacks.initialize)
187  return unwrap(callbacks.initialize(wrap(ctx), userData));
188  return success();
189  }
190 
191  bool canScheduleOn(RegisteredOperationName opName) const override {
192  if (std::optional<StringRef> specifiedOpName = getOpName())
193  return opName.getStringRef() == specifiedOpName;
194  return true;
195  }
196 
197  void runOnOperation() override {
198  callbacks.run(wrap(getOperation()), wrap(this), userData);
199  }
200 
201  std::unique_ptr<Pass> clonePass() const override {
202  void *clonedUserData = callbacks.clone(userData);
203  return std::make_unique<ExternalPass>(id, name, argument, description,
204  getOpName(), dependentDialects,
205  callbacks, clonedUserData);
206  }
207 
208 private:
209  TypeID id;
210  std::string name;
211  std::string argument;
212  std::string description;
213  std::vector<MlirDialectHandle> dependentDialects;
214  MlirExternalPassCallbacks callbacks;
215  void *userData;
216 };
217 } // namespace mlir
218 
219 MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name,
220  MlirStringRef argument,
221  MlirStringRef description, MlirStringRef opName,
222  intptr_t nDependentDialects,
223  MlirDialectHandle *dependentDialects,
224  MlirExternalPassCallbacks callbacks,
225  void *userData) {
226  return wrap(static_cast<mlir::Pass *>(new mlir::ExternalPass(
227  unwrap(passID), unwrap(name), unwrap(argument), unwrap(description),
228  opName.length > 0 ? std::optional<StringRef>(unwrap(opName))
229  : std::nullopt,
230  {dependentDialects, static_cast<size_t>(nDependentDialects)}, callbacks,
231  userData)));
232 }
233 
234 void mlirExternalPassSignalFailure(MlirExternalPass pass) {
235  unwrap(pass)->signalPassFailure();
236 }
MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager, MlirStringRef operationName)
Nest an OpPassManager under the provided OpPassManager, the nested passmanager will only run on opera...
Definition: Pass.cpp:102
MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name, MlirStringRef argument, MlirStringRef description, MlirStringRef opName, intptr_t nDependentDialects, MlirDialectHandle *dependentDialects, MlirExternalPassCallbacks callbacks, void *userData)
Creates an external MlirPass that calls the supplied callbacks using the supplied userData.
Definition: Pass.cpp:219
MlirPassManager mlirPassManagerCreate(MlirContext ctx)
Create a new top-level PassManager with the default anchor.
Definition: Pass.cpp:25
void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable)
Enable / disable verify-each.
Definition: Pass.cpp:75
void mlirPassManagerEnableStatistics(MlirPassManager passManager, MlirPassDisplayMode displayMode)
Enable pass statistics.
Definition: Pass.cpp:83
MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager, MlirStringRef operationName)
Nest an OpPassManager under the top-level PassManager, the nested passmanager will only run on operat...
Definition: Pass.cpp:97
void mlirPassManagerEnableTiming(MlirPassManager passManager)
Enable pass timing.
Definition: Pass.cpp:79
void mlirPassManagerDestroy(MlirPassManager passManager)
Destroy the provided PassManager.
Definition: Pass.cpp:34
MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline, MlirStringCallback callback, void *userData)
Parse a textual MLIR pass pipeline and assign it to the provided OpPassManager.
Definition: Pass.cpp:131
MlirOpPassManager mlirPassManagerGetAsOpPassManager(MlirPassManager passManager)
Cast a top-level PassManager to a generic OpPassManager.
Definition: Pass.cpp:39
MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op)
Run the provided passManager on the given op.
Definition: Pass.cpp:43
void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager, MlirPass pass)
Add a pass and transfer ownership to the provided mlirOpPassManager.
Definition: Pass.cpp:111
void mlirExternalPassSignalFailure(MlirExternalPass pass)
This signals that the pass has failed.
Definition: Pass.cpp:234
void mlirPrintPassPipeline(MlirOpPassManager passManager, MlirStringCallback callback, void *userData)
Print a textual MLIR pass pipeline by sending chunks of the string representation and forwarding user...
Definition: Pass.cpp:125
void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass)
Add a pass and transfer ownership to the provided top-level mlirPassManager.
Definition: Pass.cpp:107
MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx, MlirStringRef anchorOp)
Create a new top-level PassManager anchored on anchorOp.
Definition: Pass.cpp:29
MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager, MlirStringRef pipelineElements, MlirStringCallback callback, void *userData)
Parse a sequence of textual MLIR pass pipeline elements and add them to the provided OpPassManager.
Definition: Pass.cpp:116
void mlirPassManagerEnableIRPrinting(MlirPassManager passManager, bool printBeforeAll, bool printAfterAll, bool printModuleScope, bool printAfterOnlyOnChange, bool printAfterOnlyOnFailure, MlirOpPrintingFlags flags, MlirStringRef treePrintingPath)
Enable IR printing.
Definition: Pass.cpp:48
#define DEFINE_C_API_PTR_METHODS(name, cpptype)
Definition: Wrap.h:25
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This pass class wraps external passes defined in other languages using the MLIR C-interface.
Definition: Pass.cpp:154
StringRef getArgument() const override
Return the command line argument used when registering this pass.
Definition: Pass.cpp:173
~ExternalPass() override
Definition: Pass.cpp:167
void signalPassFailure()
Definition: Pass.cpp:182
ExternalPass(TypeID passID, StringRef name, StringRef argument, StringRef description, std::optional< StringRef > opName, ArrayRef< MlirDialectHandle > dependentDialects, MlirExternalPassCallbacks callbacks, void *userData)
Definition: Pass.cpp:156
StringRef getDescription() const override
Return the command line description used when registering this pass.
Definition: Pass.cpp:174
bool canScheduleOn(RegisteredOperationName opName) const override
Indicate if the current pass can be scheduled on the given operation type.
Definition: Pass.cpp:191
void runOnOperation() override
The polymorphic API that runs the pass over the currently held operation.
Definition: Pass.cpp:197
StringRef getName() const override
Returns the derived pass name.
Definition: Pass.cpp:172
LogicalResult initialize(MLIRContext *ctx) override
Initialize any complex state necessary for running this pass.
Definition: Pass.cpp:185
std::unique_ptr< Pass > clonePass() const override
Create a copy of this pass, ignoring statistics and options.
Definition: Pass.cpp:201
void getDependentDialects(DialectRegistry &registry) const override
Register dependent dialects for the current pass.
Definition: Pass.cpp:176
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class represents a pass manager that runs passes on either a specific operation type,...
Definition: PassManager.h:46
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
The main pass manager and pipeline builder.
Definition: PassManager.h:232
The abstract base pass class.
Definition: Pass.h:51
std::optional< StringRef > getOpName() const
Returns the name of the operation that this pass operates on, or std::nullopt if this is a generic Op...
Definition: Pass.h:83
Operation * getOperation()
Return the current operation being transformed.
Definition: Pass.h:212
void signalPassFailure()
Signal that some invariant was broken when running.
Definition: Pass.h:218
This is a "type erased" representation of a registered operation.
This class provides an efficient unique identifier for a specific C++ type.
Definition: TypeID.h:107
A simple raw ostream subclass that forwards write_impl calls to the user-supplied callback together w...
Definition: Utils.h:30
MlirPassDisplayMode
Enumerated type of pass display modes.
Definition: Pass.h:97
@ MLIR_PASS_DISPLAY_MODE_LIST
Definition: Pass.h:98
@ MLIR_PASS_DISPLAY_MODE_PIPELINE
Definition: Pass.h:99
mlir::Diagnostic & unwrap(MlirDiagnostic diagnostic)
Definition: Diagnostics.h:19
MlirDiagnostic wrap(mlir::Diagnostic &diagnostic)
Definition: Diagnostics.h:24
MLIR_CAPI_EXPORTED void mlirDialectHandleInsertDialect(MlirDialectHandle, MlirDialectRegistry)
Inserts the dialect associated with the provided dialect handle into the provided dialect registry.
void(* MlirStringCallback)(MlirStringRef, void *)
A callback for returning string references.
Definition: Support.h:105
Include the generated interface declarations.
PassDisplayMode
An enum describing the different display modes for the information within the pass manager.
Definition: PassManager.h:199
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.
Structure of external MlirPass callbacks.
Definition: Pass.h:165
void(* run)(MlirOperation op, MlirExternalPass pass, void *userData)
This callback is called when the pass is run.
Definition: Pass.h:186
void *(* clone)(void *userData)
This callback is called when the pass is cloned.
Definition: Pass.h:182
MlirLogicalResult(* initialize)(MlirContext ctx, void *userData)
This callback is optional.
Definition: Pass.h:178
void(* destruct)(void *userData)
This callback is called when the pass is destroyed This is analogous to a C++ pass destructor.
Definition: Pass.h:172
void(* construct)(void *userData)
This callback is called from the pass is created.
Definition: Pass.h:168
A logical result value, essentially a boolean with named states.
Definition: Support.h:116
A pointer to a sized fragment of a string, not necessarily null-terminated.
Definition: Support.h:73
size_t length
Length of the fragment.
Definition: Support.h:75