MLIR  21.0.0git
FrozenRewritePatternSet.cpp
Go to the documentation of this file.
1 //===- FrozenRewritePatternSet.cpp - Frozen Pattern List -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "ByteCode.h"
12 #include "mlir/Pass/Pass.h"
13 #include "mlir/Pass/PassManager.h"
14 #include <optional>
15 
16 using namespace mlir;
17 
18 // Include the PDL rewrite support.
19 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
22 
23 static LogicalResult
24 convertPDLToPDLInterp(ModuleOp pdlModule,
26  // Skip the conversion if the module doesn't contain pdl.
27  if (pdlModule.getOps<pdl::PatternOp>().empty())
28  return success();
29 
30  // Simplify the provided PDL module. Note that we can't use the canonicalizer
31  // here because it would create a cyclic dependency.
32  auto simplifyFn = [](Operation *op) {
33  // TODO: Add folding here if ever necessary.
34  if (isOpTriviallyDead(op))
35  op->erase();
36  };
37  pdlModule.getBody()->walk(simplifyFn);
38 
39  /// Lower the PDL pattern module to the interpreter dialect.
40  PassManager pdlPipeline(pdlModule->getName());
41 #ifdef NDEBUG
42  // We don't want to incur the hit of running the verifier when in release
43  // mode.
44  pdlPipeline.enableVerifier(false);
45 #endif
46  pdlPipeline.addPass(createConvertPDLToPDLInterpPass(configMap));
47  if (failed(pdlPipeline.run(pdlModule)))
48  return failure();
49 
50  // Simplify again after running the lowering pipeline.
51  pdlModule.getBody()->walk(simplifyFn);
52  return success();
53 }
54 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
55 
56 //===----------------------------------------------------------------------===//
57 // FrozenRewritePatternSet
58 //===----------------------------------------------------------------------===//
59 
61  : impl(std::make_shared<Impl>()) {}
62 
64  RewritePatternSet &&patterns, ArrayRef<std::string> disabledPatternLabels,
65  ArrayRef<std::string> enabledPatternLabels)
66  : impl(std::make_shared<Impl>()) {
67  DenseSet<StringRef> disabledPatterns, enabledPatterns;
68  disabledPatterns.insert_range(disabledPatternLabels);
69  enabledPatterns.insert_range(enabledPatternLabels);
70 
71  // Functor used to walk all of the operations registered in the context. This
72  // is useful for patterns that get applied to multiple operations, such as
73  // interface and trait based patterns.
74  std::vector<RegisteredOperationName> opInfos;
75  auto addToOpsWhen =
76  [&](std::unique_ptr<RewritePattern> &pattern,
77  function_ref<bool(RegisteredOperationName)> callbackFn) {
78  if (opInfos.empty())
79  opInfos = pattern->getContext()->getRegisteredOperations();
80  for (RegisteredOperationName info : opInfos)
81  if (callbackFn(info))
82  impl->nativeOpSpecificPatternMap[info].push_back(pattern.get());
83  impl->nativeOpSpecificPatternList.push_back(std::move(pattern));
84  };
85 
86  for (std::unique_ptr<RewritePattern> &pat : patterns.getNativePatterns()) {
87  // Don't add patterns that haven't been enabled by the user.
88  if (!enabledPatterns.empty()) {
89  auto isEnabledFn = [&](StringRef label) {
90  return enabledPatterns.count(label);
91  };
92  if (!isEnabledFn(pat->getDebugName()) &&
93  llvm::none_of(pat->getDebugLabels(), isEnabledFn))
94  continue;
95  }
96  // Don't add patterns that have been disabled by the user.
97  if (!disabledPatterns.empty()) {
98  auto isDisabledFn = [&](StringRef label) {
99  return disabledPatterns.count(label);
100  };
101  if (isDisabledFn(pat->getDebugName()) ||
102  llvm::any_of(pat->getDebugLabels(), isDisabledFn))
103  continue;
104  }
105 
106  if (std::optional<OperationName> rootName = pat->getRootKind()) {
107  impl->nativeOpSpecificPatternMap[*rootName].push_back(pat.get());
108  impl->nativeOpSpecificPatternList.push_back(std::move(pat));
109  continue;
110  }
111  if (std::optional<TypeID> interfaceID = pat->getRootInterfaceID()) {
112  addToOpsWhen(pat, [&](RegisteredOperationName info) {
113  return info.hasInterface(*interfaceID);
114  });
115  continue;
116  }
117  if (std::optional<TypeID> traitID = pat->getRootTraitID()) {
118  addToOpsWhen(pat, [&](RegisteredOperationName info) {
119  return info.hasTrait(*traitID);
120  });
121  continue;
122  }
123  impl->nativeAnyOpPatterns.push_back(std::move(pat));
124  }
125 
126 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
127  // Generate the bytecode for the PDL patterns if any were provided.
128  PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
129  ModuleOp pdlModule = pdlPatterns.getModule();
130  if (!pdlModule)
131  return;
133  pdlPatterns.takeConfigMap();
134  if (failed(convertPDLToPDLInterp(pdlModule, configMap)))
135  llvm::report_fatal_error(
136  "failed to lower PDL pattern module to the PDL Interpreter");
137 
138  // Generate the pdl bytecode.
139  impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
140  pdlModule, pdlPatterns.takeConfigs(), configMap,
141  pdlPatterns.takeConstraintFunctions(),
142  pdlPatterns.takeRewriteFunctions());
143 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
144 }
145 
bool hasTrait() const
Returns true if the operation was registered with a particular trait, e.g.
bool hasInterface() const
Returns true if this operation has the given interface registered to it.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
The main pass manager and pipeline builder.
Definition: PassManager.h:230
This is a "type erased" representation of a registered operation.
Include the generated interface declarations.
std::unique_ptr< OperationPass< ModuleOp > > createConvertPDLToPDLInterpPass(DenseMap< Operation *, PDLPatternConfigSet * > &configMap)
Creates and returns a pass to convert PDL ops to PDL interpreter ops.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns