MLIR  22.0.0git
FrozenRewritePatternSet.cpp
Go to the documentation of this file.
1 //===- FrozenRewritePatternSet.cpp - Frozen Pattern List -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "ByteCode.h"
12 #include "mlir/Pass/PassManager.h"
13 #include <optional>
14 
15 using namespace mlir;
16 
17 // Include the PDL rewrite support.
18 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
21 
22 static LogicalResult
23 convertPDLToPDLInterp(ModuleOp pdlModule,
25  // Skip the conversion if the module doesn't contain pdl.
26  if (pdlModule.getOps<pdl::PatternOp>().empty())
27  return success();
28 
29  // Simplify the provided PDL module. Note that we can't use the canonicalizer
30  // here because it would create a cyclic dependency.
31  auto simplifyFn = [](Operation *op) {
32  // TODO: Add folding here if ever necessary.
33  if (isOpTriviallyDead(op))
34  op->erase();
35  };
36  pdlModule.getBody()->walk(simplifyFn);
37 
38  /// Lower the PDL pattern module to the interpreter dialect.
39  PassManager pdlPipeline(pdlModule->getName());
40 #ifdef NDEBUG
41  // We don't want to incur the hit of running the verifier when in release
42  // mode.
43  pdlPipeline.enableVerifier(false);
44 #endif
45  pdlPipeline.addPass(createConvertPDLToPDLInterpPass(configMap));
46  if (failed(pdlPipeline.run(pdlModule)))
47  return failure();
48 
49  // Simplify again after running the lowering pipeline.
50  pdlModule.getBody()->walk(simplifyFn);
51  return success();
52 }
53 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
54 
55 //===----------------------------------------------------------------------===//
56 // FrozenRewritePatternSet
57 //===----------------------------------------------------------------------===//
58 
60  : impl(std::make_shared<Impl>()) {}
61 
63  RewritePatternSet &&patterns, ArrayRef<std::string> disabledPatternLabels,
64  ArrayRef<std::string> enabledPatternLabels)
65  : impl(std::make_shared<Impl>()) {
66  DenseSet<StringRef> disabledPatterns, enabledPatterns;
67  disabledPatterns.insert_range(disabledPatternLabels);
68  enabledPatterns.insert_range(enabledPatternLabels);
69 
70  // Functor used to walk all of the operations registered in the context. This
71  // is useful for patterns that get applied to multiple operations, such as
72  // interface and trait based patterns.
73  std::vector<RegisteredOperationName> opInfos;
74  auto addToOpsWhen =
75  [&](std::unique_ptr<RewritePattern> &pattern,
76  function_ref<bool(RegisteredOperationName)> callbackFn) {
77  if (opInfos.empty())
78  opInfos = pattern->getContext()->getRegisteredOperations();
79  for (RegisteredOperationName info : opInfos)
80  if (callbackFn(info))
81  impl->nativeOpSpecificPatternMap[info].push_back(pattern.get());
82  impl->nativeOpSpecificPatternList.push_back(std::move(pattern));
83  };
84 
85  for (std::unique_ptr<RewritePattern> &pat : patterns.getNativePatterns()) {
86  // Don't add patterns that haven't been enabled by the user.
87  if (!enabledPatterns.empty()) {
88  auto isEnabledFn = [&](StringRef label) {
89  return enabledPatterns.count(label);
90  };
91  if (!isEnabledFn(pat->getDebugName()) &&
92  llvm::none_of(pat->getDebugLabels(), isEnabledFn))
93  continue;
94  }
95  // Don't add patterns that have been disabled by the user.
96  if (!disabledPatterns.empty()) {
97  auto isDisabledFn = [&](StringRef label) {
98  return disabledPatterns.count(label);
99  };
100  if (isDisabledFn(pat->getDebugName()) ||
101  llvm::any_of(pat->getDebugLabels(), isDisabledFn))
102  continue;
103  }
104 
105  if (std::optional<OperationName> rootName = pat->getRootKind()) {
106  impl->nativeOpSpecificPatternMap[*rootName].push_back(pat.get());
107  impl->nativeOpSpecificPatternList.push_back(std::move(pat));
108  continue;
109  }
110  if (std::optional<TypeID> interfaceID = pat->getRootInterfaceID()) {
111  addToOpsWhen(pat, [&](RegisteredOperationName info) {
112  return info.hasInterface(*interfaceID);
113  });
114  continue;
115  }
116  if (std::optional<TypeID> traitID = pat->getRootTraitID()) {
117  addToOpsWhen(pat, [&](RegisteredOperationName info) {
118  return info.hasTrait(*traitID);
119  });
120  continue;
121  }
122  impl->nativeAnyOpPatterns.push_back(std::move(pat));
123  }
124 
125 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
126  // Generate the bytecode for the PDL patterns if any were provided.
127  PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
128  ModuleOp pdlModule = pdlPatterns.getModule();
129  if (!pdlModule)
130  return;
132  pdlPatterns.takeConfigMap();
133  if (failed(convertPDLToPDLInterp(pdlModule, configMap)))
134  llvm::report_fatal_error(
135  "failed to lower PDL pattern module to the PDL Interpreter");
136 
137  // Generate the pdl bytecode.
138  impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
139  pdlModule, pdlPatterns.takeConfigs(), configMap,
140  pdlPatterns.takeConstraintFunctions(),
141  pdlPatterns.takeRewriteFunctions());
142 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
143 }
144 
bool hasTrait() const
Returns true if the operation was registered with a particular trait, e.g.
bool hasInterface() const
Returns true if this operation has the given interface registered to it.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
The main pass manager and pipeline builder.
Definition: PassManager.h:232
This is a "type erased" representation of a registered operation.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
std::unique_ptr< OperationPass< ModuleOp > > createConvertPDLToPDLInterpPass(DenseMap< Operation *, PDLPatternConfigSet * > &configMap)
Creates and returns a pass to convert PDL ops to PDL interpreter ops.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns