MLIR  19.0.0git
Predicate.cpp
Go to the documentation of this file.
1 //===- Predicate.cpp - Predicate class ------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Wrapper around predicates defined in TableGen.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 #include "llvm/ADT/SmallPtrSet.h"
15 #include "llvm/ADT/StringExtras.h"
16 #include "llvm/ADT/StringSwitch.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/TableGen/Error.h"
19 #include "llvm/TableGen/Record.h"
20 
21 using namespace mlir;
22 using namespace tblgen;
23 
24 // Construct a Predicate from a record.
25 Pred::Pred(const llvm::Record *record) : def(record) {
26  assert(def->isSubClassOf("Pred") &&
27  "must be a subclass of TableGen 'Pred' class");
28 }
29 
30 // Construct a Predicate from an initializer.
31 Pred::Pred(const llvm::Init *init) {
32  if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(init))
33  def = defInit->getDef();
34 }
35 
36 std::string Pred::getCondition() const {
37  // Static dispatch to subclasses.
38  if (def->isSubClassOf("CombinedPred"))
39  return static_cast<const CombinedPred *>(this)->getConditionImpl();
40  if (def->isSubClassOf("CPred"))
41  return static_cast<const CPred *>(this)->getConditionImpl();
42  llvm_unreachable("Pred::getCondition must be overridden in subclasses");
43 }
44 
45 bool Pred::isCombined() const {
46  return def && def->isSubClassOf("CombinedPred");
47 }
48 
49 ArrayRef<SMLoc> Pred::getLoc() const { return def->getLoc(); }
50 
51 CPred::CPred(const llvm::Record *record) : Pred(record) {
52  assert(def->isSubClassOf("CPred") &&
53  "must be a subclass of Tablegen 'CPred' class");
54 }
55 
56 CPred::CPred(const llvm::Init *init) : Pred(init) {
57  assert((!def || def->isSubClassOf("CPred")) &&
58  "must be a subclass of Tablegen 'CPred' class");
59 }
60 
61 // Get condition of the C Predicate.
62 std::string CPred::getConditionImpl() const {
63  assert(!isNull() && "null predicate does not have a condition");
64  return std::string(def->getValueAsString("predExpr"));
65 }
66 
67 CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) {
68  assert(def->isSubClassOf("CombinedPred") &&
69  "must be a subclass of Tablegen 'CombinedPred' class");
70 }
71 
72 CombinedPred::CombinedPred(const llvm::Init *init) : Pred(init) {
73  assert((!def || def->isSubClassOf("CombinedPred")) &&
74  "must be a subclass of Tablegen 'CombinedPred' class");
75 }
76 
77 const llvm::Record *CombinedPred::getCombinerDef() const {
78  assert(def->getValue("kind") && "CombinedPred must have a value 'kind'");
79  return def->getValueAsDef("kind");
80 }
81 
82 std::vector<llvm::Record *> CombinedPred::getChildren() const {
83  assert(def->getValue("children") &&
84  "CombinedPred must have a value 'children'");
85  return def->getValueAsListOfDefs("children");
86 }
87 
88 namespace {
89 // Kinds of nodes in a logical predicate tree.
90 enum class PredCombinerKind {
91  Leaf,
92  And,
93  Or,
94  Not,
95  SubstLeaves,
96  Concat,
97  // Special kinds that are used in simplification.
98  False,
99  True
100 };
101 
102 // A node in a logical predicate tree.
103 struct PredNode {
104  PredCombinerKind kind;
105  const Pred *predicate;
107  std::string expr;
108 
109  // Prefix and suffix are used by ConcatPred.
110  std::string prefix;
111  std::string suffix;
112 };
113 } // namespace
114 
115 // Get a predicate tree node kind based on the kind used in the predicate
116 // TableGen record.
117 static PredCombinerKind getPredCombinerKind(const Pred &pred) {
118  if (!pred.isCombined())
119  return PredCombinerKind::Leaf;
120 
121  const auto &combinedPred = static_cast<const CombinedPred &>(pred);
123  combinedPred.getCombinerDef()->getName())
124  .Case("PredCombinerAnd", PredCombinerKind::And)
125  .Case("PredCombinerOr", PredCombinerKind::Or)
126  .Case("PredCombinerNot", PredCombinerKind::Not)
127  .Case("PredCombinerSubstLeaves", PredCombinerKind::SubstLeaves)
128  .Case("PredCombinerConcat", PredCombinerKind::Concat);
129 }
130 
131 namespace {
132 // Substitution<pattern, replacement>.
133 using Subst = std::pair<StringRef, StringRef>;
134 } // namespace
135 
136 /// Perform the given substitutions on 'str' in-place.
137 static void performSubstitutions(std::string &str,
138  ArrayRef<Subst> substitutions) {
139  // Apply all parent substitutions from innermost to outermost.
140  for (const auto &subst : llvm::reverse(substitutions)) {
141  auto pos = str.find(std::string(subst.first));
142  while (pos != std::string::npos) {
143  str.replace(pos, subst.first.size(), std::string(subst.second));
144  // Skip the newly inserted substring, which itself may consider the
145  // pattern to match.
146  pos += subst.second.size();
147  // Find the next possible match position.
148  pos = str.find(std::string(subst.first), pos);
149  }
150  }
151 }
152 
153 // Build the predicate tree starting from the top-level predicate, which may
154 // have children, and perform leaf substitutions inplace. Note that after
155 // substitution, nodes are still pointing to the original TableGen record.
156 // All nodes are created within "allocator".
157 static PredNode *
159  llvm::SpecificBumpPtrAllocator<PredNode> &allocator,
160  ArrayRef<Subst> substitutions) {
161  auto *rootNode = allocator.Allocate();
162  new (rootNode) PredNode;
163  rootNode->kind = getPredCombinerKind(root);
164  rootNode->predicate = &root;
165  if (!root.isCombined()) {
166  rootNode->expr = root.getCondition();
167  performSubstitutions(rootNode->expr, substitutions);
168  return rootNode;
169  }
170 
171  // If the current combined predicate is a leaf substitution, append it to the
172  // list before continuing.
173  auto allSubstitutions = llvm::to_vector<4>(substitutions);
174  if (rootNode->kind == PredCombinerKind::SubstLeaves) {
175  const auto &substPred = static_cast<const SubstLeavesPred &>(root);
176  allSubstitutions.push_back(
177  {substPred.getPattern(), substPred.getReplacement()});
178 
179  // If the current predicate is a ConcatPred, record the prefix and suffix.
180  } else if (rootNode->kind == PredCombinerKind::Concat) {
181  const auto &concatPred = static_cast<const ConcatPred &>(root);
182  rootNode->prefix = std::string(concatPred.getPrefix());
183  performSubstitutions(rootNode->prefix, substitutions);
184  rootNode->suffix = std::string(concatPred.getSuffix());
185  performSubstitutions(rootNode->suffix, substitutions);
186  }
187 
188  // Build child subtrees.
189  auto combined = static_cast<const CombinedPred &>(root);
190  for (const auto *record : combined.getChildren()) {
191  auto *childTree =
192  buildPredicateTree(Pred(record), allocator, allSubstitutions);
193  rootNode->children.push_back(childTree);
194  }
195  return rootNode;
196 }
197 
198 // Simplify a predicate tree rooted at "node" using the predicates that are
199 // known to be true(false). For AND(OR) combined predicates, if any of the
200 // children is known to be false(true), the result is also false(true).
201 // Furthermore, for AND(OR) combined predicates, children that are known to be
202 // true(false) don't have to be checked dynamically.
203 static PredNode *
204 propagateGroundTruth(PredNode *node,
205  const llvm::SmallPtrSetImpl<Pred *> &knownTruePreds,
206  const llvm::SmallPtrSetImpl<Pred *> &knownFalsePreds) {
207  // If the current predicate is known to be true or false, change the kind of
208  // the node and return immediately.
209  if (knownTruePreds.count(node->predicate) != 0) {
210  node->kind = PredCombinerKind::True;
211  node->children.clear();
212  return node;
213  }
214  if (knownFalsePreds.count(node->predicate) != 0) {
215  node->kind = PredCombinerKind::False;
216  node->children.clear();
217  return node;
218  }
219 
220  // If the current node is a substitution, stop recursion now.
221  // The expressions in the leaves below this node were rewritten, but the nodes
222  // still point to the original predicate records. While the original
223  // predicate may be known to be true or false, it is not necessarily the case
224  // after rewriting.
225  // TODO: we can support ground truth for rewritten
226  // predicates by either (a) having our own unique'ing of the predicates
227  // instead of relying on TableGen record pointers or (b) taking ground truth
228  // values optionally prefixed with a list of substitutions to apply, e.g.
229  // "predX is true by itself as well as predSubY leaf substitution had been
230  // applied to it".
231  if (node->kind == PredCombinerKind::SubstLeaves) {
232  return node;
233  }
234 
235  // Otherwise, look at child nodes.
236 
237  // Move child nodes into some local variable so that they can be optimized
238  // separately and re-added if necessary.
240  std::swap(node->children, children);
241 
242  for (auto &child : children) {
243  // First, simplify the child. This maintains the predicate as it was.
244  auto *simplifiedChild =
245  propagateGroundTruth(child, knownTruePreds, knownFalsePreds);
246 
247  // Just add the child if we don't know how to simplify the current node.
248  if (node->kind != PredCombinerKind::And &&
249  node->kind != PredCombinerKind::Or) {
250  node->children.push_back(simplifiedChild);
251  continue;
252  }
253 
254  // Second, based on the type define which known values of child predicates
255  // immediately collapse this predicate to a known value, and which others
256  // may be safely ignored.
257  // OR(..., True, ...) = True
258  // OR(..., False, ...) = OR(..., ...)
259  // AND(..., False, ...) = False
260  // AND(..., True, ...) = AND(..., ...)
261  auto collapseKind = node->kind == PredCombinerKind::And
262  ? PredCombinerKind::False
263  : PredCombinerKind::True;
264  auto eraseKind = node->kind == PredCombinerKind::And
265  ? PredCombinerKind::True
266  : PredCombinerKind::False;
267  const auto &collapseList =
268  node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds;
269  const auto &eraseList =
270  node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds;
271  if (simplifiedChild->kind == collapseKind ||
272  collapseList.count(simplifiedChild->predicate) != 0) {
273  node->kind = collapseKind;
274  node->children.clear();
275  return node;
276  }
277  if (simplifiedChild->kind == eraseKind ||
278  eraseList.count(simplifiedChild->predicate) != 0) {
279  continue;
280  }
281  node->children.push_back(simplifiedChild);
282  }
283  return node;
284 }
285 
286 // Combine a list of predicate expressions using a binary combiner. If a list
287 // is empty, return "init".
288 static std::string combineBinary(ArrayRef<std::string> children,
289  const std::string &combiner,
290  std::string init) {
291  if (children.empty())
292  return init;
293 
294  auto size = children.size();
295  if (size == 1)
296  return children.front();
297 
298  std::string str;
299  llvm::raw_string_ostream os(str);
300  os << '(' << children.front() << ')';
301  for (unsigned i = 1; i < size; ++i) {
302  os << ' ' << combiner << " (" << children[i] << ')';
303  }
304  return os.str();
305 }
306 
307 // Prepend negation to the only condition in the predicate expression list.
308 static std::string combineNot(ArrayRef<std::string> children) {
309  assert(children.size() == 1 && "expected exactly one child predicate of Neg");
310  return (Twine("!(") + children.front() + Twine(')')).str();
311 }
312 
313 // Recursively traverse the predicate tree in depth-first post-order and build
314 // the final expression.
315 static std::string getCombinedCondition(const PredNode &root) {
316  // Immediately return for non-combiner predicates that don't have children.
317  if (root.kind == PredCombinerKind::Leaf)
318  return root.expr;
319  if (root.kind == PredCombinerKind::True)
320  return "true";
321  if (root.kind == PredCombinerKind::False)
322  return "false";
323 
324  // Recurse into children.
325  llvm::SmallVector<std::string, 4> childExpressions;
326  childExpressions.reserve(root.children.size());
327  for (const auto &child : root.children)
328  childExpressions.push_back(getCombinedCondition(*child));
329 
330  // Combine the expressions based on the predicate node kind.
331  if (root.kind == PredCombinerKind::And)
332  return combineBinary(childExpressions, "&&", "true");
333  if (root.kind == PredCombinerKind::Or)
334  return combineBinary(childExpressions, "||", "false");
335  if (root.kind == PredCombinerKind::Not)
336  return combineNot(childExpressions);
337  if (root.kind == PredCombinerKind::Concat) {
338  assert(childExpressions.size() == 1 &&
339  "ConcatPred should only have one child");
340  return root.prefix + childExpressions.front() + root.suffix;
341  }
342 
343  // Substitutions were applied before so just ignore them.
344  if (root.kind == PredCombinerKind::SubstLeaves) {
345  assert(childExpressions.size() == 1 &&
346  "substitution predicate must have one child");
347  return childExpressions[0];
348  }
349 
350  llvm::PrintFatalError(root.predicate->getLoc(), "unsupported predicate kind");
351 }
352 
353 std::string CombinedPred::getConditionImpl() const {
354  llvm::SpecificBumpPtrAllocator<PredNode> allocator;
355  auto *predicateTree = buildPredicateTree(*this, allocator, {});
356  predicateTree =
357  propagateGroundTruth(predicateTree,
358  /*knownTruePreds=*/llvm::SmallPtrSet<Pred *, 2>(),
359  /*knownFalsePreds=*/llvm::SmallPtrSet<Pred *, 2>());
360 
361  return getCombinedCondition(*predicateTree);
362 }
363 
364 StringRef SubstLeavesPred::getPattern() const {
365  return def->getValueAsString("pattern");
366 }
367 
369  return def->getValueAsString("replacement");
370 }
371 
372 StringRef ConcatPred::getPrefix() const {
373  return def->getValueAsString("prefix");
374 }
375 
376 StringRef ConcatPred::getSuffix() const {
377  return def->getValueAsString("suffix");
378 }
static PredNode * buildPredicateTree(const Pred &root, llvm::SpecificBumpPtrAllocator< PredNode > &allocator, ArrayRef< Subst > substitutions)
Definition: Predicate.cpp:158
static std::string combineBinary(ArrayRef< std::string > children, const std::string &combiner, std::string init)
Definition: Predicate.cpp:288
static std::string getCombinedCondition(const PredNode &root)
Definition: Predicate.cpp:315
static PredCombinerKind getPredCombinerKind(const Pred &pred)
Definition: Predicate.cpp:117
static void performSubstitutions(std::string &str, ArrayRef< Subst > substitutions)
Perform the given substitutions on 'str' in-place.
Definition: Predicate.cpp:137
static PredNode * propagateGroundTruth(PredNode *node, const llvm::SmallPtrSetImpl< Pred * > &knownTruePreds, const llvm::SmallPtrSetImpl< Pred * > &knownFalsePreds)
Definition: Predicate.cpp:204
static std::string combineNot(ArrayRef< std::string > children)
Definition: Predicate.cpp:308
std::string getConditionImpl() const
Definition: Predicate.cpp:62
CPred(const llvm::Record *record)
Definition: Predicate.cpp:51
const llvm::Record * getCombinerDef() const
Definition: Predicate.cpp:77
std::vector< llvm::Record * > getChildren() const
Definition: Predicate.cpp:82
std::string getConditionImpl() const
Definition: Predicate.cpp:353
CombinedPred(const llvm::Record *record)
Definition: Predicate.cpp:67
StringRef getSuffix() const
Definition: Predicate.cpp:376
StringRef getPrefix() const
Definition: Predicate.cpp:372
bool isNull() const
Definition: Predicate.h:46
std::string getCondition() const
Definition: Predicate.cpp:36
bool isCombined() const
Definition: Predicate.cpp:45
ArrayRef< SMLoc > getLoc() const
Definition: Predicate.cpp:49
const llvm::Record * def
Definition: Predicate.h:75
StringRef getReplacement() const
Definition: Predicate.cpp:368
StringRef getPattern() const
Definition: Predicate.cpp:364
Include the generated interface declarations.