MLIR 23.0.0git
Canonicalizer.cpp
Go to the documentation of this file.
1//===- Canonicalizer.cpp - Canonicalize MLIR operations -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation pass converts operations into their canonical forms by
10// folding constants, applying operation identity transformations etc.
11//
12//===----------------------------------------------------------------------===//
13
15
16#include "mlir/Pass/Pass.h"
18
19namespace mlir {
20#define GEN_PASS_DEF_CANONICALIZERPASS
21#include "mlir/Transforms/Passes.h.inc"
22} // namespace mlir
23
24using namespace mlir;
25
26namespace {
27/// Canonicalize operations in nested regions.
28struct Canonicalizer : public impl::CanonicalizerPassBase<Canonicalizer> {
29 using impl::CanonicalizerPassBase<Canonicalizer>::CanonicalizerPassBase;
30 Canonicalizer(const GreedyRewriteConfig &config,
31 ArrayRef<std::string> disabledPatterns,
32 ArrayRef<std::string> enabledPatterns)
33 : config(config) {
34 this->topDownProcessingEnabled = config.getUseTopDownTraversal();
35 this->regionSimplifyLevel = config.getRegionSimplificationLevel();
36 this->maxIterations = config.getMaxIterations();
37 this->maxNumRewrites = config.getMaxNumRewrites();
38 this->cseBetweenIterations = config.isCSEBetweenIterationsEnabled();
39 this->disabledPatterns = disabledPatterns;
40 this->enabledPatterns = enabledPatterns;
41 }
42
43 /// Initialize the canonicalizer by building the set of patterns used during
44 /// execution.
45 LogicalResult initialize(MLIRContext *context) override {
46 // Set the config from possible pass options set in the meantime.
47 config.setUseTopDownTraversal(topDownProcessingEnabled);
48 config.setRegionSimplificationLevel(regionSimplifyLevel);
49 config.setMaxIterations(maxIterations);
50 config.setMaxNumRewrites(maxNumRewrites);
51 config.enableCSEBetweenIterations(cseBetweenIterations);
52
53 RewritePatternSet owningPatterns(context);
54 for (auto *dialect : context->getLoadedDialects())
55 dialect->getCanonicalizationPatterns(owningPatterns);
56 for (RegisteredOperationName op : context->getRegisteredOperations())
57 op.getCanonicalizationPatterns(owningPatterns, context);
58
59 patterns = std::make_shared<FrozenRewritePatternSet>(
60 std::move(owningPatterns), disabledPatterns, enabledPatterns);
61 return success();
62 }
63 void runOnOperation() override {
64 LogicalResult converged =
65 applyPatternsGreedily(getOperation(), *patterns, config);
66 // Canonicalization is best-effort. Non-convergence is not a pass failure.
67 if (testConvergence && failed(converged))
68 signalPassFailure();
69 }
70 GreedyRewriteConfig config;
71 std::shared_ptr<const FrozenRewritePatternSet> patterns;
72};
73} // namespace
74
75/// Creates an instance of the Canonicalizer pass with the specified config.
76std::unique_ptr<Pass>
78 ArrayRef<std::string> disabledPatterns,
79 ArrayRef<std::string> enabledPatterns) {
80 return std::make_unique<Canonicalizer>(config, disabledPatterns,
81 enabledPatterns);
82}
return success()
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
This class allows control over how the GreedyPatternRewriteDriver works.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::unique_ptr<::mlir::Pass > createCanonicalizerPass()