MLIR  18.0.0git
FrozenRewritePatternSet.cpp
Go to the documentation of this file.
1 //===- FrozenRewritePatternSet.cpp - Frozen Pattern List -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "ByteCode.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Pass/PassManager.h"
16 #include <optional>
17 
18 using namespace mlir;
19 
20 static LogicalResult
21 convertPDLToPDLInterp(ModuleOp pdlModule,
23  // Skip the conversion if the module doesn't contain pdl.
24  if (pdlModule.getOps<pdl::PatternOp>().empty())
25  return success();
26 
27  // Simplify the provided PDL module. Note that we can't use the canonicalizer
28  // here because it would create a cyclic dependency.
29  auto simplifyFn = [](Operation *op) {
30  // TODO: Add folding here if ever necessary.
31  if (isOpTriviallyDead(op))
32  op->erase();
33  };
34  pdlModule.getBody()->walk(simplifyFn);
35 
36  /// Lower the PDL pattern module to the interpreter dialect.
37  PassManager pdlPipeline(pdlModule->getName());
38 #ifdef NDEBUG
39  // We don't want to incur the hit of running the verifier when in release
40  // mode.
41  pdlPipeline.enableVerifier(false);
42 #endif
43  pdlPipeline.addPass(createPDLToPDLInterpPass(configMap));
44  if (failed(pdlPipeline.run(pdlModule)))
45  return failure();
46 
47  // Simplify again after running the lowering pipeline.
48  pdlModule.getBody()->walk(simplifyFn);
49  return success();
50 }
51 
52 //===----------------------------------------------------------------------===//
53 // FrozenRewritePatternSet
54 //===----------------------------------------------------------------------===//
55 
57  : impl(std::make_shared<Impl>()) {}
58 
60  RewritePatternSet &&patterns, ArrayRef<std::string> disabledPatternLabels,
61  ArrayRef<std::string> enabledPatternLabels)
62  : impl(std::make_shared<Impl>()) {
63  DenseSet<StringRef> disabledPatterns, enabledPatterns;
64  disabledPatterns.insert(disabledPatternLabels.begin(),
65  disabledPatternLabels.end());
66  enabledPatterns.insert(enabledPatternLabels.begin(),
67  enabledPatternLabels.end());
68 
69  // Functor used to walk all of the operations registered in the context. This
70  // is useful for patterns that get applied to multiple operations, such as
71  // interface and trait based patterns.
72  std::vector<RegisteredOperationName> opInfos;
73  auto addToOpsWhen =
74  [&](std::unique_ptr<RewritePattern> &pattern,
75  function_ref<bool(RegisteredOperationName)> callbackFn) {
76  if (opInfos.empty())
77  opInfos = pattern->getContext()->getRegisteredOperations();
78  for (RegisteredOperationName info : opInfos)
79  if (callbackFn(info))
80  impl->nativeOpSpecificPatternMap[info].push_back(pattern.get());
81  impl->nativeOpSpecificPatternList.push_back(std::move(pattern));
82  };
83 
84  for (std::unique_ptr<RewritePattern> &pat : patterns.getNativePatterns()) {
85  // Don't add patterns that haven't been enabled by the user.
86  if (!enabledPatterns.empty()) {
87  auto isEnabledFn = [&](StringRef label) {
88  return enabledPatterns.count(label);
89  };
90  if (!isEnabledFn(pat->getDebugName()) &&
91  llvm::none_of(pat->getDebugLabels(), isEnabledFn))
92  continue;
93  }
94  // Don't add patterns that have been disabled by the user.
95  if (!disabledPatterns.empty()) {
96  auto isDisabledFn = [&](StringRef label) {
97  return disabledPatterns.count(label);
98  };
99  if (isDisabledFn(pat->getDebugName()) ||
100  llvm::any_of(pat->getDebugLabels(), isDisabledFn))
101  continue;
102  }
103 
104  if (std::optional<OperationName> rootName = pat->getRootKind()) {
105  impl->nativeOpSpecificPatternMap[*rootName].push_back(pat.get());
106  impl->nativeOpSpecificPatternList.push_back(std::move(pat));
107  continue;
108  }
109  if (std::optional<TypeID> interfaceID = pat->getRootInterfaceID()) {
110  addToOpsWhen(pat, [&](RegisteredOperationName info) {
111  return info.hasInterface(*interfaceID);
112  });
113  continue;
114  }
115  if (std::optional<TypeID> traitID = pat->getRootTraitID()) {
116  addToOpsWhen(pat, [&](RegisteredOperationName info) {
117  return info.hasTrait(*traitID);
118  });
119  continue;
120  }
121  impl->nativeAnyOpPatterns.push_back(std::move(pat));
122  }
123 
124  // Generate the bytecode for the PDL patterns if any were provided.
125  PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
126  ModuleOp pdlModule = pdlPatterns.getModule();
127  if (!pdlModule)
128  return;
130  pdlPatterns.takeConfigMap();
131  if (failed(convertPDLToPDLInterp(pdlModule, configMap)))
132  llvm::report_fatal_error(
133  "failed to lower PDL pattern module to the PDL Interpreter");
134 
135  // Generate the pdl bytecode.
136  impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
137  pdlModule, pdlPatterns.takeConfigs(), configMap,
138  pdlPatterns.takeConstraintFunctions(),
139  pdlPatterns.takeRewriteFunctions());
140 }
141 
static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule, DenseMap< Operation *, PDLPatternConfigSet * > &configMap)
void addPass(std::unique_ptr< Pass > pass)
Add the given pass to this pass manager.
Definition: Pass.cpp:346
bool hasTrait() const
Returns true if the operation was registered with a particular trait, e.g.
bool hasInterface() const
Returns true if this operation has the given interface registered to it.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
This class contains all of the necessary data for a set of PDL patterns, or pattern rewrites specifie...
llvm::StringMap< PDLConstraintFunction > takeConstraintFunctions()
ModuleOp getModule()
Return the internal PDL module of this pattern.
SmallVector< std::unique_ptr< PDLPatternConfigSet > > takeConfigs()
Return the set of the registered pattern configs.
llvm::StringMap< PDLRewriteFunction > takeRewriteFunctions()
DenseMap< Operation *, PDLPatternConfigSet * > takeConfigMap()
The main pass manager and pipeline builder.
Definition: PassManager.h:211
LogicalResult run(Operation *op)
Run the passes within this manager on the provided operation.
Definition: Pass.cpp:822
void enableVerifier(bool enabled=true)
Runs the verifier after each individual pass.
Definition: Pass.cpp:819
This is a "type erased" representation of a registered operation.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
std::unique_ptr< OperationPass< ModuleOp > > createPDLToPDLInterpPass()
Creates and returns a pass to convert PDL ops to PDL interpreter ops.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26