MLIR 22.0.0git
FrozenRewritePatternSet.cpp
Go to the documentation of this file.
1//===- FrozenRewritePatternSet.cpp - Frozen Pattern List -------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "ByteCode.h"
13#include <optional>
14
15using namespace mlir;
16
17// Include the PDL rewrite support.
18#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
21
22static LogicalResult
23convertPDLToPDLInterp(ModuleOp pdlModule,
25 // Skip the conversion if the module doesn't contain pdl.
26 if (pdlModule.getOps<pdl::PatternOp>().empty())
27 return success();
28
29 // Simplify the provided PDL module. Note that we can't use the canonicalizer
30 // here because it would create a cyclic dependency.
31 auto simplifyFn = [](Operation *op) {
32 // TODO: Add folding here if ever necessary.
33 if (isOpTriviallyDead(op))
34 op->erase();
35 };
36 pdlModule.getBody()->walk(simplifyFn);
37
38 /// Lower the PDL pattern module to the interpreter dialect.
39 PassManager pdlPipeline(pdlModule->getName());
40#ifdef NDEBUG
41 // We don't want to incur the hit of running the verifier when in release
42 // mode.
43 pdlPipeline.enableVerifier(false);
44#endif
45 pdlPipeline.addPass(createConvertPDLToPDLInterpPass(configMap));
46 if (failed(pdlPipeline.run(pdlModule)))
47 return failure();
48
49 // Simplify again after running the lowering pipeline.
50 pdlModule.getBody()->walk(simplifyFn);
51 return success();
52}
53#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
54
55//===----------------------------------------------------------------------===//
56// FrozenRewritePatternSet
57//===----------------------------------------------------------------------===//
58
60 : impl(std::make_shared<Impl>()) {}
61
63 RewritePatternSet &&patterns, ArrayRef<std::string> disabledPatternLabels,
64 ArrayRef<std::string> enabledPatternLabels)
65 : impl(std::make_shared<Impl>()) {
66 DenseSet<StringRef> disabledPatterns, enabledPatterns;
67 disabledPatterns.insert_range(disabledPatternLabels);
68 enabledPatterns.insert_range(enabledPatternLabels);
69
70 // Functor used to walk all of the operations registered in the context. This
71 // is useful for patterns that get applied to multiple operations, such as
72 // interface and trait based patterns.
73 std::vector<RegisteredOperationName> opInfos;
74 auto addToOpsWhen =
75 [&](std::unique_ptr<RewritePattern> &pattern,
77 if (opInfos.empty())
78 opInfos = pattern->getContext()->getRegisteredOperations();
79 for (RegisteredOperationName info : opInfos)
80 if (callbackFn(info))
81 impl->nativeOpSpecificPatternMap[info].push_back(pattern.get());
82 impl->nativeOpSpecificPatternList.push_back(std::move(pattern));
83 };
84
85 for (std::unique_ptr<RewritePattern> &pat : patterns.getNativePatterns()) {
86 // Don't add patterns that haven't been enabled by the user.
87 if (!enabledPatterns.empty()) {
88 auto isEnabledFn = [&](StringRef label) {
89 return enabledPatterns.count(label);
90 };
91 if (!isEnabledFn(pat->getDebugName()) &&
92 llvm::none_of(pat->getDebugLabels(), isEnabledFn))
93 continue;
94 }
95 // Don't add patterns that have been disabled by the user.
96 if (!disabledPatterns.empty()) {
97 auto isDisabledFn = [&](StringRef label) {
98 return disabledPatterns.count(label);
99 };
100 if (isDisabledFn(pat->getDebugName()) ||
101 llvm::any_of(pat->getDebugLabels(), isDisabledFn))
102 continue;
103 }
104
105 if (std::optional<OperationName> rootName = pat->getRootKind()) {
106 impl->nativeOpSpecificPatternMap[*rootName].push_back(pat.get());
107 impl->nativeOpSpecificPatternList.push_back(std::move(pat));
108 continue;
109 }
110 if (std::optional<TypeID> interfaceID = pat->getRootInterfaceID()) {
111 addToOpsWhen(pat, [&](RegisteredOperationName info) {
112 return info.hasInterface(*interfaceID);
113 });
114 continue;
115 }
116 if (std::optional<TypeID> traitID = pat->getRootTraitID()) {
117 addToOpsWhen(pat, [&](RegisteredOperationName info) {
118 return info.hasTrait(*traitID);
119 });
120 continue;
121 }
122 impl->nativeAnyOpPatterns.push_back(std::move(pat));
123 }
124
125#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
126 // Generate the bytecode for the PDL patterns if any were provided.
127 PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
128 ModuleOp pdlModule = pdlPatterns.getModule();
129 if (!pdlModule)
130 return;
132 pdlPatterns.takeConfigMap();
133 if (failed(convertPDLToPDLInterp(pdlModule, configMap)))
134 llvm::report_fatal_error(
135 "failed to lower PDL pattern module to the PDL Interpreter");
136
137 // Generate the pdl bytecode.
138 impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
139 pdlModule, pdlPatterns.takeConfigs(), configMap,
140 pdlPatterns.takeConstraintFunctions(),
141 pdlPatterns.takeRewriteFunctions());
142#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
143}
144
return success()
if(!isCopyOut)
bool hasTrait() const
Returns true if the operation was registered with a particular trait, e.g.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
The main pass manager and pipeline builder.
This is a "type erased" representation of a registered operation.
Attribute collections provide a dictionary-like interface.
Definition Traits.h:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
std::unique_ptr<::mlir::Pass > createConvertPDLToPDLInterpPass()
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152