MLIR  19.0.0git
Canonicalizer.cpp
Go to the documentation of this file.
1 //===- Canonicalizer.cpp - Canonicalize MLIR operations -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass converts operations into their canonical forms by
10 // folding constants, applying operation identity transformations etc.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Transforms/Passes.h"
15 
16 #include "mlir/Pass/Pass.h"
18 
19 namespace mlir {
20 #define GEN_PASS_DEF_CANONICALIZER
21 #include "mlir/Transforms/Passes.h.inc"
22 } // namespace mlir
23 
24 using namespace mlir;
25 
26 namespace {
27 /// Canonicalize operations in nested regions.
28 struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
29  Canonicalizer() = default;
30  Canonicalizer(const GreedyRewriteConfig &config,
31  ArrayRef<std::string> disabledPatterns,
32  ArrayRef<std::string> enabledPatterns)
33  : config(config) {
34  this->topDownProcessingEnabled = config.useTopDownTraversal;
35  this->enableRegionSimplification = config.enableRegionSimplification;
36  this->maxIterations = config.maxIterations;
37  this->maxNumRewrites = config.maxNumRewrites;
38  this->disabledPatterns = disabledPatterns;
39  this->enabledPatterns = enabledPatterns;
40  }
41 
42  /// Initialize the canonicalizer by building the set of patterns used during
43  /// execution.
44  LogicalResult initialize(MLIRContext *context) override {
45  // Set the config from possible pass options set in the meantime.
46  config.useTopDownTraversal = topDownProcessingEnabled;
47  config.enableRegionSimplification = enableRegionSimplification;
48  config.maxIterations = maxIterations;
49  config.maxNumRewrites = maxNumRewrites;
50 
51  RewritePatternSet owningPatterns(context);
52  for (auto *dialect : context->getLoadedDialects())
53  dialect->getCanonicalizationPatterns(owningPatterns);
55  op.getCanonicalizationPatterns(owningPatterns, context);
56 
57  patterns = std::make_shared<FrozenRewritePatternSet>(
58  std::move(owningPatterns), disabledPatterns, enabledPatterns);
59  return success();
60  }
61  void runOnOperation() override {
62  LogicalResult converged =
63  applyPatternsAndFoldGreedily(getOperation(), *patterns, config);
64  // Canonicalization is best-effort. Non-convergence is not a pass failure.
65  if (testConvergence && failed(converged))
66  signalPassFailure();
67  }
68  GreedyRewriteConfig config;
69  std::shared_ptr<const FrozenRewritePatternSet> patterns;
70 };
71 } // namespace
72 
73 /// Create a Canonicalizer pass.
74 std::unique_ptr<Pass> mlir::createCanonicalizerPass() {
75  return std::make_unique<Canonicalizer>();
76 }
77 
78 /// Creates an instance of the Canonicalizer pass with the specified config.
79 std::unique_ptr<Pass>
81  ArrayRef<std::string> disabledPatterns,
82  ArrayRef<std::string> enabledPatterns) {
83  return std::make_unique<Canonicalizer>(config, disabledPatterns,
84  enabledPatterns);
85 }
This class allows control over how the GreedyPatternRewriteDriver works.
int64_t maxIterations
This specifies the maximum number of times the rewriter will iterate between applying patterns and si...
bool useTopDownTraversal
This specifies the order of initial traversal that populates the rewriters worklist.
int64_t maxNumRewrites
This specifies the maximum number of rewrites within an iteration.
bool enableRegionSimplification
Perform control flow optimizations to the region tree after applying all patterns.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
This is a "type erased" representation of a registered operation.
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::unique_ptr< Pass > createCanonicalizerPass()
Creates an instance of the Canonicalizer pass, configured with default settings (which can be overrid...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26