MLIR  20.0.0git
Predicate.cpp
Go to the documentation of this file.
1 //===- Predicate.cpp - Predicate class ------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Wrapper around predicates defined in TableGen.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 #include "llvm/ADT/SmallPtrSet.h"
15 #include "llvm/ADT/StringExtras.h"
16 #include "llvm/ADT/StringSwitch.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "llvm/TableGen/Error.h"
19 #include "llvm/TableGen/Record.h"
20 
21 using namespace mlir;
22 using namespace tblgen;
23 using llvm::Init;
24 using llvm::Record;
25 using llvm::SpecificBumpPtrAllocator;
26 
27 // Construct a Predicate from a record.
28 Pred::Pred(const Record *record) : def(record) {
29  assert(def->isSubClassOf("Pred") &&
30  "must be a subclass of TableGen 'Pred' class");
31 }
32 
33 // Construct a Predicate from an initializer.
34 Pred::Pred(const Init *init) {
35  if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(init))
36  def = defInit->getDef();
37 }
38 
39 std::string Pred::getCondition() const {
40  // Static dispatch to subclasses.
41  if (def->isSubClassOf("CombinedPred"))
42  return static_cast<const CombinedPred *>(this)->getConditionImpl();
43  if (def->isSubClassOf("CPred"))
44  return static_cast<const CPred *>(this)->getConditionImpl();
45  llvm_unreachable("Pred::getCondition must be overridden in subclasses");
46 }
47 
48 bool Pred::isCombined() const {
49  return def && def->isSubClassOf("CombinedPred");
50 }
51 
52 ArrayRef<SMLoc> Pred::getLoc() const { return def->getLoc(); }
53 
54 CPred::CPred(const Record *record) : Pred(record) {
55  assert(def->isSubClassOf("CPred") &&
56  "must be a subclass of Tablegen 'CPred' class");
57 }
58 
59 CPred::CPred(const Init *init) : Pred(init) {
60  assert((!def || def->isSubClassOf("CPred")) &&
61  "must be a subclass of Tablegen 'CPred' class");
62 }
63 
64 // Get condition of the C Predicate.
65 std::string CPred::getConditionImpl() const {
66  assert(!isNull() && "null predicate does not have a condition");
67  return std::string(def->getValueAsString("predExpr"));
68 }
69 
70 CombinedPred::CombinedPred(const Record *record) : Pred(record) {
71  assert(def->isSubClassOf("CombinedPred") &&
72  "must be a subclass of Tablegen 'CombinedPred' class");
73 }
74 
75 CombinedPred::CombinedPred(const Init *init) : Pred(init) {
76  assert((!def || def->isSubClassOf("CombinedPred")) &&
77  "must be a subclass of Tablegen 'CombinedPred' class");
78 }
79 
80 const Record *CombinedPred::getCombinerDef() const {
81  assert(def->getValue("kind") && "CombinedPred must have a value 'kind'");
82  return def->getValueAsDef("kind");
83 }
84 
85 std::vector<const Record *> CombinedPred::getChildren() const {
86  assert(def->getValue("children") &&
87  "CombinedPred must have a value 'children'");
88  return def->getValueAsListOfDefs("children");
89 }
90 
91 namespace {
92 // Kinds of nodes in a logical predicate tree.
93 enum class PredCombinerKind {
94  Leaf,
95  And,
96  Or,
97  Not,
98  SubstLeaves,
99  Concat,
100  // Special kinds that are used in simplification.
101  False,
102  True
103 };
104 
105 // A node in a logical predicate tree.
106 struct PredNode {
107  PredCombinerKind kind;
108  const Pred *predicate;
110  std::string expr;
111 
112  // Prefix and suffix are used by ConcatPred.
113  std::string prefix;
114  std::string suffix;
115 };
116 } // namespace
117 
118 // Get a predicate tree node kind based on the kind used in the predicate
119 // TableGen record.
120 static PredCombinerKind getPredCombinerKind(const Pred &pred) {
121  if (!pred.isCombined())
122  return PredCombinerKind::Leaf;
123 
124  const auto &combinedPred = static_cast<const CombinedPred &>(pred);
126  combinedPred.getCombinerDef()->getName())
127  .Case("PredCombinerAnd", PredCombinerKind::And)
128  .Case("PredCombinerOr", PredCombinerKind::Or)
129  .Case("PredCombinerNot", PredCombinerKind::Not)
130  .Case("PredCombinerSubstLeaves", PredCombinerKind::SubstLeaves)
131  .Case("PredCombinerConcat", PredCombinerKind::Concat);
132 }
133 
134 namespace {
135 // Substitution<pattern, replacement>.
136 using Subst = std::pair<StringRef, StringRef>;
137 } // namespace
138 
139 /// Perform the given substitutions on 'str' in-place.
140 static void performSubstitutions(std::string &str,
141  ArrayRef<Subst> substitutions) {
142  // Apply all parent substitutions from innermost to outermost.
143  for (const auto &subst : llvm::reverse(substitutions)) {
144  auto pos = str.find(std::string(subst.first));
145  while (pos != std::string::npos) {
146  str.replace(pos, subst.first.size(), std::string(subst.second));
147  // Skip the newly inserted substring, which itself may consider the
148  // pattern to match.
149  pos += subst.second.size();
150  // Find the next possible match position.
151  pos = str.find(std::string(subst.first), pos);
152  }
153  }
154 }
155 
156 // Build the predicate tree starting from the top-level predicate, which may
157 // have children, and perform leaf substitutions inplace. Note that after
158 // substitution, nodes are still pointing to the original TableGen record.
159 // All nodes are created within "allocator".
160 static PredNode *
162  SpecificBumpPtrAllocator<PredNode> &allocator,
163  ArrayRef<Subst> substitutions) {
164  auto *rootNode = allocator.Allocate();
165  new (rootNode) PredNode;
166  rootNode->kind = getPredCombinerKind(root);
167  rootNode->predicate = &root;
168  if (!root.isCombined()) {
169  rootNode->expr = root.getCondition();
170  performSubstitutions(rootNode->expr, substitutions);
171  return rootNode;
172  }
173 
174  // If the current combined predicate is a leaf substitution, append it to the
175  // list before continuing.
176  auto allSubstitutions = llvm::to_vector<4>(substitutions);
177  if (rootNode->kind == PredCombinerKind::SubstLeaves) {
178  const auto &substPred = static_cast<const SubstLeavesPred &>(root);
179  allSubstitutions.push_back(
180  {substPred.getPattern(), substPred.getReplacement()});
181 
182  // If the current predicate is a ConcatPred, record the prefix and suffix.
183  } else if (rootNode->kind == PredCombinerKind::Concat) {
184  const auto &concatPred = static_cast<const ConcatPred &>(root);
185  rootNode->prefix = std::string(concatPred.getPrefix());
186  performSubstitutions(rootNode->prefix, substitutions);
187  rootNode->suffix = std::string(concatPred.getSuffix());
188  performSubstitutions(rootNode->suffix, substitutions);
189  }
190 
191  // Build child subtrees.
192  auto combined = static_cast<const CombinedPred &>(root);
193  for (const auto *record : combined.getChildren()) {
194  auto *childTree =
195  buildPredicateTree(Pred(record), allocator, allSubstitutions);
196  rootNode->children.push_back(childTree);
197  }
198  return rootNode;
199 }
200 
201 // Simplify a predicate tree rooted at "node" using the predicates that are
202 // known to be true(false). For AND(OR) combined predicates, if any of the
203 // children is known to be false(true), the result is also false(true).
204 // Furthermore, for AND(OR) combined predicates, children that are known to be
205 // true(false) don't have to be checked dynamically.
206 static PredNode *
207 propagateGroundTruth(PredNode *node,
208  const llvm::SmallPtrSetImpl<Pred *> &knownTruePreds,
209  const llvm::SmallPtrSetImpl<Pred *> &knownFalsePreds) {
210  // If the current predicate is known to be true or false, change the kind of
211  // the node and return immediately.
212  if (knownTruePreds.count(node->predicate) != 0) {
213  node->kind = PredCombinerKind::True;
214  node->children.clear();
215  return node;
216  }
217  if (knownFalsePreds.count(node->predicate) != 0) {
218  node->kind = PredCombinerKind::False;
219  node->children.clear();
220  return node;
221  }
222 
223  // If the current node is a substitution, stop recursion now.
224  // The expressions in the leaves below this node were rewritten, but the nodes
225  // still point to the original predicate records. While the original
226  // predicate may be known to be true or false, it is not necessarily the case
227  // after rewriting.
228  // TODO: we can support ground truth for rewritten
229  // predicates by either (a) having our own unique'ing of the predicates
230  // instead of relying on TableGen record pointers or (b) taking ground truth
231  // values optionally prefixed with a list of substitutions to apply, e.g.
232  // "predX is true by itself as well as predSubY leaf substitution had been
233  // applied to it".
234  if (node->kind == PredCombinerKind::SubstLeaves) {
235  return node;
236  }
237 
238  // Otherwise, look at child nodes.
239 
240  // Move child nodes into some local variable so that they can be optimized
241  // separately and re-added if necessary.
243  std::swap(node->children, children);
244 
245  for (auto &child : children) {
246  // First, simplify the child. This maintains the predicate as it was.
247  auto *simplifiedChild =
248  propagateGroundTruth(child, knownTruePreds, knownFalsePreds);
249 
250  // Just add the child if we don't know how to simplify the current node.
251  if (node->kind != PredCombinerKind::And &&
252  node->kind != PredCombinerKind::Or) {
253  node->children.push_back(simplifiedChild);
254  continue;
255  }
256 
257  // Second, based on the type define which known values of child predicates
258  // immediately collapse this predicate to a known value, and which others
259  // may be safely ignored.
260  // OR(..., True, ...) = True
261  // OR(..., False, ...) = OR(..., ...)
262  // AND(..., False, ...) = False
263  // AND(..., True, ...) = AND(..., ...)
264  auto collapseKind = node->kind == PredCombinerKind::And
265  ? PredCombinerKind::False
266  : PredCombinerKind::True;
267  auto eraseKind = node->kind == PredCombinerKind::And
268  ? PredCombinerKind::True
269  : PredCombinerKind::False;
270  const auto &collapseList =
271  node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds;
272  const auto &eraseList =
273  node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds;
274  if (simplifiedChild->kind == collapseKind ||
275  collapseList.count(simplifiedChild->predicate) != 0) {
276  node->kind = collapseKind;
277  node->children.clear();
278  return node;
279  }
280  if (simplifiedChild->kind == eraseKind ||
281  eraseList.count(simplifiedChild->predicate) != 0) {
282  continue;
283  }
284  node->children.push_back(simplifiedChild);
285  }
286  return node;
287 }
288 
289 // Combine a list of predicate expressions using a binary combiner. If a list
290 // is empty, return "init".
291 static std::string combineBinary(ArrayRef<std::string> children,
292  const std::string &combiner,
293  std::string init) {
294  if (children.empty())
295  return init;
296 
297  auto size = children.size();
298  if (size == 1)
299  return children.front();
300 
301  std::string str;
302  llvm::raw_string_ostream os(str);
303  os << '(' << children.front() << ')';
304  for (unsigned i = 1; i < size; ++i) {
305  os << ' ' << combiner << " (" << children[i] << ')';
306  }
307  return str;
308 }
309 
310 // Prepend negation to the only condition in the predicate expression list.
311 static std::string combineNot(ArrayRef<std::string> children) {
312  assert(children.size() == 1 && "expected exactly one child predicate of Neg");
313  return (Twine("!(") + children.front() + Twine(')')).str();
314 }
315 
316 // Recursively traverse the predicate tree in depth-first post-order and build
317 // the final expression.
318 static std::string getCombinedCondition(const PredNode &root) {
319  // Immediately return for non-combiner predicates that don't have children.
320  if (root.kind == PredCombinerKind::Leaf)
321  return root.expr;
322  if (root.kind == PredCombinerKind::True)
323  return "true";
324  if (root.kind == PredCombinerKind::False)
325  return "false";
326 
327  // Recurse into children.
328  llvm::SmallVector<std::string, 4> childExpressions;
329  childExpressions.reserve(root.children.size());
330  for (const auto &child : root.children)
331  childExpressions.push_back(getCombinedCondition(*child));
332 
333  // Combine the expressions based on the predicate node kind.
334  if (root.kind == PredCombinerKind::And)
335  return combineBinary(childExpressions, "&&", "true");
336  if (root.kind == PredCombinerKind::Or)
337  return combineBinary(childExpressions, "||", "false");
338  if (root.kind == PredCombinerKind::Not)
339  return combineNot(childExpressions);
340  if (root.kind == PredCombinerKind::Concat) {
341  assert(childExpressions.size() == 1 &&
342  "ConcatPred should only have one child");
343  return root.prefix + childExpressions.front() + root.suffix;
344  }
345 
346  // Substitutions were applied before so just ignore them.
347  if (root.kind == PredCombinerKind::SubstLeaves) {
348  assert(childExpressions.size() == 1 &&
349  "substitution predicate must have one child");
350  return childExpressions[0];
351  }
352 
353  llvm::PrintFatalError(root.predicate->getLoc(), "unsupported predicate kind");
354 }
355 
356 std::string CombinedPred::getConditionImpl() const {
357  SpecificBumpPtrAllocator<PredNode> allocator;
358  auto *predicateTree = buildPredicateTree(*this, allocator, {});
359  predicateTree =
360  propagateGroundTruth(predicateTree,
361  /*knownTruePreds=*/llvm::SmallPtrSet<Pred *, 2>(),
362  /*knownFalsePreds=*/llvm::SmallPtrSet<Pred *, 2>());
363 
364  return getCombinedCondition(*predicateTree);
365 }
366 
367 StringRef SubstLeavesPred::getPattern() const {
368  return def->getValueAsString("pattern");
369 }
370 
372  return def->getValueAsString("replacement");
373 }
374 
375 StringRef ConcatPred::getPrefix() const {
376  return def->getValueAsString("prefix");
377 }
378 
379 StringRef ConcatPred::getSuffix() const {
380  return def->getValueAsString("suffix");
381 }
static std::string combineBinary(ArrayRef< std::string > children, const std::string &combiner, std::string init)
Definition: Predicate.cpp:291
static std::string getCombinedCondition(const PredNode &root)
Definition: Predicate.cpp:318
static PredCombinerKind getPredCombinerKind(const Pred &pred)
Definition: Predicate.cpp:120
static void performSubstitutions(std::string &str, ArrayRef< Subst > substitutions)
Perform the given substitutions on 'str' in-place.
Definition: Predicate.cpp:140
static PredNode * buildPredicateTree(const Pred &root, SpecificBumpPtrAllocator< PredNode > &allocator, ArrayRef< Subst > substitutions)
Definition: Predicate.cpp:161
static PredNode * propagateGroundTruth(PredNode *node, const llvm::SmallPtrSetImpl< Pred * > &knownTruePreds, const llvm::SmallPtrSetImpl< Pred * > &knownFalsePreds)
Definition: Predicate.cpp:207
static std::string combineNot(ArrayRef< std::string > children)
Definition: Predicate.cpp:311
std::string getConditionImpl() const
Definition: Predicate.cpp:65
CPred(const llvm::Record *record)
CombinedPred(const llvm::Record *record)
const llvm::Record * getCombinerDef() const
Definition: Predicate.cpp:80
std::vector< const llvm::Record * > getChildren() const
Definition: Predicate.cpp:85
std::string getConditionImpl() const
Definition: Predicate.cpp:356
StringRef getSuffix() const
Definition: Predicate.cpp:379
StringRef getPrefix() const
Definition: Predicate.cpp:375
bool isNull() const
Definition: Predicate.h:46
std::string getCondition() const
Definition: Predicate.cpp:39
bool isCombined() const
Definition: Predicate.cpp:48
ArrayRef< SMLoc > getLoc() const
Definition: Predicate.cpp:52
const llvm::Record * def
Definition: Predicate.h:75
StringRef getReplacement() const
Definition: Predicate.cpp:371
StringRef getPattern() const
Definition: Predicate.cpp:367
Include the generated interface declarations.