MLIR  16.0.0git
Canonicalizer.cpp
Go to the documentation of this file.
1 //===- Canonicalizer.cpp - Canonicalize MLIR operations -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass converts operations into their canonical forms by
10 // folding constants, applying operation identity transformations etc.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Transforms/Passes.h"
15 
16 #include "mlir/Pass/Pass.h"
18 
19 namespace mlir {
20 #define GEN_PASS_DEF_CANONICALIZER
21 #include "mlir/Transforms/Passes.h.inc"
22 } // namespace mlir
23 
24 using namespace mlir;
25 
26 namespace {
27 /// Canonicalize operations in nested regions.
28 struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
29  Canonicalizer() = default;
30  Canonicalizer(const GreedyRewriteConfig &config,
31  ArrayRef<std::string> disabledPatterns,
32  ArrayRef<std::string> enabledPatterns) {
33  this->topDownProcessingEnabled = config.useTopDownTraversal;
34  this->enableRegionSimplification = config.enableRegionSimplification;
35  this->maxIterations = config.maxIterations;
36  this->disabledPatterns = disabledPatterns;
37  this->enabledPatterns = enabledPatterns;
38  }
39 
40  /// Initialize the canonicalizer by building the set of patterns used during
41  /// execution.
42  LogicalResult initialize(MLIRContext *context) override {
43  RewritePatternSet owningPatterns(context);
44  for (auto *dialect : context->getLoadedDialects())
45  dialect->getCanonicalizationPatterns(owningPatterns);
47  op.getCanonicalizationPatterns(owningPatterns, context);
48 
49  patterns = FrozenRewritePatternSet(std::move(owningPatterns),
50  disabledPatterns, enabledPatterns);
51  return success();
52  }
53  void runOnOperation() override {
54  GreedyRewriteConfig config;
55  config.useTopDownTraversal = topDownProcessingEnabled;
56  config.enableRegionSimplification = enableRegionSimplification;
57  config.maxIterations = maxIterations;
58  (void)applyPatternsAndFoldGreedily(getOperation(), patterns, config);
59  }
60 
61  FrozenRewritePatternSet patterns;
62 };
63 } // namespace
64 
65 /// Create a Canonicalizer pass.
66 std::unique_ptr<Pass> mlir::createCanonicalizerPass() {
67  return std::make_unique<Canonicalizer>();
68 }
69 
70 /// Creates an instance of the Canonicalizer pass with the specified config.
71 std::unique_ptr<Pass>
73  ArrayRef<std::string> disabledPatterns,
74  ArrayRef<std::string> enabledPatterns) {
75  return std::make_unique<Canonicalizer>(config, disabledPatterns,
76  enabledPatterns);
77 }
This class represents a frozen set of patterns that can be processed by a pattern applicator.
This class allows control over how the GreedyPatternRewriteDriver works.
int64_t maxIterations
This specifies the maximum number of times the rewriter will iterate between applying patterns and si...
bool useTopDownTraversal
This specifies the order of initial traversal that populates the rewriters worklist.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
This is a "type erased" representation of a registered operation.
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef< Region > regions, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite the regions of the specified operation, which must be isolated from above,...
std::unique_ptr< Pass > createCanonicalizerPass()
Creates an instance of the Canonicalizer pass, configured with default settings (which can be overrid...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26