14 #include "llvm/ADT/SmallPtrSet.h"
15 #include "llvm/ADT/StringExtras.h"
16 #include "llvm/ADT/StringSwitch.h"
17 #include "llvm/TableGen/Error.h"
18 #include "llvm/TableGen/Record.h"
21 using namespace tblgen;
24 using llvm::SpecificBumpPtrAllocator;
27 Pred::Pred(
const Record *record) : def(record) {
28 assert(def->isSubClassOf(
"Pred") &&
29 "must be a subclass of TableGen 'Pred' class");
33 Pred::Pred(
const Init *init) {
34 if (
const auto *defInit = dyn_cast_or_null<llvm::DefInit>(init))
35 def = defInit->getDef();
40 if (
def->isSubClassOf(
"CombinedPred"))
41 return static_cast<const CombinedPred *
>(
this)->getConditionImpl();
42 if (
def->isSubClassOf(
"CPred"))
44 llvm_unreachable(
"Pred::getCondition must be overridden in subclasses");
48 return def &&
def->isSubClassOf(
"CombinedPred");
54 assert(
def->isSubClassOf(
"CPred") &&
55 "must be a subclass of Tablegen 'CPred' class");
59 assert((!
def ||
def->isSubClassOf(
"CPred")) &&
60 "must be a subclass of Tablegen 'CPred' class");
65 assert(!
isNull() &&
"null predicate does not have a condition");
66 return std::string(
def->getValueAsString(
"predExpr"));
70 assert(
def->isSubClassOf(
"CombinedPred") &&
71 "must be a subclass of Tablegen 'CombinedPred' class");
75 assert((!
def ||
def->isSubClassOf(
"CombinedPred")) &&
76 "must be a subclass of Tablegen 'CombinedPred' class");
80 assert(
def->getValue(
"kind") &&
"CombinedPred must have a value 'kind'");
81 return def->getValueAsDef(
"kind");
85 assert(
def->getValue(
"children") &&
86 "CombinedPred must have a value 'children'");
87 return def->getValueAsListOfDefs(
"children");
92 enum class PredCombinerKind {
106 PredCombinerKind
kind;
107 const Pred *predicate;
121 return PredCombinerKind::Leaf;
123 const auto &combinedPred =
static_cast<const CombinedPred &
>(pred);
125 combinedPred.getCombinerDef()->getName())
126 .Case(
"PredCombinerAnd", PredCombinerKind::And)
127 .Case(
"PredCombinerOr", PredCombinerKind::Or)
128 .Case(
"PredCombinerNot", PredCombinerKind::Not)
129 .Case(
"PredCombinerSubstLeaves", PredCombinerKind::SubstLeaves)
130 .Case(
"PredCombinerConcat", PredCombinerKind::Concat);
135 using Subst = std::pair<StringRef, StringRef>;
142 for (
const auto &subst : llvm::reverse(substitutions)) {
143 auto pos = str.find(subst.first);
144 while (pos != std::string::npos) {
145 str.replace(pos, subst.first.size(), std::string(subst.second));
148 pos += subst.second.size();
150 pos = str.find(subst.first, pos);
161 SpecificBumpPtrAllocator<PredNode> &allocator,
163 auto *rootNode = allocator.Allocate();
164 new (rootNode) PredNode;
166 rootNode->predicate = &root;
175 auto allSubstitutions = llvm::to_vector<4>(substitutions);
176 if (rootNode->kind == PredCombinerKind::SubstLeaves) {
178 allSubstitutions.push_back(
179 {substPred.getPattern(), substPred.getReplacement()});
182 }
else if (rootNode->kind == PredCombinerKind::Concat) {
183 const auto &concatPred =
static_cast<const ConcatPred &
>(root);
184 rootNode->prefix = std::string(concatPred.getPrefix());
186 rootNode->suffix = std::string(concatPred.getSuffix());
191 auto combined =
static_cast<const CombinedPred &
>(root);
192 for (
const auto *record : combined.getChildren()) {
195 rootNode->children.push_back(childTree);
211 if (knownTruePreds.count(node->predicate) != 0) {
212 node->kind = PredCombinerKind::True;
213 node->children.clear();
216 if (knownFalsePreds.count(node->predicate) != 0) {
217 node->kind = PredCombinerKind::False;
218 node->children.clear();
233 if (node->kind == PredCombinerKind::SubstLeaves) {
237 if (node->kind == PredCombinerKind::And && node->children.empty()) {
238 node->kind = PredCombinerKind::True;
242 if (node->kind == PredCombinerKind::Or && node->children.empty()) {
243 node->kind = PredCombinerKind::False;
252 std::swap(node->children, children);
254 for (
auto &child : children) {
256 auto *simplifiedChild =
260 if (node->kind != PredCombinerKind::And &&
261 node->kind != PredCombinerKind::Or) {
262 node->children.push_back(simplifiedChild);
273 auto collapseKind = node->kind == PredCombinerKind::And
274 ? PredCombinerKind::False
275 : PredCombinerKind::True;
276 auto eraseKind = node->kind == PredCombinerKind::And
277 ? PredCombinerKind::True
278 : PredCombinerKind::False;
279 const auto &collapseList =
280 node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds;
281 const auto &eraseList =
282 node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds;
283 if (simplifiedChild->kind == collapseKind ||
284 collapseList.count(simplifiedChild->predicate) != 0) {
285 node->kind = collapseKind;
286 node->children.clear();
289 if (simplifiedChild->kind == eraseKind ||
290 eraseList.count(simplifiedChild->predicate) != 0) {
293 node->children.push_back(simplifiedChild);
301 const std::string &combiner,
303 if (children.empty())
306 auto size = children.size();
308 return children.front();
311 llvm::raw_string_ostream os(str);
312 os <<
'(' << children.front() <<
')';
313 for (
unsigned i = 1; i < size; ++i) {
314 os <<
' ' << combiner <<
" (" << children[i] <<
')';
321 assert(children.size() == 1 &&
"expected exactly one child predicate of Neg");
322 return (Twine(
"!(") + children.front() + Twine(
')')).str();
329 if (root.kind == PredCombinerKind::Leaf)
331 if (root.kind == PredCombinerKind::True)
333 if (root.kind == PredCombinerKind::False)
338 childExpressions.reserve(root.children.size());
339 for (
const auto &child : root.children)
343 if (root.kind == PredCombinerKind::And)
345 if (root.kind == PredCombinerKind::Or)
347 if (root.kind == PredCombinerKind::Not)
349 if (root.kind == PredCombinerKind::Concat) {
350 assert(childExpressions.size() == 1 &&
351 "ConcatPred should only have one child");
352 return root.prefix + childExpressions.front() + root.suffix;
356 if (root.kind == PredCombinerKind::SubstLeaves) {
357 assert(childExpressions.size() == 1 &&
358 "substitution predicate must have one child");
359 return childExpressions[0];
362 llvm::PrintFatalError(root.predicate->getLoc(),
"unsupported predicate kind");
366 SpecificBumpPtrAllocator<PredNode> allocator;
377 return def->getValueAsString(
"pattern");
381 return def->getValueAsString(
"replacement");
385 return def->getValueAsString(
"prefix");
389 return def->getValueAsString(
"suffix");
union mlir::linalg::@1242::ArityGroupAndKind::Kind kind
static std::string combineBinary(ArrayRef< std::string > children, const std::string &combiner, std::string init)
static std::string getCombinedCondition(const PredNode &root)
static PredCombinerKind getPredCombinerKind(const Pred &pred)
static void performSubstitutions(std::string &str, ArrayRef< Subst > substitutions)
Perform the given substitutions on 'str' in-place.
static PredNode * buildPredicateTree(const Pred &root, SpecificBumpPtrAllocator< PredNode > &allocator, ArrayRef< Subst > substitutions)
static PredNode * propagateGroundTruth(PredNode *node, const llvm::SmallPtrSetImpl< Pred * > &knownTruePreds, const llvm::SmallPtrSetImpl< Pred * > &knownFalsePreds)
static std::string combineNot(ArrayRef< std::string > children)
std::string getConditionImpl() const
CPred(const llvm::Record *record)
CombinedPred(const llvm::Record *record)
const llvm::Record * getCombinerDef() const
std::vector< const llvm::Record * > getChildren() const
std::string getConditionImpl() const
StringRef getSuffix() const
StringRef getPrefix() const
std::string getCondition() const
ArrayRef< SMLoc > getLoc() const
StringRef getReplacement() const
StringRef getPattern() const
Include the generated interface declarations.