MLIR  21.0.0git
OptReductionPass.cpp
Go to the documentation of this file.
1 //===- OptReductionPass.cpp - Optimization Reduction Pass Wrapper ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Opt Reduction Pass Wrapper. It creates a MLIR pass to
10 // run any optimization pass within it and only replaces the output module with
11 // the transformed version if it is smaller and interesting.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Pass/PassManager.h"
16 #include "mlir/Pass/PassRegistry.h"
17 #include "mlir/Reducer/Passes.h"
18 #include "mlir/Reducer/Tester.h"
19 
20 #include "llvm/Support/Debug.h"
21 
22 namespace mlir {
23 #define GEN_PASS_DEF_OPTREDUCTIONPASS
24 #include "mlir/Reducer/Passes.h.inc"
25 } // namespace mlir
26 
27 #define DEBUG_TYPE "mlir-reduce"
28 
29 using namespace mlir;
30 
31 namespace {
32 
33 class OptReductionPass : public impl::OptReductionPassBase<OptReductionPass> {
34 public:
35  using Base::Base;
36 
37  /// Runs the pass instance in the pass pipeline.
38  void runOnOperation() override;
39 };
40 
41 } // namespace
42 
43 /// Runs the pass instance in the pass pipeline.
44 void OptReductionPass::runOnOperation() {
45  LLVM_DEBUG(llvm::dbgs() << "\nOptimization Reduction pass: ");
46 
47  Tester test(testerName, testerArgs);
48 
49  ModuleOp module = this->getOperation();
50  ModuleOp moduleVariant = module.clone();
51 
52  OpPassManager passManager("builtin.module");
53  if (failed(parsePassPipeline(optPass, passManager))) {
54  module.emitError() << "\nfailed to parse pass pipeline";
55  return signalPassFailure();
56  }
57 
58  std::pair<Tester::Interestingness, int> original = test.isInteresting(module);
59  if (original.first != Tester::Interestingness::True) {
60  module.emitError() << "\nthe original input is not interested";
61  return signalPassFailure();
62  }
63 
64  // Temporarily push the variant under the main module and execute the pipeline
65  // on it.
66  module.getBody()->push_back(moduleVariant);
67  LogicalResult pipelineResult = runPipeline(passManager, moduleVariant);
68  moduleVariant->remove();
69 
70  if (failed(pipelineResult)) {
71  module.emitError() << "\nfailed to run pass pipeline";
72  return signalPassFailure();
73  }
74 
75  std::pair<Tester::Interestingness, int> reduced =
76  test.isInteresting(moduleVariant);
77 
78  if (reduced.first == Tester::Interestingness::True &&
79  reduced.second < original.second) {
80  module.getBody()->clear();
81  module.getBody()->getOperations().splice(
82  module.getBody()->begin(), moduleVariant.getBody()->getOperations());
83  LLVM_DEBUG(llvm::dbgs() << "\nSuccessful Transformed version\n\n");
84  } else {
85  LLVM_DEBUG(llvm::dbgs() << "\nUnsuccessful Transformed version\n\n");
86  }
87 
88  moduleVariant->destroy();
89 
90  LLVM_DEBUG(llvm::dbgs() << "Pass Complete\n\n");
91 }
This class represents a pass manager that runs passes on either a specific operation type,...
Definition: PassManager.h:46
This class is used to keep track of the testing environment of the tool.
Definition: Tester.h:31
Include the generated interface declarations.
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.