19#include "llvm/ADT/DenseSet.h"
22#define GEN_PASS_DEF_CANONICALIZERPASS
23#include "mlir/Transforms/Passes.h.inc"
31 using impl::CanonicalizerPassBase<Canonicalizer>::CanonicalizerPassBase;
32 Canonicalizer(
const GreedyRewriteConfig &config,
33 ArrayRef<std::string> disabledPatterns,
34 ArrayRef<std::string> enabledPatterns)
36 this->topDownProcessingEnabled = config.getUseTopDownTraversal();
37 this->regionSimplifyLevel = config.getRegionSimplificationLevel();
38 this->maxIterations = config.getMaxIterations();
39 this->maxNumRewrites = config.getMaxNumRewrites();
40 this->cseBetweenIterations = config.isCSEBetweenIterationsEnabled();
41 this->disabledPatterns = disabledPatterns;
42 this->enabledPatterns = enabledPatterns;
45 void getDependentDialects(DialectRegistry ®istry)
const override {
48 for (
const std::string &name : filterDialects)
54 LogicalResult
initialize(MLIRContext *context)
override {
56 config.setUseTopDownTraversal(topDownProcessingEnabled);
57 config.setRegionSimplificationLevel(regionSimplifyLevel);
58 config.setMaxIterations(maxIterations);
59 config.setMaxNumRewrites(maxNumRewrites);
60 config.enableCSEBetweenIterations(cseBetweenIterations);
62 llvm::DenseSet<TypeID> allowedDialects;
63 for (
const std::string &name : filterDialects) {
65 assert(dialect &&
"filter-dialect should have been preloaded by the "
66 "PassManager via getDependentDialects");
67 allowedDialects.insert(dialect->
getTypeID());
69 auto isAllowed = [&](Dialect *dialect) {
70 return allowedDialects.empty() ||
71 allowedDialects.contains(dialect->getTypeID());
74 RewritePatternSet owningPatterns(context);
76 if (isAllowed(dialect))
77 dialect->getCanonicalizationPatterns(owningPatterns);
79 if (isAllowed(&op.getDialect()))
80 op.getCanonicalizationPatterns(owningPatterns, context);
82 patterns = std::make_shared<FrozenRewritePatternSet>(
83 std::move(owningPatterns), disabledPatterns, enabledPatterns);
86 void runOnOperation()
override {
87 LogicalResult converged =
90 if (testConvergence &&
failed(converged))
93 GreedyRewriteConfig config;
94 std::shared_ptr<const FrozenRewritePatternSet> patterns;
103 return std::make_unique<Canonicalizer>(config, disabledPatterns,
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
void addDialectToPreload(StringRef name)
Request that the dialect with the given name be preloaded into the MLIRContext, without providing an ...
TypeID getTypeID() const
Returns the unique identifier that corresponds to this dialect.
This class allows control over how the GreedyPatternRewriteDriver works.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
ArrayRef< RegisteredOperationName > getRegisteredOperations()
Return a sorted array containing the information about all registered operations.
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::unique_ptr<::mlir::Pass > createCanonicalizerPass()