MLIR  19.0.0git
FrozenRewritePatternSet.cpp
Go to the documentation of this file.
1 //===- FrozenRewritePatternSet.cpp - Frozen Pattern List -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "ByteCode.h"
12 #include "mlir/Pass/Pass.h"
13 #include "mlir/Pass/PassManager.h"
14 #include <optional>
15 
16 using namespace mlir;
17 
18 // Include the PDL rewrite support.
19 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
22 
23 static LogicalResult
24 convertPDLToPDLInterp(ModuleOp pdlModule,
26  // Skip the conversion if the module doesn't contain pdl.
27  if (pdlModule.getOps<pdl::PatternOp>().empty())
28  return success();
29 
30  // Simplify the provided PDL module. Note that we can't use the canonicalizer
31  // here because it would create a cyclic dependency.
32  auto simplifyFn = [](Operation *op) {
33  // TODO: Add folding here if ever necessary.
34  if (isOpTriviallyDead(op))
35  op->erase();
36  };
37  pdlModule.getBody()->walk(simplifyFn);
38 
39  /// Lower the PDL pattern module to the interpreter dialect.
40  PassManager pdlPipeline(pdlModule->getName());
41 #ifdef NDEBUG
42  // We don't want to incur the hit of running the verifier when in release
43  // mode.
44  pdlPipeline.enableVerifier(false);
45 #endif
46  pdlPipeline.addPass(createPDLToPDLInterpPass(configMap));
47  if (failed(pdlPipeline.run(pdlModule)))
48  return failure();
49 
50  // Simplify again after running the lowering pipeline.
51  pdlModule.getBody()->walk(simplifyFn);
52  return success();
53 }
54 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
55 
56 //===----------------------------------------------------------------------===//
57 // FrozenRewritePatternSet
58 //===----------------------------------------------------------------------===//
59 
61  : impl(std::make_shared<Impl>()) {}
62 
64  RewritePatternSet &&patterns, ArrayRef<std::string> disabledPatternLabels,
65  ArrayRef<std::string> enabledPatternLabels)
66  : impl(std::make_shared<Impl>()) {
67  DenseSet<StringRef> disabledPatterns, enabledPatterns;
68  disabledPatterns.insert(disabledPatternLabels.begin(),
69  disabledPatternLabels.end());
70  enabledPatterns.insert(enabledPatternLabels.begin(),
71  enabledPatternLabels.end());
72 
73  // Functor used to walk all of the operations registered in the context. This
74  // is useful for patterns that get applied to multiple operations, such as
75  // interface and trait based patterns.
76  std::vector<RegisteredOperationName> opInfos;
77  auto addToOpsWhen =
78  [&](std::unique_ptr<RewritePattern> &pattern,
79  function_ref<bool(RegisteredOperationName)> callbackFn) {
80  if (opInfos.empty())
81  opInfos = pattern->getContext()->getRegisteredOperations();
82  for (RegisteredOperationName info : opInfos)
83  if (callbackFn(info))
84  impl->nativeOpSpecificPatternMap[info].push_back(pattern.get());
85  impl->nativeOpSpecificPatternList.push_back(std::move(pattern));
86  };
87 
88  for (std::unique_ptr<RewritePattern> &pat : patterns.getNativePatterns()) {
89  // Don't add patterns that haven't been enabled by the user.
90  if (!enabledPatterns.empty()) {
91  auto isEnabledFn = [&](StringRef label) {
92  return enabledPatterns.count(label);
93  };
94  if (!isEnabledFn(pat->getDebugName()) &&
95  llvm::none_of(pat->getDebugLabels(), isEnabledFn))
96  continue;
97  }
98  // Don't add patterns that have been disabled by the user.
99  if (!disabledPatterns.empty()) {
100  auto isDisabledFn = [&](StringRef label) {
101  return disabledPatterns.count(label);
102  };
103  if (isDisabledFn(pat->getDebugName()) ||
104  llvm::any_of(pat->getDebugLabels(), isDisabledFn))
105  continue;
106  }
107 
108  if (std::optional<OperationName> rootName = pat->getRootKind()) {
109  impl->nativeOpSpecificPatternMap[*rootName].push_back(pat.get());
110  impl->nativeOpSpecificPatternList.push_back(std::move(pat));
111  continue;
112  }
113  if (std::optional<TypeID> interfaceID = pat->getRootInterfaceID()) {
114  addToOpsWhen(pat, [&](RegisteredOperationName info) {
115  return info.hasInterface(*interfaceID);
116  });
117  continue;
118  }
119  if (std::optional<TypeID> traitID = pat->getRootTraitID()) {
120  addToOpsWhen(pat, [&](RegisteredOperationName info) {
121  return info.hasTrait(*traitID);
122  });
123  continue;
124  }
125  impl->nativeAnyOpPatterns.push_back(std::move(pat));
126  }
127 
128 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
129  // Generate the bytecode for the PDL patterns if any were provided.
130  PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
131  ModuleOp pdlModule = pdlPatterns.getModule();
132  if (!pdlModule)
133  return;
135  pdlPatterns.takeConfigMap();
136  if (failed(convertPDLToPDLInterp(pdlModule, configMap)))
137  llvm::report_fatal_error(
138  "failed to lower PDL pattern module to the PDL Interpreter");
139 
140  // Generate the pdl bytecode.
141  impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
142  pdlModule, pdlPatterns.takeConfigs(), configMap,
143  pdlPatterns.takeConstraintFunctions(),
144  pdlPatterns.takeRewriteFunctions());
145 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
146 }
147 
bool hasTrait() const
Returns true if the operation was registered with a particular trait, e.g.
bool hasInterface() const
Returns true if this operation has the given interface registered to it.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
The main pass manager and pipeline builder.
Definition: PassManager.h:232
This is a "type erased" representation of a registered operation.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
std::unique_ptr< OperationPass< ModuleOp > > createPDLToPDLInterpPass()
Creates and returns a pass to convert PDL ops to PDL interpreter ops.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26