MLIR  16.0.0git
Canonicalizer.cpp
Go to the documentation of this file.
1 //===- Canonicalizer.cpp - Canonicalize MLIR operations -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass converts operations into their canonical forms by
10 // folding constants, applying operation identity transformations etc.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/Passes.h"
18 
19 using namespace mlir;
20 
21 namespace {
22 /// Canonicalize operations in nested regions.
23 struct Canonicalizer : public CanonicalizerBase<Canonicalizer> {
24  Canonicalizer() = default;
25  Canonicalizer(const GreedyRewriteConfig &config,
26  ArrayRef<std::string> disabledPatterns,
27  ArrayRef<std::string> enabledPatterns) {
28  this->topDownProcessingEnabled = config.useTopDownTraversal;
29  this->enableRegionSimplification = config.enableRegionSimplification;
30  this->maxIterations = config.maxIterations;
31  this->disabledPatterns = disabledPatterns;
32  this->enabledPatterns = enabledPatterns;
33  }
34 
35  /// Initialize the canonicalizer by building the set of patterns used during
36  /// execution.
37  LogicalResult initialize(MLIRContext *context) override {
38  RewritePatternSet owningPatterns(context);
39  for (auto *dialect : context->getLoadedDialects())
40  dialect->getCanonicalizationPatterns(owningPatterns);
42  op.getCanonicalizationPatterns(owningPatterns, context);
43 
44  patterns = FrozenRewritePatternSet(std::move(owningPatterns),
45  disabledPatterns, enabledPatterns);
46  return success();
47  }
48  void runOnOperation() override {
49  GreedyRewriteConfig config;
50  config.useTopDownTraversal = topDownProcessingEnabled;
51  config.enableRegionSimplification = enableRegionSimplification;
52  config.maxIterations = maxIterations;
53  (void)applyPatternsAndFoldGreedily(getOperation(), patterns, config);
54  }
55 
56  FrozenRewritePatternSet patterns;
57 };
58 } // namespace
59 
60 /// Create a Canonicalizer pass.
61 std::unique_ptr<Pass> mlir::createCanonicalizerPass() {
62  return std::make_unique<Canonicalizer>();
63 }
64 
65 /// Creates an instance of the Canonicalizer pass with the specified config.
66 std::unique_ptr<Pass>
68  ArrayRef<std::string> disabledPatterns,
69  ArrayRef<std::string> enabledPatterns) {
70  return std::make_unique<Canonicalizer>(config, disabledPatterns,
71  enabledPatterns);
72 }
Include the generated interface declarations.
This class represents a frozen set of patterns that can be processed by a pattern applicator...
This class allows control over how the GreedyPatternRewriteDriver works.
bool useTopDownTraversal
This specifies the order of initial traversal that populates the rewriters worklist.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
int64_t maxIterations
This specifies the maximum number of times the rewriter will iterate between applying patterns and si...
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations. ...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
This is a "type erased" representation of a registered operation.
LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef< Region > regions, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite the regions of the specified operation, which must be isolated from above, by repeatedly applying the highest benefit patterns in a greedy work-list driven manner.
std::unique_ptr< Pass > createCanonicalizerPass()
Creates an instance of the Canonicalizer pass, configured with default settings (which can be overrid...