MLIR 23.0.0git
OptReductionPass.cpp
Go to the documentation of this file.
1//===- OptReductionPass.cpp - Optimization Reduction Pass Wrapper ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the Opt Reduction Pass Wrapper. It creates a MLIR pass to
10// run any optimization pass within it and only replaces the output module with
11// the transformed version if it is smaller and interesting.
12//
13//===----------------------------------------------------------------------===//
14
17#include "mlir/Reducer/Passes.h"
18#include "mlir/Reducer/Tester.h"
19
20#include "llvm/Support/DebugLog.h"
21#include "llvm/Support/MemoryBuffer.h"
22
23namespace mlir {
24#define GEN_PASS_DEF_OPTREDUCTIONPASS
25#include "mlir/Reducer/Passes.h.inc"
26} // namespace mlir
27
28#define DEBUG_TYPE "mlir-reduce"
29
30using namespace mlir;
31
32namespace {
33
34class OptReductionPass : public impl::OptReductionPassBase<OptReductionPass> {
35public:
36 using Base::Base;
37
38 /// Runs the pass instance in the pass pipeline.
39 void runOnOperation() override;
40};
41
42} // namespace
43
44/// Runs the pass instance in the pass pipeline.
45void OptReductionPass::runOnOperation() {
46 LDBG() << "\nOptimization Reduction pass: ";
47
48 Tester test(testerName, testerArgs);
49 Operation *topOp = this->getOperation();
50
51 std::string pipelineStr = optPass;
52 if (pipelineStr.empty()) {
53 if (!optPassFile.empty()) {
54 auto fileOrErr = llvm::MemoryBuffer::getFile(optPassFile);
55 if (std::error_code ec = fileOrErr.getError()) {
56 topOp->emitError() << "Could not open pass pipeline file: "
57 << optPassFile << " (" << ec.message() << ")";
58 return signalPassFailure();
59 }
60 pipelineStr = fileOrErr.get()->getBuffer().trim().str();
61 }
62 }
63
64 PassManager passManager(topOp->getName());
65 if (failed(parsePassPipeline(pipelineStr, passManager))) {
66 topOp->emitError() << "\nfailed to parse pass pipeline";
67 return signalPassFailure();
68 }
69
70 std::pair<Tester::Interestingness, int> original = test.isInteresting(topOp);
71 if (original.first != Tester::Interestingness::True) {
72 topOp->emitError() << "\nthe original input is not interested";
73 return signalPassFailure();
74 }
75 Operation *topOpVariant = topOp->clone();
76
77 LogicalResult pipelineResult = passManager.run(topOpVariant);
78 if (failed(pipelineResult)) {
79 topOp->emitError() << "\nfailed to run pass pipeline";
80 return signalPassFailure();
81 }
82
83 std::pair<Tester::Interestingness, int> reduced =
84 test.isInteresting(topOpVariant);
85
86 if (reduced.first == Tester::Interestingness::True &&
87 reduced.second < original.second) {
88 topOp->getRegion(0).getBlocks().clear();
89 topOp->getRegion(0).getBlocks().splice(
90 topOp->getRegion(0).getBlocks().begin(),
91 topOpVariant->getRegion(0).getBlocks());
92
93 LDBG() << "\nSuccessful Transformed version\n";
94 } else {
95 LDBG() << "\nUnsuccessful Transformed version\n";
96 }
97
98 topOpVariant->destroy();
99
100 LDBG() << "Pass Complete\n";
101}
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:712
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
void destroy()
Destroys this operation and its subclass data.
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:116
Operation * clone(IRMapping &mapper, const CloneOptions &options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
BlockListType & getBlocks()
Definition Region.h:45
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
LogicalResult parsePassPipeline(StringRef pipeline, OpPassManager &pm, raw_ostream &errorStream=llvm::errs())
Parse the textual representation of a pass pipeline, adding the result to 'pm' on success.