MLIR  21.0.0git
CompositePass.cpp
Go to the documentation of this file.
1 //===- CompositePass.cpp - Composite pass code ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // CompositePass allows to run set of passes until fixed point is reached.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Transforms/Passes.h"
14 
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Pass/PassManager.h"
17 
18 namespace mlir {
19 #define GEN_PASS_DEF_COMPOSITEFIXEDPOINTPASS
20 #include "mlir/Transforms/Passes.h.inc"
21 } // namespace mlir
22 
23 using namespace mlir;
24 
25 namespace {
26 struct CompositeFixedPointPass final
27  : public impl::CompositeFixedPointPassBase<CompositeFixedPointPass> {
28  using CompositeFixedPointPassBase::CompositeFixedPointPassBase;
29 
30  CompositeFixedPointPass(
31  std::string name_, llvm::function_ref<void(OpPassManager &)> populateFunc,
32  int maxIterations) {
33  name = std::move(name_);
34  maxIter = maxIterations;
35  populateFunc(dynamicPM);
36 
37  llvm::raw_string_ostream os(pipelineStr);
38  llvm::interleave(
39  dynamicPM, [&](mlir::Pass &pass) { pass.printAsTextualPipeline(os); },
40  [&]() { os << ","; });
41  }
42 
43  LogicalResult initializeOptions(
44  StringRef options,
45  function_ref<LogicalResult(const Twine &)> errorHandler) override {
46  if (failed(CompositeFixedPointPassBase::initializeOptions(options,
47  errorHandler)))
48  return failure();
49 
50  if (failed(parsePassPipeline(pipelineStr, dynamicPM)))
51  return errorHandler("Failed to parse composite pass pipeline");
52 
53  return success();
54  }
55 
56  LogicalResult initialize(MLIRContext *context) override {
57  if (maxIter <= 0)
58  return emitError(UnknownLoc::get(context))
59  << "Invalid maxIterations value: " << maxIter << "\n";
60 
61  return success();
62  }
63 
64  void getDependentDialects(DialectRegistry &registry) const override {
65  dynamicPM.getDependentDialects(registry);
66  }
67 
68  void runOnOperation() override {
69  auto op = getOperation();
70  OperationFingerPrint fp(op);
71 
72  int currentIter = 0;
73  int maxIterVal = maxIter;
74  while (true) {
75  if (failed(runPipeline(dynamicPM, op)))
76  return signalPassFailure();
77 
78  if (currentIter++ >= maxIterVal) {
79  op->emitWarning("Composite pass \"" + llvm::Twine(name) +
80  "\"+ didn't converge in " + llvm::Twine(maxIterVal) +
81  " iterations");
82  break;
83  }
84 
85  OperationFingerPrint newFp(op);
86  if (newFp == fp)
87  break;
88 
89  fp = newFp;
90  }
91  }
92 
93 protected:
94  llvm::StringRef getName() const override { return name; }
95 
96 private:
97  OpPassManager dynamicPM;
98 };
99 } // namespace
100 
102  std::string name, llvm::function_ref<void(OpPassManager &)> populateFunc,
103  int maxIterations) {
104 
105  return std::make_unique<CompositeFixedPointPass>(std::move(name),
106  populateFunc, maxIterations);
107 }
static llvm::ManagedStatic< PassManagerOptions > options
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a pass manager that runs passes on either a specific operation type,...
Definition: PassManager.h:46
A unique fingerprint for a specific operation, and all of it's internal operations (if includeNested ...
The abstract base pass class.
Definition: Pass.h:51
void printAsTextualPipeline(raw_ostream &os)
Prints out the pass in the textual representation of pipelines.
Definition: Pass.cpp:84
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::unique_ptr< Pass > createCompositeFixedPointPass(std::string name, llvm::function_ref< void(OpPassManager &)> populateFunc, int maxIterations=10)
Create composite pass, which runs provided set of passes until fixed point or maximum number of itera...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.