MLIR  21.0.0git
Canonicalizer.cpp
Go to the documentation of this file.
1 //===- Canonicalizer.cpp - Canonicalize MLIR operations -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass converts operations into their canonical forms by
10 // folding constants, applying operation identity transformations etc.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Transforms/Passes.h"
15 
17 #include "mlir/Pass/Pass.h"
19 
20 namespace mlir {
21 #define GEN_PASS_DEF_CANONICALIZER
22 #include "mlir/Transforms/Passes.h.inc"
23 } // namespace mlir
24 
25 using namespace mlir;
26 
27 namespace {
28 /// Canonicalize operations in nested regions.
29 struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
30  Canonicalizer() = default;
31  Canonicalizer(const GreedyRewriteConfig &config,
32  ArrayRef<std::string> disabledPatterns,
33  ArrayRef<std::string> enabledPatterns)
34  : config(config) {
35  this->topDownProcessingEnabled = config.getUseTopDownTraversal();
36  this->regionSimplifyLevel = config.getRegionSimplificationLevel();
37  this->maxIterations = config.getMaxIterations();
38  this->maxNumRewrites = config.getMaxNumRewrites();
39  this->disabledPatterns = disabledPatterns;
40  this->enabledPatterns = enabledPatterns;
41  }
42 
43  /// Initialize the canonicalizer by building the set of patterns used during
44  /// execution.
45  LogicalResult initialize(MLIRContext *context) override {
46  // Set the config from possible pass options set in the meantime.
47  config.setUseTopDownTraversal(topDownProcessingEnabled);
48  config.setRegionSimplificationLevel(regionSimplifyLevel);
49  config.setMaxIterations(maxIterations);
50  config.setMaxNumRewrites(maxNumRewrites);
51 
52  RewritePatternSet owningPatterns(context);
53  for (auto *dialect : context->getLoadedDialects())
54  dialect->getCanonicalizationPatterns(owningPatterns);
56  op.getCanonicalizationPatterns(owningPatterns, context);
57 
58  patterns = std::make_shared<FrozenRewritePatternSet>(
59  std::move(owningPatterns), disabledPatterns, enabledPatterns);
60  return success();
61  }
62  void runOnOperation() override {
63  LogicalResult converged =
64  applyPatternsGreedily(getOperation(), *patterns, config);
65  // Canonicalization is best-effort. Non-convergence is not a pass failure.
66  if (testConvergence && failed(converged))
67  signalPassFailure();
68  }
70  std::shared_ptr<const FrozenRewritePatternSet> patterns;
71 };
72 } // namespace
73 
74 /// Create a Canonicalizer pass.
75 std::unique_ptr<Pass> mlir::createCanonicalizerPass() {
76  return std::make_unique<Canonicalizer>();
77 }
78 
79 /// Creates an instance of the Canonicalizer pass with the specified config.
80 std::unique_ptr<Pass>
82  ArrayRef<std::string> disabledPatterns,
83  ArrayRef<std::string> enabledPatterns) {
84  return std::make_unique<Canonicalizer>(config, disabledPatterns,
85  enabledPatterns);
86 }
This class allows control over how the GreedyPatternRewriteDriver works.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
This is a "type erased" representation of a registered operation.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
std::unique_ptr< Pass > createCanonicalizerPass()
Creates an instance of the Canonicalizer pass, configured with default settings (which can be overrid...