MLIR 23.0.0git
Canonicalizer.cpp
Go to the documentation of this file.
1//===- Canonicalizer.cpp - Canonicalize MLIR operations -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation pass converts operations into their canonical forms by
10// folding constants, applying operation identity transformations etc.
11//
12//===----------------------------------------------------------------------===//
13
15
17#include "mlir/Pass/Pass.h"
19#include "llvm/ADT/DenseSet.h"
20
21namespace mlir {
22#define GEN_PASS_DEF_CANONICALIZERPASS
23#include "mlir/Transforms/Passes.h.inc"
24} // namespace mlir
25
26using namespace mlir;
27
28namespace {
29/// Canonicalize operations in nested regions.
30struct Canonicalizer : public impl::CanonicalizerPassBase<Canonicalizer> {
31 using impl::CanonicalizerPassBase<Canonicalizer>::CanonicalizerPassBase;
32 Canonicalizer(const GreedyRewriteConfig &config,
33 ArrayRef<std::string> disabledPatterns,
34 ArrayRef<std::string> enabledPatterns)
35 : config(config) {
36 this->topDownProcessingEnabled = config.getUseTopDownTraversal();
37 this->regionSimplifyLevel = config.getRegionSimplificationLevel();
38 this->maxIterations = config.getMaxIterations();
39 this->maxNumRewrites = config.getMaxNumRewrites();
40 this->cseBetweenIterations = config.isCSEBetweenIterationsEnabled();
41 this->disabledPatterns = disabledPatterns;
42 this->enabledPatterns = enabledPatterns;
43 }
44
45 void getDependentDialects(DialectRegistry &registry) const override {
46 // Force-load any dialects named via the `filter-dialects` option. The
47 // allocator is resolved later from the MLIRContext's own registry.
48 for (const std::string &name : filterDialects)
49 registry.addDialectToPreload(StringRef(name));
50 }
51
52 /// Initialize the canonicalizer by building the set of patterns used during
53 /// execution.
54 LogicalResult initialize(MLIRContext *context) override {
55 // Set the config from possible pass options set in the meantime.
56 config.setUseTopDownTraversal(topDownProcessingEnabled);
57 config.setRegionSimplificationLevel(regionSimplifyLevel);
58 config.setMaxIterations(maxIterations);
59 config.setMaxNumRewrites(maxNumRewrites);
60 config.enableCSEBetweenIterations(cseBetweenIterations);
61
62 llvm::DenseSet<TypeID> allowedDialects;
63 for (const std::string &name : filterDialects) {
64 Dialect *dialect = context->getLoadedDialect(name);
65 assert(dialect && "filter-dialect should have been preloaded by the "
66 "PassManager via getDependentDialects");
67 allowedDialects.insert(dialect->getTypeID());
68 }
69 auto isAllowed = [&](Dialect *dialect) {
70 return allowedDialects.empty() ||
71 allowedDialects.contains(dialect->getTypeID());
72 };
73
74 RewritePatternSet owningPatterns(context);
75 for (auto *dialect : context->getLoadedDialects())
76 if (isAllowed(dialect))
77 dialect->getCanonicalizationPatterns(owningPatterns);
78 for (RegisteredOperationName op : context->getRegisteredOperations())
79 if (isAllowed(&op.getDialect()))
80 op.getCanonicalizationPatterns(owningPatterns, context);
81
82 patterns = std::make_shared<FrozenRewritePatternSet>(
83 std::move(owningPatterns), disabledPatterns, enabledPatterns);
84 return success();
85 }
86 void runOnOperation() override {
87 LogicalResult converged =
88 applyPatternsGreedily(getOperation(), *patterns, config);
89 // Canonicalization is best-effort. Non-convergence is not a pass failure.
90 if (testConvergence && failed(converged))
91 signalPassFailure();
92 }
93 GreedyRewriteConfig config;
94 std::shared_ptr<const FrozenRewritePatternSet> patterns;
95};
96} // namespace
97
98/// Creates an instance of the Canonicalizer pass with the specified config.
99std::unique_ptr<Pass>
101 ArrayRef<std::string> disabledPatterns,
102 ArrayRef<std::string> enabledPatterns) {
103 return std::make_unique<Canonicalizer>(config, disabledPatterns,
104 enabledPatterns);
105}
return success()
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
void addDialectToPreload(StringRef name)
Request that the dialect with the given name be preloaded into the MLIRContext, without providing an ...
Definition Dialect.cpp:241
TypeID getTypeID() const
Returns the unique identifier that corresponds to this dialect.
Definition Dialect.h:57
This class allows control over how the GreedyPatternRewriteDriver works.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::unique_ptr<::mlir::Pass > createCanonicalizerPass()